diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index a2a7f689a5..95146d17d9 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,85 +1,10 @@ import logging import uuid -from dataclasses import dataclass, field from enum import StrEnum, auto logger = logging.getLogger(__name__) -@dataclass -class QuotaCharge: - """ - Result of a quota reservation (Reserve phase). - - Lifecycle: - charge = QuotaType.TRIGGER.consume(tenant_id) # Reserve - try: - do_work() - charge.commit() # Confirm consumption - except: - charge.refund() # Release frozen quota - - If neither commit() nor refund() is called, the billing system's - cleanup CronJob will auto-release the reservation within ~75 seconds. - """ - - success: bool - charge_id: str | None # reservation_id - _quota_type: "QuotaType" - _tenant_id: str | None = None - _feature_key: str | None = None - _amount: int = 0 - _committed: bool = field(default=False, repr=False) - - def commit(self, actual_amount: int | None = None) -> None: - """ - Confirm the consumption with actual amount. - - Args: - actual_amount: Actual amount consumed. Defaults to the reserved amount. - If less than reserved, the difference is refunded automatically. - """ - if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: - return - - try: - from services.billing_service import BillingService - - amount = actual_amount if actual_amount is not None else self._amount - BillingService.quota_commit( - tenant_id=self._tenant_id, - feature_key=self._feature_key, - reservation_id=self.charge_id, - actual_amount=amount, - ) - self._committed = True - logger.debug( - "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", - self._quota_type.value, - self._tenant_id, - self.charge_id, - amount, - ) - except Exception: - logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) - - def refund(self) -> None: - """ - Release the reserved quota (cancel the charge). - - Safe to call even if: - - charge failed or was disabled (charge_id is None) - - already committed (Release after Commit is a no-op) - - already refunded (idempotent) - - This method guarantees no exceptions will be raised. - """ - if not self.charge_id or not self._tenant_id or not self._feature_key: - return - - self._quota_type.release(self.charge_id, self._tenant_id, self._feature_key) - - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. @@ -99,7 +24,7 @@ class QuotaType(StrEnum): case _: raise ValueError(f"Invalid quota type: {self}") - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: + def consume(self, tenant_id: str, amount: int = 1): """ Consume quota using Reserve + immediate Commit. @@ -125,7 +50,7 @@ class QuotaType(StrEnum): charge.commit() return charge - def reserve(self, tenant_id: str, amount: int = 1) -> QuotaCharge: + def reserve(self, tenant_id: str, amount: int = 1): """ Reserve quota before task execution (Reserve phase only). @@ -147,6 +72,7 @@ class QuotaType(StrEnum): from configs import dify_config from services.billing_service import BillingService from services.errors.app import QuotaExceededError + from services.quota_service import QuotaCharge, unlimited if not dify_config.BILLING_ENABLED: logger.debug("Billing disabled, allowing request for %s", tenant_id) @@ -245,7 +171,7 @@ class QuotaType(StrEnum): from services.billing_service import BillingService try: - usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + usage_info = BillingService.get_quota_info(tenant_id) if isinstance(usage_info, dict): feature_info = usage_info.get(self.billing_key, {}) if isinstance(feature_info, dict): @@ -258,7 +184,3 @@ class QuotaType(StrEnum): except Exception: logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) return -1 - - -def unlimited() -> QuotaCharge: - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index ce68a1dcba..246252fda0 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -18,7 +18,8 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType +from services.quota_service import unlimited from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 5f7201de11..d9a468795c 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -13,7 +13,8 @@ from celery.result import AsyncResult from sqlalchemy import select from sqlalchemy.orm import Session -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType +from services.quota_service import unlimited from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowTriggerStatus diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 755407d849..717b818d52 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -26,6 +26,29 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) + + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @@ -46,19 +69,21 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/quota/info", params=params) + usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info + @classmethod + def get_quota_info(cls, tenant_id: str): + params = {"tenant_id": tenant_id} + return cls._send_request("GET", "/quota/info", params=params) + @classmethod def quota_reserve( cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None - ) -> dict: - """Reserve quota before task execution. - - Returns: - {"reservation_id": "uuid", "available": int, "reserved": int} - """ + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" payload: dict = { "tenant_id": tenant_id, "feature_key": feature_key, @@ -67,17 +92,13 @@ class BillingService: } if meta: payload["meta"] = meta - return cls._send_request("POST", "/quota/reserve", json=payload) + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) @classmethod def quota_commit( cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None - ) -> dict: - """Commit a reservation with actual consumption. - - Returns: - {"available": int, "reserved": int, "refunded": int} - """ + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" payload: dict = { "tenant_id": tenant_id, "feature_key": feature_key, @@ -86,23 +107,21 @@ class BillingService: } if meta: payload["meta"] = meta - return cls._send_request("POST", "/quota/commit", json=payload) + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) @classmethod - def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> dict: - """Release a reservation (cancel, return frozen quota). - - Returns: - {"available": int, "reserved": int, "released": int} - """ - return cls._send_request( - "POST", - "/quota/release", - json={ - "tenant_id": tenant_id, - "feature_key": feature_key, - "reservation_id": reservation_id, - }, + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) ) @classmethod diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..77cfdda3a3 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaType.TRIGGER.consume(tenant_id) # Reserve + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + self._quota_type.release(self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index aca2b93fc9..a3f72da6f5 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -28,7 +28,8 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType +from services.quota_service import unlimited from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 9aa90c3793..2816bd04ab 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,7 +8,8 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType +from services.quota_service import unlimited from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError diff --git a/api/tests/unit_tests/enums/test_quota_type.py b/api/tests/unit_tests/enums/test_quota_type.py index ac34064f8a..85b5b8bb3e 100644 --- a/api/tests/unit_tests/enums/test_quota_type.py +++ b/api/tests/unit_tests/enums/test_quota_type.py @@ -4,7 +4,8 @@ from unittest.mock import patch import pytest -from enums.quota_type import QuotaCharge, QuotaType, unlimited +from enums.quota_type import QuotaType +from services.quota_service import QuotaCharge, unlimited class TestQuotaType: @@ -186,43 +187,43 @@ class TestQuotaType: def test_get_remaining_normal(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 100, "usage": 30}} assert QuotaType.TRIGGER.get_remaining("t1") == 70 def test_get_remaining_unlimited(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": -1, "usage": 0}} assert QuotaType.TRIGGER.get_remaining("t1") == -1 def test_get_remaining_over_limit_returns_zero(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} + mock_bs.get_quota_info.return_value = {"trigger_event": {"limit": 10, "usage": 15}} assert QuotaType.TRIGGER.get_remaining("t1") == 0 def test_get_remaining_exception_returns_neg1(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.side_effect = RuntimeError + mock_bs.get_quota_info.side_effect = RuntimeError assert QuotaType.TRIGGER.get_remaining("t1") == -1 def test_get_remaining_empty_response(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = {} + mock_bs.get_quota_info.return_value = {} assert QuotaType.TRIGGER.get_remaining("t1") == 0 def test_get_remaining_non_dict_response(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = "invalid" + mock_bs.get_quota_info.return_value = "invalid" assert QuotaType.TRIGGER.get_remaining("t1") == 0 def test_get_remaining_feature_not_in_response(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} + mock_bs.get_quota_info.return_value = {"other_feature": {"limit": 100, "usage": 0}} remaining = QuotaType.TRIGGER.get_remaining("t1") assert remaining == 0 def test_get_remaining_non_dict_feature_info(self): with patch("services.billing_service.BillingService") as mock_bs: - mock_bs.get_tenant_feature_plan_usage_info.return_value = {"trigger_event": "not_a_dict"} + mock_bs.get_quota_info.return_value = {"trigger_event": "not_a_dict"} assert QuotaType.TRIGGER.get_remaining("t1") == 0 diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 3ae2e006e3..ed78397dc3 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -415,7 +415,7 @@ class TestBillingServiceUsageCalculation: yield mock def test_get_tenant_feature_plan_usage_info(self, mock_send_request): - """Test retrieval of tenant feature plan usage information.""" + """Test retrieval of tenant feature plan usage information (legacy endpoint).""" # Arrange tenant_id = "tenant-123" expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}} @@ -424,6 +424,20 @@ class TestBillingServiceUsageCalculation: # Act result = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + + def test_get_quota_info(self, mock_send_request): + """Test retrieval of quota info from new endpoint.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"trigger_event": {"limit": 100, "usage": 30}, "api_rate_limit": {"limit": -1, "usage": 0}} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_quota_info(tenant_id) + # Assert assert result == expected_response mock_send_request.assert_called_once_with("GET", "/quota/info", params={"tenant_id": tenant_id}) @@ -526,8 +540,19 @@ class TestBillingServiceQuotaOperations: json={"tenant_id": "t1", "feature_key": "trigger_event", "request_id": "req-1", "amount": 1}, ) + def test_quota_reserve_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"reservation_id": "rid-str", "available": "99", "reserved": "1"} + + result = BillingService.quota_reserve(tenant_id="t1", feature_key="trigger_event", request_id="req-s", amount=1) + + assert result["available"] == 99 + assert isinstance(result["available"], int) + assert result["reserved"] == 1 + assert isinstance(result["reserved"], int) + def test_quota_reserve_with_meta(self, mock_send_request): - mock_send_request.return_value = {"reservation_id": "rid-2"} + mock_send_request.return_value = {"reservation_id": "rid-2", "available": 98, "reserved": 1} meta = {"source": "webhook"} BillingService.quota_reserve( @@ -557,8 +582,21 @@ class TestBillingServiceQuotaOperations: }, ) + def test_quota_commit_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "97", "reserved": "0", "refunded": "1"} + + result = BillingService.quota_commit( + tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s", actual_amount=1 + ) + + assert result["available"] == 97 + assert isinstance(result["available"], int) + assert result["refunded"] == 1 + assert isinstance(result["refunded"], int) + def test_quota_commit_with_meta(self, mock_send_request): - mock_send_request.return_value = {} + mock_send_request.return_value = {"available": 97, "reserved": 0, "refunded": 0} meta = {"reason": "partial"} BillingService.quota_commit( @@ -581,6 +619,17 @@ class TestBillingServiceQuotaOperations: json={"tenant_id": "t1", "feature_key": "trigger_event", "reservation_id": "rid-1"}, ) + def test_quota_release_coerces_string_to_int(self, mock_send_request): + """Test that TypeAdapter coerces string values to int.""" + mock_send_request.return_value = {"available": "100", "reserved": "0", "released": "1"} + + result = BillingService.quota_release(tenant_id="t1", feature_key="trigger_event", reservation_id="rid-s") + + assert result["available"] == 100 + assert isinstance(result["available"], int) + assert result["released"] == 1 + assert isinstance(result["released"], int) + class TestBillingServiceRateLimitEnforcement: """Unit tests for rate limit enforcement mechanisms.