diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1f419d7a5b..2e103dec15 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,9 +1,10 @@ import logging from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool @@ -41,7 +42,7 @@ class CreditPoolService: @classmethod def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with session_factory.get_session_maker().begin() as session: return session.scalar( select(TenantCreditPool) .where( @@ -76,7 +77,7 @@ class CreditPoolService: return 0 try: - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with session_factory.get_session_maker().begin() as session: pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type) if not pool: raise QuotaExceededError("Credit pool not found") @@ -108,7 +109,7 @@ class CreditPoolService: return 0 try: - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with session_factory.get_session_maker().begin() as session: pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type) if not pool: logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type) diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py index d9390a4a8f..baf54851e9 100644 --- a/api/tests/unit_tests/core/app/test_llm_quota.py +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -1,8 +1,12 @@ +from collections.abc import Iterator +from contextlib import contextmanager from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from sqlalchemy import create_engine, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.app.llm.quota import ( @@ -21,6 +25,13 @@ from models.enums import ProviderQuotaType as ModelProviderQuotaType from models.provider import Provider, ProviderType +@contextmanager +def _patched_credit_pool_session_factory(engine: Engine) -> Iterator[None]: + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker): + yield + + def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None: provider_configuration = SimpleNamespace( using_provider_type=ProviderType.SYSTEM, @@ -148,7 +159,7 @@ def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_credit_pool_session_factory(engine), ): deduct_llm_quota_for_model( tenant_id="tenant-id", diff --git a/api/tests/unit_tests/events/test_update_provider_when_message_created.py b/api/tests/unit_tests/events/test_update_provider_when_message_created.py index 9cb8ca7854..ca21b0d205 100644 --- a/api/tests/unit_tests/events/test_update_provider_when_message_created.py +++ b/api/tests/unit_tests/events/test_update_provider_when_message_created.py @@ -1,8 +1,12 @@ +from collections.abc import Iterator +from contextlib import contextmanager from types import SimpleNamespace from unittest.mock import patch from uuid import uuid4 from sqlalchemy import create_engine, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker from core.app.entities.app_invoke_entities import ChatAppGenerateEntity from core.entities.provider_entities import ProviderQuotaType, QuotaUnit @@ -11,6 +15,13 @@ from models import TenantCreditPool from models.provider import ProviderType +@contextmanager +def _patched_credit_pool_session_factory(engine: Engine) -> Iterator[None]: + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker): + yield + + def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None: engine = create_engine("sqlite:///:memory:") TenantCreditPool.__table__.create(engine) @@ -54,7 +65,7 @@ def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_ message = SimpleNamespace(message_tokens=2, answer_tokens=1) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_credit_pool_session_factory(engine), patch.object(update_provider_when_message_created, "_execute_provider_updates"), ): update_provider_when_message_created.handle( diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py index e77ef894e7..62e492c67d 100644 --- a/api/tests/unit_tests/services/test_credit_pool_service.py +++ b/api/tests/unit_tests/services/test_credit_pool_service.py @@ -1,10 +1,12 @@ -from types import SimpleNamespace +from collections.abc import Iterator +from contextlib import contextmanager from unittest.mock import patch from uuid import uuid4 import pytest from sqlalchemy import create_engine, select from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker from core.errors.error import QuotaExceededError from models import TenantCreditPool @@ -31,15 +33,33 @@ def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engi return engine, tenant_id, pool_id +@contextmanager +def _patched_session_factory(engine: Engine) -> Iterator[None]: + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker): + yield + + def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None: with engine.connect() as connection: return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) +def test_get_pool_uses_configured_session_factory_without_flask_app_context() -> None: + engine, tenant_id, _ = _create_engine_with_pool(quota_limit=10, quota_used=2) + + with _patched_session_factory(engine): + pool = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) + + assert pool is not None + assert pool.tenant_id == tenant_id + assert pool.quota_used == 2 + + def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) - with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + with _patched_session_factory(engine): deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) assert deducted_credits == 3 @@ -55,7 +75,7 @@ def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None: TenantCreditPool.__table__.create(engine) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_session_factory(engine), pytest.raises(QuotaExceededError, match="Credit pool not found"), ): CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1) @@ -65,7 +85,7 @@ def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_session_factory(engine), pytest.raises(QuotaExceededError, match="No credits remaining"), ): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) @@ -77,7 +97,7 @@ def test_check_and_deduct_credits_raises_without_partial_deduction_when_insuffic engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_session_factory(engine), pytest.raises(QuotaExceededError, match="Insufficient credits remaining"), ): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) @@ -89,7 +109,7 @@ def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_session_factory(engine), patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), pytest.raises(QuotaExceededError, match="Failed to deduct credits"), ): @@ -106,7 +126,7 @@ def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None: engine = create_engine("sqlite:///:memory:") TenantCreditPool.__table__.create(engine) - with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + with _patched_session_factory(engine): deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1) assert deducted_credits == 0 @@ -115,7 +135,7 @@ def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None: def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10) - with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + with _patched_session_factory(engine): deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) assert deducted_credits == 0 @@ -125,7 +145,7 @@ def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None: def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) - with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + with _patched_session_factory(engine): deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3) assert deducted_credits == 1 @@ -136,7 +156,7 @@ def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_session_factory(engine), patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), pytest.raises(QuotaExceededError, match="Failed to deduct credits"), ): @@ -149,7 +169,7 @@ def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None: engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) with ( - patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + _patched_session_factory(engine), patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")), pytest.raises(QuotaExceededError, match="quota unavailable"), ):