mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 06:00:51 -04:00
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
import logging
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from configs import dify_config
|
|
from core.errors.error import QuotaExceededError
|
|
from extensions.ext_database import db
|
|
from models import TenantCreditPool
|
|
from models.enums import ProviderQuotaType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CreditPoolService:
|
|
@staticmethod
|
|
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
|
|
return session.scalar(
|
|
select(TenantCreditPool)
|
|
.where(
|
|
TenantCreditPool.tenant_id == tenant_id,
|
|
TenantCreditPool.pool_type == pool_type,
|
|
)
|
|
.limit(1)
|
|
.with_for_update()
|
|
)
|
|
|
|
@classmethod
|
|
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
|
"""create default credit pool for new tenant"""
|
|
credit_pool = TenantCreditPool(
|
|
tenant_id=tenant_id,
|
|
quota_limit=dify_config.HOSTED_POOL_CREDITS,
|
|
quota_used=0,
|
|
pool_type=ProviderQuotaType.TRIAL,
|
|
)
|
|
db.session.add(credit_pool)
|
|
db.session.commit()
|
|
return credit_pool
|
|
|
|
@classmethod
|
|
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
|
|
"""get tenant credit pool"""
|
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
|
return session.scalar(
|
|
select(TenantCreditPool)
|
|
.where(
|
|
TenantCreditPool.tenant_id == tenant_id,
|
|
TenantCreditPool.pool_type == pool_type,
|
|
)
|
|
.limit(1)
|
|
)
|
|
|
|
@classmethod
|
|
def check_credits_available(
|
|
cls,
|
|
tenant_id: str,
|
|
credits_required: int,
|
|
pool_type: str = "trial",
|
|
) -> bool:
|
|
"""check if credits are available without deducting"""
|
|
pool = cls.get_pool(tenant_id, pool_type)
|
|
if not pool:
|
|
return False
|
|
return pool.remaining_credits >= credits_required
|
|
|
|
@classmethod
|
|
def check_and_deduct_credits(
|
|
cls,
|
|
tenant_id: str,
|
|
credits_required: int,
|
|
pool_type: str = "trial",
|
|
) -> int:
|
|
"""Deduct exactly the requested credits or raise without mutating the pool."""
|
|
if credits_required <= 0:
|
|
return 0
|
|
|
|
try:
|
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
|
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
|
if not pool:
|
|
raise QuotaExceededError("Credit pool not found")
|
|
|
|
remaining_credits = pool.remaining_credits
|
|
if remaining_credits <= 0:
|
|
raise QuotaExceededError("No credits remaining")
|
|
if remaining_credits < credits_required:
|
|
raise QuotaExceededError("Insufficient credits remaining")
|
|
|
|
pool.quota_used += credits_required
|
|
except QuotaExceededError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
|
raise QuotaExceededError("Failed to deduct credits")
|
|
|
|
return credits_required
|
|
|
|
@classmethod
|
|
def deduct_credits_capped(
|
|
cls,
|
|
tenant_id: str,
|
|
credits_required: int,
|
|
pool_type: str = "trial",
|
|
) -> int:
|
|
"""Deduct up to the available balance and return the actual deducted credits."""
|
|
if credits_required <= 0:
|
|
return 0
|
|
|
|
try:
|
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
|
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
|
if not pool:
|
|
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
|
|
return 0
|
|
|
|
deducted_credits = min(credits_required, pool.remaining_credits)
|
|
if deducted_credits <= 0:
|
|
return 0
|
|
|
|
pool.quota_used += deducted_credits
|
|
return deducted_credits
|
|
except QuotaExceededError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("Failed to deduct capped credits for tenant %s", tenant_id)
|
|
raise QuotaExceededError("Failed to deduct credits")
|