diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 95146d17d9..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,9 +1,5 @@ -import logging -import uuid from enum import StrEnum, auto -logger = logging.getLogger(__name__) - class QuotaType(StrEnum): """ @@ -23,164 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1): - """ - Consume quota using Reserve + immediate Commit. - - This is the simple one-shot mode: Reserve freezes quota, then Commit - confirms it right away. The returned QuotaCharge supports .refund() - which calls Release (idempotent even after Commit). - - For advanced two-phase usage (e.g. streaming), use reserve() directly - and call charge.commit() / charge.refund() manually. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with reservation_id for potential refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - charge = self.reserve(tenant_id, amount) - if charge.success and charge.charge_id: - charge.commit() - return charge - - def reserve(self, tenant_id: str, amount: int = 1): - """ - Reserve quota before task execution (Reserve phase only). - - The caller MUST call charge.commit() after the task succeeds, - or charge.refund() if the task fails. - - If neither is called, the reservation auto-expires in ~75 seconds. - - Args: - tenant_id: The tenant identifier - amount: Amount to reserve (default: 1) - - Returns: - QuotaCharge — call .commit() on success, .refund() on failure - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - from services.quota_service import QuotaCharge, unlimited - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Reserving %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to reserve must be greater than 0") - - request_id = str(uuid.uuid4()) - feature_key = self.billing_key - - try: - reserve_resp = BillingService.quota_reserve( - tenant_id=tenant_id, - feature_key=feature_key, - request_id=request_id, - amount=amount, - ) - - reservation_id = reserve_resp.get("reservation_id") - if not reservation_id: - logger.warning( - "Reserve returned no reservation_id for %s, feature %s, response: %s", - tenant_id, - self.value, - reserve_resp, - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - logger.debug( - "Reserved %d %s quota for tenant %s, reservation_id: %s", - amount, - self.value, - tenant_id, - reservation_id, - ) - return QuotaCharge( - success=True, - charge_id=reservation_id, - _quota_type=self, - _tenant_id=tenant_id, - _feature_key=feature_key, - _amount=amount, - ) - - except QuotaExceededError: - raise - except ValueError: - raise - except Exception: - logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - return True - - def release(self, reservation_id: str, tenant_id: str, feature_key: str) -> None: - """ - Release a reservation. Guarantees no exceptions. - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not reservation_id: - return - - logger.info("Releasing %s quota, reservation_id: %s", self.value, reservation_id) - BillingService.quota_release( - tenant_id=tenant_id, - feature_key=feature_key, - reservation_id=reservation_id, - ) - except Exception: - logger.exception("Failed to release quota, reservation_id: %s", reservation_id) - - def get_remaining(self, tenant_id: str) -> int: - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_quota_info(tenant_id) - if isinstance(usage_info, dict): - feature_info = usage_info.get(self.billing_key, {}) - if isinstance(feature_info, dict): - limit = feature_info.get("limit", 0) - usage = feature_info.get("usage", 0) - if limit == -1: - return -1 - return max(0, limit - usage) - return 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index ee24424cd5..5bff841c10 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -24,7 +24,7 @@ from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError -from services.quota_service import unlimited +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -107,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.reserve(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index e78e03cce5..327756753c 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -22,7 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError -from services.quota_service import unlimited +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -135,7 +135,7 @@ class AsyncWorkflowService: # 7. Reserve quota (commit after successful dispatch) quota_charge = unlimited() try: - quota_charge = QuotaType.WORKFLOW.reserve(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED diff --git a/api/services/quota_service.py b/api/services/quota_service.py index 77cfdda3a3..4c784315c7 100644 --- a/api/services/quota_service.py +++ b/api/services/quota_service.py @@ -1,9 +1,12 @@ from __future__ import annotations import logging +import uuid from dataclasses import dataclass, field from typing import TYPE_CHECKING +from configs import dify_config + if TYPE_CHECKING: from enums.quota_type import QuotaType @@ -16,7 +19,7 @@ class QuotaCharge: Result of a quota reservation (Reserve phase). Lifecycle: - charge = QuotaType.TRIGGER.consume(tenant_id) # Reserve + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) try: do_work() charge.commit() # Confirm consumption @@ -81,10 +84,150 @@ class QuotaCharge: if not self.charge_id or not self._tenant_id or not self._feature_key: return - self._quota_type.release(self.charge_id, self._tenant_id, self._feature_key) + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) def unlimited() -> QuotaCharge: from enums.quota_type import QuotaType return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 2254ce8e7d..b71ae6fe41 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -28,6 +28,7 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from enums.quota_type import QuotaType +from services.quota_service import QuotaService from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory @@ -784,7 +785,7 @@ class WebhookService: # reserve quota before triggering workflow execution try: - quota_charge = QuotaType.TRIGGER.reserve(webhook_trigger.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) logger.info( diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 1b1e809baf..b9f382eccf 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -42,7 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError -from services.quota_service import unlimited +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -302,7 +302,7 @@ def dispatch_triggered_workflow( # reserve quota before invoking trigger quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.reserve(subscription.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) logger.info( diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 774e1103df..dfb2fb3391 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -12,7 +12,7 @@ from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError -from services.quota_service import unlimited +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -44,7 +44,7 @@ def run_schedule_trigger(schedule_id: str) -> None: quota_charge = unlimited() try: - quota_charge = QuotaType.TRIGGER.reserve(schedule.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) except QuotaExceededError: AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 7539bae685..3514447240 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -602,9 +602,9 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) diff --git a/api/tests/unit_tests/enums/test_quota_type.py b/api/tests/unit_tests/enums/test_quota_type.py index 85b5b8bb3e..f256ff3b4e 100644 --- a/api/tests/unit_tests/enums/test_quota_type.py +++ b/api/tests/unit_tests/enums/test_quota_type.py @@ -1,11 +1,11 @@ -"""Unit tests for QuotaType and QuotaCharge.""" +"""Unit tests for QuotaType, QuotaService, and QuotaCharge.""" from unittest.mock import patch import pytest from enums.quota_type import QuotaType -from services.quota_service import QuotaCharge, unlimited +from services.quota_service import QuotaCharge, QuotaService, unlimited class TestQuotaType: @@ -19,31 +19,33 @@ class TestQuotaType: with pytest.raises(ValueError, match="Invalid quota type"): _ = QuotaType.UNLIMITED.billing_key + +class TestQuotaService: def test_reserve_billing_disabled(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService"), ): mock_cfg.BILLING_ENABLED = False - charge = QuotaType.TRIGGER.reserve("t1") + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") assert charge.success is True assert charge.charge_id is None def test_reserve_zero_amount_raises(self): - with patch("configs.dify_config") as mock_cfg: + with patch("services.quota_service.dify_config") as mock_cfg: mock_cfg.BILLING_ENABLED = True with pytest.raises(ValueError, match="greater than 0"): - QuotaType.TRIGGER.reserve("t1", amount=0) + QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=0) def test_reserve_success(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_reserve.return_value = {"reservation_id": "rid-1", "available": 99} - charge = QuotaType.TRIGGER.reserve("t1", amount=1) + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1", amount=1) assert charge.success is True assert charge.charge_id == "rid-1" @@ -56,175 +58,175 @@ class TestQuotaType: from services.errors.app import QuotaExceededError with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_reserve.return_value = {} with pytest.raises(QuotaExceededError): - QuotaType.TRIGGER.reserve("t1") + QuotaService.reserve(QuotaType.TRIGGER, "t1") def test_reserve_quota_exceeded_propagates(self): from services.errors.app import QuotaExceededError with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_reserve.side_effect = QuotaExceededError(feature="trigger", tenant_id="t1", required=1) with pytest.raises(QuotaExceededError): - QuotaType.TRIGGER.reserve("t1") + QuotaService.reserve(QuotaType.TRIGGER, "t1") def test_reserve_api_exception_returns_unlimited(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_reserve.side_effect = RuntimeError("network") - charge = QuotaType.TRIGGER.reserve("t1") + charge = QuotaService.reserve(QuotaType.TRIGGER, "t1") assert charge.success is True assert charge.charge_id is None def test_consume_calls_reserve_and_commit(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_reserve.return_value = {"reservation_id": "rid-c"} mock_bs.quota_commit.return_value = {} - charge = QuotaType.TRIGGER.consume("t1") + charge = QuotaService.consume(QuotaType.TRIGGER, "t1") assert charge.success is True mock_bs.quota_commit.assert_called_once() def test_check_billing_disabled(self): - with patch("configs.dify_config") as mock_cfg: + with patch("services.quota_service.dify_config") as mock_cfg: mock_cfg.BILLING_ENABLED = False - assert QuotaType.TRIGGER.check("t1") is True + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True def test_check_zero_amount_raises(self): - with patch("configs.dify_config") as mock_cfg: + with patch("services.quota_service.dify_config") as mock_cfg: mock_cfg.BILLING_ENABLED = True with pytest.raises(ValueError, match="greater than 0"): - QuotaType.TRIGGER.check("t1", amount=0) + QuotaService.check(QuotaType.TRIGGER, "t1", amount=0) def test_check_sufficient_quota(self): with ( - patch("configs.dify_config") as mock_cfg, - patch.object(QuotaType.TRIGGER, "get_remaining", return_value=100), + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=100), ): mock_cfg.BILLING_ENABLED = True - assert QuotaType.TRIGGER.check("t1", amount=50) is True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=50) is True def test_check_insufficient_quota(self): with ( - patch("configs.dify_config") as mock_cfg, - patch.object(QuotaType.TRIGGER, "get_remaining", return_value=5), + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=5), ): mock_cfg.BILLING_ENABLED = True - assert QuotaType.TRIGGER.check("t1", amount=10) is False + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=10) is False def test_check_unlimited_quota(self): with ( - patch("configs.dify_config") as mock_cfg, - patch.object(QuotaType.TRIGGER, "get_remaining", return_value=-1), + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", return_value=-1), ): mock_cfg.BILLING_ENABLED = True - assert QuotaType.TRIGGER.check("t1", amount=999) is True + assert QuotaService.check(QuotaType.TRIGGER, "t1", amount=999) is True def test_check_exception_returns_true(self): with ( - patch("configs.dify_config") as mock_cfg, - patch.object(QuotaType.TRIGGER, "get_remaining", side_effect=RuntimeError), + patch("services.quota_service.dify_config") as mock_cfg, + patch.object(QuotaService, "get_remaining", side_effect=RuntimeError), ): mock_cfg.BILLING_ENABLED = True - assert QuotaType.TRIGGER.check("t1") is True + assert QuotaService.check(QuotaType.TRIGGER, "t1") is True def test_release_billing_disabled(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = False - QuotaType.TRIGGER.release("rid-1", "t1", "trigger_event") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") mock_bs.quota_release.assert_not_called() def test_release_empty_reservation(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True - QuotaType.TRIGGER.release("", "t1", "trigger_event") + QuotaService.release(QuotaType.TRIGGER, "", "t1", "trigger_event") mock_bs.quota_release.assert_not_called() def test_release_success(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_release.return_value = {} - QuotaType.TRIGGER.release("rid-1", "t1", "trigger_event") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") mock_bs.quota_release.assert_called_once_with( tenant_id="t1", feature_key="trigger_event", reservation_id="rid-1" ) def test_release_exception_swallowed(self): with ( - patch("configs.dify_config") as mock_cfg, + patch("services.quota_service.dify_config") as mock_cfg, patch("services.billing_service.BillingService") as mock_bs, ): mock_cfg.BILLING_ENABLED = True mock_bs.quota_release.side_effect = RuntimeError("fail") - QuotaType.TRIGGER.release("rid-1", "t1", "trigger_event") + QuotaService.release(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") def test_get_remaining_normal(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} - assert QuotaType.TRIGGER.get_remaining("t1") == 70 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 70 def test_get_remaining_unlimited(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} - assert QuotaType.TRIGGER.get_remaining("t1") == -1 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 def test_get_remaining_over_limit_returns_zero(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} - assert QuotaType.TRIGGER.get_remaining("t1") == 0 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 def test_get_remaining_exception_returns_neg1(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.side_effect = RuntimeError - assert QuotaType.TRIGGER.get_remaining("t1") == -1 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == -1 def test_get_remaining_empty_response(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = {} - assert QuotaType.TRIGGER.get_remaining("t1") == 0 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 def test_get_remaining_non_dict_response(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = "invalid" - assert QuotaType.TRIGGER.get_remaining("t1") == 0 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 def test_get_remaining_feature_not_in_response(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} - remaining = QuotaType.TRIGGER.get_remaining("t1") + remaining = QuotaService.get_remaining(QuotaType.TRIGGER, "t1") assert remaining == 0 def test_get_remaining_non_dict_feature_info(self): with patch("services.billing_service.BillingService") as mock_bs: mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"} - assert QuotaType.TRIGGER.get_remaining("t1") == 0 + assert QuotaService.get_remaining(QuotaType.TRIGGER, "t1") == 0 class TestQuotaCharge: @@ -310,7 +312,7 @@ class TestQuotaCharge: charge.commit() def test_refund_success(self): - with patch.object(QuotaType.TRIGGER, "release") as mock_rel: + with patch.object(QuotaService, "release") as mock_rel: charge = QuotaCharge( success=True, charge_id="rid-1", @@ -319,16 +321,16 @@ class TestQuotaCharge: _feature_key="trigger_event", ) charge.refund() - mock_rel.assert_called_once_with("rid-1", "t1", "trigger_event") + mock_rel.assert_called_once_with(QuotaType.TRIGGER, "rid-1", "t1", "trigger_event") def test_refund_no_charge_id_noop(self): - with patch.object(QuotaType.TRIGGER, "release") as mock_rel: + with patch.object(QuotaService, "release") as mock_rel: charge = QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.TRIGGER) charge.refund() mock_rel.assert_not_called() def test_refund_no_tenant_id_noop(self): - with patch.object(QuotaType.TRIGGER, "release") as mock_rel: + with patch.object(QuotaService, "release") as mock_rel: charge = QuotaCharge( success=True, charge_id="rid-1", diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index 68ee6ae9d6..6750dca570 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -448,7 +448,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() reserve_mock = mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.reserve", + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( @@ -476,7 +476,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.reserve", + "services.app_generate_service.QuotaService.reserve", side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), ) @@ -493,7 +493,7 @@ class TestGenerateBilling: monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) quota_charge = MagicMock() mocker.patch( - "services.app_generate_service.QuotaType.WORKFLOW.reserve", + "services.app_generate_service.QuotaService.reserve", return_value=quota_charge, ) mocker.patch( diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py index ab83c8020f..07f8324d13 100644 --- a/api/tests/unit_tests/services/test_async_workflow_service.py +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -57,7 +57,7 @@ class TestAsyncWorkflowService: - repo: SQLAlchemyWorkflowTriggerLogRepository - dispatcher_manager_class: QueueDispatcherManager class - dispatcher: dispatcher instance - - quota_workflow: QuotaType.WORKFLOW + - quota_service: QuotaService mock - get_workflow: AsyncWorkflowService._get_workflow method - professional_task: execute_workflow_professional - team_task: execute_workflow_team @@ -72,7 +72,7 @@ class TestAsyncWorkflowService: mock_repo.create.side_effect = _create_side_effect mock_dispatcher = MagicMock() - quota_workflow = MagicMock() + mock_quota_service = MagicMock() mock_get_workflow = MagicMock() mock_professional_task = MagicMock() @@ -93,8 +93,8 @@ class TestAsyncWorkflowService: ) as mock_get_workflow, patch.object( async_workflow_service_module, - "QuotaType", - new=SimpleNamespace(WORKFLOW=quota_workflow), + "QuotaService", + new=mock_quota_service, ), patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, @@ -107,7 +107,7 @@ class TestAsyncWorkflowService: "repo": mock_repo, "dispatcher_manager_class": mock_dispatcher_manager_class, "dispatcher": mock_dispatcher, - "quota_workflow": quota_workflow, + "quota_service": mock_quota_service, "get_workflow": mock_get_workflow, "professional_task": mock_professional_task, "team_task": mock_team_task, @@ -147,7 +147,7 @@ class TestAsyncWorkflowService: mocks["sandbox_task"].delay.return_value = task_result quota_charge_mock = MagicMock() - mocks["quota_workflow"].reserve.return_value = quota_charge_mock + mocks["quota_service"].reserve.return_value = quota_charge_mock class DummyAccount: def __init__(self, user_id: str): @@ -166,7 +166,7 @@ class TestAsyncWorkflowService: assert result.status == "queued" assert result.queue == queue_name - mocks["quota_workflow"].reserve.assert_called_once_with("tenant-123") + mocks["quota_service"].reserve.assert_called_once() quota_charge_mock.commit.assert_called_once() assert session.commit.call_count == 2 @@ -254,7 +254,7 @@ class TestAsyncWorkflowService: mocks = async_workflow_trigger_mocks mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM mocks["get_workflow"].return_value = workflow - mocks["quota_workflow"].reserve.side_effect = QuotaExceededError( + mocks["quota_service"].reserve.side_effect = QuotaExceededError( feature="workflow", tenant_id="tenant-123", required=1,