mirror of
https://github.com/langgenius/dify.git
synced 2026-05-25 10:00:43 -04:00
fix: credit pool access outside flask context (#36143)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
@@ -41,7 +42,7 @@ class CreditPoolService:
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
|
||||
"""get tenant credit pool"""
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
return session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
@@ -76,7 +77,7 @@ class CreditPoolService:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
@@ -108,7 +109,7 @@ class CreditPoolService:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.quota import (
|
||||
@@ -21,6 +25,13 @@ from models.enums import ProviderQuotaType as ModelProviderQuotaType
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patched_credit_pool_session_factory(engine: Engine) -> Iterator[None]:
|
||||
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker):
|
||||
yield
|
||||
|
||||
|
||||
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
@@ -148,7 +159,7 @@ def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_credit_pool_session_factory(engine),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
@@ -11,6 +15,13 @@ from models import TenantCreditPool
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patched_credit_pool_session_factory(engine: Engine) -> Iterator[None]:
|
||||
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker):
|
||||
yield
|
||||
|
||||
|
||||
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
@@ -54,7 +65,7 @@ def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_credit_pool_session_factory(engine),
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from types import SimpleNamespace
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
@@ -31,15 +33,33 @@ def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engi
|
||||
return engine, tenant_id, pool_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patched_session_factory(engine: Engine) -> Iterator[None]:
|
||||
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker):
|
||||
yield
|
||||
|
||||
|
||||
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
|
||||
with engine.connect() as connection:
|
||||
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
|
||||
def test_get_pool_uses_configured_session_factory_without_flask_app_context() -> None:
|
||||
engine, tenant_id, _ = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with _patched_session_factory(engine):
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert pool is not None
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.quota_used == 2
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 3
|
||||
@@ -55,7 +75,7 @@ def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_session_factory(engine),
|
||||
pytest.raises(QuotaExceededError, match="Credit pool not found"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
|
||||
@@ -65,7 +85,7 @@ def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_session_factory(engine),
|
||||
pytest.raises(QuotaExceededError, match="No credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
@@ -77,7 +97,7 @@ def test_check_and_deduct_credits_raises_without_partial_deduction_when_insuffic
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_session_factory(engine),
|
||||
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
@@ -89,7 +109,7 @@ def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_session_factory(engine),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
@@ -106,7 +126,7 @@ def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
@@ -115,7 +135,7 @@ def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
@@ -125,7 +145,7 @@ def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
|
||||
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 1
|
||||
@@ -136,7 +156,7 @@ def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_session_factory(engine),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
@@ -149,7 +169,7 @@ def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
_patched_session_factory(engine),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="quota unavailable"),
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user