mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 09:03:04 -04:00
Merge branch 'feat/new-biliing-quota' into deploy/dev
This commit is contained in:
@@ -8,8 +8,8 @@ from hashlib import sha256
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
@@ -83,6 +83,12 @@ from tasks.mail_reset_password_task import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvitationDetailDict(TypedDict):
|
||||
account: Account
|
||||
data: InvitationData
|
||||
tenant: Tenant
|
||||
|
||||
|
||||
def _try_join_enterprise_default_workspace(account_id: str) -> None:
|
||||
"""Best-effort join to enterprise default workspace."""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
@@ -144,22 +150,26 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = db.session.query(Account).filter_by(id=user_id).first()
|
||||
account = db.session.get(Account, user_id)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
|
||||
.limit(1)
|
||||
)
|
||||
if current_tenant:
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
available_ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
@@ -195,7 +205,7 @@ class AccountService:
|
||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
@@ -371,8 +381,10 @@ class AccountService:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: AccountIntegrate | None = (
|
||||
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
|
||||
account_integrate: AccountIntegrate | None = db.session.scalar(
|
||||
select(AccountIntegrate)
|
||||
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if account_integrate:
|
||||
@@ -416,7 +428,9 @@ class AccountService:
|
||||
def update_account_email(account: Account, email: str) -> Account:
|
||||
"""Update account email"""
|
||||
account.email = email
|
||||
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
|
||||
account_integrate = db.session.scalar(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
|
||||
)
|
||||
if account_integrate:
|
||||
db.session.delete(account_integrate)
|
||||
db.session.add(account)
|
||||
@@ -818,7 +832,7 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@@ -1018,7 +1032,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def check_email_unique(email: str) -> bool:
|
||||
return db.session.query(Account).filter_by(email=email).first() is None
|
||||
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
|
||||
|
||||
class TenantService:
|
||||
@@ -1061,11 +1075,11 @@ class TenantService:
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter_by(account_id=account.id)
|
||||
available_ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if available_ta:
|
||||
@@ -1096,7 +1110,11 @@ class TenantService:
|
||||
logger.error("Tenant %s has already an owner.", tenant.id)
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if ta:
|
||||
ta.role = TenantAccountRole(role)
|
||||
else:
|
||||
@@ -1111,11 +1129,12 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> list[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return (
|
||||
db.session.query(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
.all()
|
||||
return list(
|
||||
db.session.scalars(
|
||||
select(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
).all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1125,7 +1144,11 @@ class TenantService:
|
||||
if not tenant:
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
else:
|
||||
@@ -1140,23 +1163,25 @@ class TenantService:
|
||||
if tenant_id is None:
|
||||
raise ValueError("Tenant ID must be provided.")
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
tenant_account_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
db.session.query(TenantAccountJoin).where(
|
||||
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
||||
).update({"current": False})
|
||||
db.session.execute(
|
||||
update(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
|
||||
.values(current=False)
|
||||
)
|
||||
tenant_account_join.current = True
|
||||
# Set the current tenant for the account
|
||||
account.set_tenant_id(tenant_account_join.tenant_id)
|
||||
@@ -1165,8 +1190,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_tenant_members(tenant: Tenant) -> list[Account]:
|
||||
"""Get tenant members"""
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id)
|
||||
@@ -1175,7 +1200,7 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in query:
|
||||
for account, role in db.session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
@@ -1184,8 +1209,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
|
||||
"""Get dataset admin members"""
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id)
|
||||
@@ -1195,7 +1220,7 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in query:
|
||||
for account, role in db.session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
@@ -1208,26 +1233,31 @@ class TenantService:
|
||||
raise ValueError("all roles must be TenantAccountRole")
|
||||
|
||||
return (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
|
||||
.first()
|
||||
db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.role.in_([role.value for role in roles]),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return TenantAccountRole(join.role) if join else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_count() -> int:
|
||||
"""Get tenant count"""
|
||||
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
||||
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
@@ -1244,7 +1274,11 @@ class TenantService:
|
||||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
ta_operator = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
@@ -1262,7 +1296,11 @@ class TenantService:
|
||||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not ta:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
@@ -1277,7 +1315,12 @@ class TenantService:
|
||||
should_delete_account = False
|
||||
if account.status == AccountStatus.PENDING:
|
||||
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
|
||||
remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count()
|
||||
remaining_joins = (
|
||||
db.session.scalar(
|
||||
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if remaining_joins == 0:
|
||||
db.session.delete(account)
|
||||
should_delete_account = True
|
||||
@@ -1312,8 +1355,10 @@ class TenantService:
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
target_member_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
target_member_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not target_member_join:
|
||||
@@ -1324,8 +1369,10 @@ class TenantService:
|
||||
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = (
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if current_owner_join:
|
||||
current_owner_join.role = TenantAccountRole.ADMIN
|
||||
@@ -1384,10 +1431,10 @@ class RegisterService:
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
@@ -1469,7 +1516,7 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if not account:
|
||||
@@ -1488,7 +1535,11 @@ class RegisterService:
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
ta = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@@ -1540,26 +1591,23 @@ class RegisterService:
|
||||
@classmethod
|
||||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
) -> dict[str, Any] | None:
|
||||
) -> InvitationDetailDict | None:
|
||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = (
|
||||
db.session.query(Tenant)
|
||||
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
||||
.first()
|
||||
tenant = db.session.scalar(
|
||||
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
tenant_account = db.session.execute(
|
||||
select(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
.first()
|
||||
)
|
||||
).first()
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
@@ -1605,7 +1653,7 @@ class RegisterService:
|
||||
@classmethod
|
||||
def get_invitation_with_case_fallback(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
) -> dict[str, Any] | None:
|
||||
) -> InvitationDetailDict | None:
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
|
||||
if invitation or not email or email == email.lower():
|
||||
return invitation
|
||||
|
||||
@@ -32,22 +32,33 @@ class AdvancedPromptTemplateService:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
match app_mode:
|
||||
case AppMode.CHAT:
|
||||
match model_mode:
|
||||
case "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
case "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
case AppMode.COMPLETION:
|
||||
match model_mode:
|
||||
case "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
case "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
case _:
|
||||
pass
|
||||
# default return empty dict
|
||||
return {}
|
||||
|
||||
@@ -73,25 +84,38 @@ class AdvancedPromptTemplateService:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
has_context,
|
||||
baichuan_context_prompt,
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
match app_mode:
|
||||
case AppMode.CHAT:
|
||||
match model_mode:
|
||||
case "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG),
|
||||
has_context,
|
||||
baichuan_context_prompt,
|
||||
)
|
||||
case "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
case AppMode.COMPLETION:
|
||||
match model_mode:
|
||||
case "completion":
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
has_context,
|
||||
baichuan_context_prompt,
|
||||
)
|
||||
case "chat":
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG),
|
||||
has_context,
|
||||
baichuan_context_prompt,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
case _:
|
||||
pass
|
||||
# default return empty dict
|
||||
return {}
|
||||
|
||||
@@ -4,7 +4,9 @@ import uuid
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sqlalchemy import or_, select
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -23,15 +25,34 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
|
||||
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
|
||||
|
||||
|
||||
class AnnotationJobStatusDict(TypedDict):
|
||||
job_id: str
|
||||
job_status: str
|
||||
|
||||
|
||||
class EmbeddingModelDict(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationSettingDict(TypedDict):
|
||||
id: str
|
||||
enabled: bool
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelDict | dict
|
||||
|
||||
|
||||
class AnnotationSettingDisabledDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
@@ -43,7 +64,9 @@ class AppAnnotationService:
|
||||
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -72,7 +95,9 @@ class AppAnnotationService:
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
assert current_tenant_id is not None
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
@@ -85,7 +110,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str):
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@@ -109,7 +134,7 @@ class AppAnnotationService:
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
@@ -128,10 +153,8 @@ class AppAnnotationService:
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
@@ -170,20 +193,17 @@ class AppAnnotationService:
|
||||
"""
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation)
|
||||
.where(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Sanitize CSV-injectable fields to prevent formula injection
|
||||
for annotation in annotations:
|
||||
@@ -200,10 +220,8 @@ class AppAnnotationService:
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
@@ -219,7 +237,9 @@ class AppAnnotationService:
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
@@ -234,16 +254,14 @@ class AppAnnotationService:
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@@ -257,8 +275,8 @@ class AppAnnotationService:
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
app_annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
@@ -276,16 +294,14 @@ class AppAnnotationService:
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@@ -301,8 +317,8 @@ class AppAnnotationService:
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
app_annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
@@ -314,22 +330,19 @@ class AppAnnotationService:
|
||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
# Fetch annotations and their settings in a single query
|
||||
annotations_to_delete = (
|
||||
db.session.query(MessageAnnotation, AppAnnotationSetting)
|
||||
annotations_to_delete = db.session.execute(
|
||||
select(MessageAnnotation, AppAnnotationSetting)
|
||||
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
|
||||
.where(MessageAnnotation.id.in_(annotation_ids))
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not annotations_to_delete:
|
||||
return {"deleted_count": 0}
|
||||
@@ -338,9 +351,9 @@ class AppAnnotationService:
|
||||
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
|
||||
|
||||
# Step 2: Bulk delete hit histories in a single query
|
||||
db.session.query(AppAnnotationHitHistory).where(
|
||||
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
|
||||
).delete(synchronize_session=False)
|
||||
db.session.execute(
|
||||
delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete))
|
||||
)
|
||||
|
||||
# Step 3: Trigger async tasks for search index deletion
|
||||
for annotation, annotation_setting in annotations_to_delete:
|
||||
@@ -350,11 +363,10 @@ class AppAnnotationService:
|
||||
)
|
||||
|
||||
# Step 4: Bulk delete annotations in a single query
|
||||
deleted_count = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.id.in_(annotation_ids_to_delete))
|
||||
.delete(synchronize_session=False)
|
||||
delete_result = db.session.execute(
|
||||
delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete))
|
||||
)
|
||||
deleted_count = getattr(delete_result, "rowcount", 0)
|
||||
|
||||
db.session.commit()
|
||||
return {"deleted_count": deleted_count}
|
||||
@@ -375,10 +387,8 @@ class AppAnnotationService:
|
||||
|
||||
# get app info
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
@@ -499,16 +509,14 @@ class AppAnnotationService:
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@@ -528,7 +536,7 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
return None
|
||||
@@ -548,8 +556,10 @@ class AppAnnotationService:
|
||||
score: float,
|
||||
):
|
||||
# add hit count to annotation
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
|
||||
db.session.execute(
|
||||
update(MessageAnnotation)
|
||||
.where(MessageAnnotation.id == annotation_id)
|
||||
.values(hit_count=MessageAnnotation.hit_count + 1)
|
||||
)
|
||||
|
||||
annotation_hit_history = AppAnnotationHitHistory(
|
||||
@@ -567,19 +577,19 @@ class AppAnnotationService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
if collection_binding_detail:
|
||||
@@ -602,25 +612,25 @@ class AppAnnotationService:
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
def update_app_annotation_setting(
|
||||
cls, app_id: str, annotation_setting_id: str, args: dict
|
||||
) -> AnnotationSettingDict:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting)
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting)
|
||||
.where(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not annotation_setting:
|
||||
raise NotFound("App annotation not found")
|
||||
@@ -653,26 +663,26 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
app = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
# if annotation reply is enabled, delete annotation index
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
app_annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
|
||||
for annotation in annotations_query.yield_per(100):
|
||||
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
|
||||
AppAnnotationHitHistory.annotation_id == annotation.id
|
||||
)
|
||||
for annotation_hit_history in annotation_hit_histories_query.yield_per(100):
|
||||
annotations_iter = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
|
||||
).yield_per(100)
|
||||
for annotation in annotations_iter:
|
||||
hit_histories_iter = db.session.scalars(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation.id)
|
||||
).yield_per(100)
|
||||
for annotation_hit_history in hit_histories_iter:
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
||||
# if annotation reply is enabled, delete annotation index
|
||||
|
||||
@@ -118,139 +118,143 @@ class AppGenerateService:
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
quota_charge.commit()
|
||||
if app_model.mode == AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
if streaming:
|
||||
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||
)
|
||||
match effective_mode:
|
||||
case AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
CompletionAppGenerator.convert_to_event_stream(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
case AppMode.AGENT_CHAT:
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
AgentChatAppGenerator.convert_to_event_stream(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
case AppMode.CHAT:
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator.convert_to_event_stream(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
if streaming:
|
||||
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
root_node_id=root_node_id,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=False,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
pause_state_config=pause_config,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=session_factory.get_session_maker(),
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=False,
|
||||
root_node_id=root_node_id,
|
||||
call_depth=0,
|
||||
pause_state_config=pause_config,
|
||||
),
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
except Exception:
|
||||
quota_charge.refund()
|
||||
rate_limit.exit(request_id)
|
||||
@@ -282,43 +286,73 @@ class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
match app_model.mode:
|
||||
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
case AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(
|
||||
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
|
||||
):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
match app_model.mode:
|
||||
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
AdvancedChatAppGenerator().single_loop_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
|
||||
case AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(
|
||||
|
||||
@@ -7,11 +7,12 @@ from models.model import AppMode, AppModelConfigDict
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.COMPLETION:
|
||||
return CompletionAppConfigManager.config_validate(tenant_id, config)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_mode}")
|
||||
match app_mode:
|
||||
case AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
case AppMode.AGENT_CHAT:
|
||||
return AgentChatAppConfigManager.config_validate(tenant_id, config)
|
||||
case AppMode.COMPLETION:
|
||||
return CompletionAppConfigManager.config_validate(tenant_id, config)
|
||||
case AppMode.WORKFLOW | AppMode.ADVANCED_CHAT | AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
raise ValueError(f"Invalid app mode: {app_mode}")
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any, Union
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.ext_database import db
|
||||
@@ -244,7 +244,7 @@ class AsyncWorkflowService:
|
||||
Returns:
|
||||
Trigger log as dictionary or None if not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
||||
|
||||
@@ -270,7 +270,7 @@ class AsyncWorkflowService:
|
||||
Returns:
|
||||
List of trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_recent_logs(
|
||||
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
||||
@@ -293,7 +293,7 @@ class AsyncWorkflowService:
|
||||
Returns:
|
||||
List of failed trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_failed_for_retry(
|
||||
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -22,8 +22,8 @@ class AttachmentService:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = self._session_maker(expire_on_commit=False).scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id).limit(1)
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
@@ -107,6 +107,124 @@ class BillingInfo(TypedDict):
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class _TenantFeatureQuota(TypedDict):
|
||||
usage: int
|
||||
limit: int
|
||||
reset_date: NotRequired[int]
|
||||
|
||||
|
||||
class TenantFeatureQuotaInfo(TypedDict):
|
||||
"""Response of /quota/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Same convention as BillingInfo: billing may return int fields as str,
|
||||
always keep non-strict mode to auto-coerce.
|
||||
"""
|
||||
|
||||
trigger_event: _TenantFeatureQuota
|
||||
api_rate_limit: _TenantFeatureQuota
|
||||
|
||||
|
||||
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
|
||||
|
||||
|
||||
class _BillingQuota(TypedDict):
|
||||
size: int
|
||||
limit: int
|
||||
|
||||
|
||||
class _VectorSpaceQuota(TypedDict):
|
||||
size: float
|
||||
limit: int
|
||||
|
||||
|
||||
class _KnowledgeRateLimit(TypedDict):
|
||||
# NOTE (hj24):
|
||||
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
|
||||
# 2. Keep it for compatibility for now, can be deprecated in future versions.
|
||||
size: NotRequired[int]
|
||||
# NOTE END
|
||||
limit: int
|
||||
|
||||
|
||||
class _BillingSubscription(TypedDict):
|
||||
plan: str
|
||||
interval: str
|
||||
education: bool
|
||||
|
||||
|
||||
class BillingInfo(TypedDict):
|
||||
"""Response of /subscription/info.
|
||||
|
||||
NOTE (hj24):
|
||||
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
|
||||
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
|
||||
1. validate_python in non-strict mode will coerce it to the expected type
|
||||
2. In strict mode, it will raise ValidationError
|
||||
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
docs_processing: str
|
||||
can_replace_logo: bool
|
||||
model_load_balancing_enabled: bool
|
||||
knowledge_pipeline_publish_enabled: bool
|
||||
next_credit_reset_date: NotRequired[int]
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
|
||||
|
||||
class KnowledgeRateLimitDict(TypedDict):
|
||||
limit: int
|
||||
subscription_plan: str
|
||||
|
||||
|
||||
class TenantFeaturePlanUsageDict(TypedDict):
|
||||
result: str
|
||||
history_id: str
|
||||
|
||||
|
||||
class LangContentDict(TypedDict):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
title_pic_url: str
|
||||
|
||||
|
||||
class NotificationDict(TypedDict):
|
||||
notification_id: str
|
||||
contents: dict[str, LangContentDict]
|
||||
frequency: Literal["once", "every_page_load"]
|
||||
|
||||
|
||||
class AccountNotificationDict(TypedDict, total=False):
|
||||
should_show: bool
|
||||
notification: NotificationDict
|
||||
shouldShow: bool
|
||||
notifications: list[dict]
|
||||
|
||||
|
||||
class UpsertNotificationDict(TypedDict):
|
||||
notification_id: str
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsDict(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
class DismissNotificationDict(TypedDict):
|
||||
success: bool
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
@@ -133,9 +251,11 @@ class BillingService:
|
||||
return usage_info
|
||||
|
||||
@classmethod
|
||||
def get_quota_info(cls, tenant_id: str):
|
||||
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
return cls._send_request("GET", "/quota/info", params=params)
|
||||
return _tenant_feature_quota_info_adapter.validate_python(
|
||||
cls._send_request("GET", "/quota/info", params=params)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def quota_reserve(
|
||||
@@ -183,7 +303,7 @@ class BillingService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
|
||||
@@ -214,7 +334,9 @@ class BillingService:
|
||||
return cls._send_request("GET", "/invoices", params=params)
|
||||
|
||||
@classmethod
|
||||
def update_tenant_feature_plan_usage(cls, tenant_id: str, feature_key: str, delta: int) -> dict:
|
||||
def update_tenant_feature_plan_usage(
|
||||
cls, tenant_id: str, feature_key: str, delta: int
|
||||
) -> TenantFeaturePlanUsageDict:
|
||||
"""
|
||||
Update tenant feature plan usage.
|
||||
|
||||
@@ -234,7 +356,7 @@ class BillingService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def refund_tenant_feature_plan_usage(cls, history_id: str) -> dict:
|
||||
def refund_tenant_feature_plan_usage(cls, history_id: str) -> TenantFeaturePlanUsageDict:
|
||||
"""
|
||||
Refund a previous usage charge.
|
||||
|
||||
@@ -530,7 +652,7 @@ class BillingService:
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def get_account_notification(cls, account_id: str) -> dict:
|
||||
def get_account_notification(cls, account_id: str) -> AccountNotificationDict:
|
||||
"""Return the active in-product notification for account_id, if any.
|
||||
|
||||
Calling this endpoint also marks the notification as seen; subsequent
|
||||
@@ -554,13 +676,13 @@ class BillingService:
|
||||
@classmethod
|
||||
def upsert_notification(
|
||||
cls,
|
||||
contents: list[dict],
|
||||
contents: list[LangContentDict],
|
||||
frequency: str = "once",
|
||||
status: str = "active",
|
||||
notification_id: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
) -> dict:
|
||||
) -> UpsertNotificationDict:
|
||||
"""Create or update a notification.
|
||||
|
||||
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
|
||||
@@ -581,7 +703,9 @@ class BillingService:
|
||||
return cls._send_request("POST", "/notifications", json=payload)
|
||||
|
||||
@classmethod
|
||||
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||
def batch_add_notification_accounts(
|
||||
cls, notification_id: str, account_ids: list[str]
|
||||
) -> BatchAddNotificationAccountsDict:
|
||||
"""Register target account IDs for a notification (max 1000 per call).
|
||||
|
||||
Returns {"count": int}.
|
||||
@@ -593,7 +717,7 @@ class BillingService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
|
||||
def dismiss_notification(cls, notification_id: str, account_id: str) -> DismissNotificationDict:
|
||||
"""Mark a notification as dismissed for an account.
|
||||
|
||||
Returns {"success": bool}.
|
||||
|
||||
@@ -346,7 +346,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
@@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||
interval = datetime.timedelta(days=1)
|
||||
# Process tenants in this batch
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Calculate tenant count in next batch with current interval
|
||||
# Try different intervals until we find one with a reasonable tenant count
|
||||
test_intervals = [
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
@@ -71,7 +71,7 @@ class CreditPoolService:
|
||||
actual_credits = min(credits_required, pool.remaining_credits)
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(TenantCreditPool)
|
||||
.where(
|
||||
@@ -81,7 +81,6 @@ class CreditPoolService:
|
||||
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
@@ -7,15 +7,15 @@ import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@@ -107,6 +107,16 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProcessRulesDict(TypedDict):
|
||||
mode: str
|
||||
rules: dict[str, Any]
|
||||
|
||||
|
||||
class AutoDisableLogsDict(TypedDict):
|
||||
document_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
@@ -114,9 +124,11 @@ class DatasetService:
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
dataset_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
|
||||
)
|
||||
dataset_permission = db.session.scalars(
|
||||
select(DatasetPermission).where(
|
||||
DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id
|
||||
)
|
||||
).all()
|
||||
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
|
||||
|
||||
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
||||
@@ -180,21 +192,20 @@ class DatasetService:
|
||||
return datasets.items, datasets.total
|
||||
|
||||
@staticmethod
|
||||
def get_process_rules(dataset_id):
|
||||
def get_process_rules(dataset_id) -> ProcessRulesDict:
|
||||
# get the latest process rule
|
||||
dataset_process_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
dataset_process_rule = db.session.execute(
|
||||
select(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == dataset_id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if dataset_process_rule:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict
|
||||
rules = dataset_process_rule.rules_dict or {}
|
||||
else:
|
||||
mode = DocumentService.DEFAULT_RULES["mode"]
|
||||
rules = DocumentService.DEFAULT_RULES["rules"]
|
||||
mode = str(DocumentService.DEFAULT_RULES["mode"])
|
||||
rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {})
|
||||
return {"mode": mode, "rules": rules}
|
||||
|
||||
@staticmethod
|
||||
@@ -225,7 +236,7 @@ class DatasetService:
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
|
||||
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
|
||||
embedding_model = None
|
||||
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@@ -300,17 +311,17 @@ class DatasetService:
|
||||
):
|
||||
if rag_pipeline_dataset_create_entity.name:
|
||||
# check if dataset name already exists
|
||||
if (
|
||||
db.session.query(Dataset)
|
||||
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
):
|
||||
raise DatasetNameDuplicateError(
|
||||
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
|
||||
)
|
||||
else:
|
||||
# generate a random name as Untitled 1 2 3 ...
|
||||
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
|
||||
names = [dataset.name for dataset in datasets]
|
||||
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
|
||||
names,
|
||||
@@ -344,7 +355,7 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset: Dataset | None = db.session.get(Dataset, dataset_id)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
@@ -466,14 +477,14 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return dataset is not None
|
||||
|
||||
@@ -540,7 +551,7 @@ class DatasetService:
|
||||
external_knowledge_id: External knowledge identifier
|
||||
external_knowledge_api_id: External knowledge API identifier
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
)
|
||||
@@ -548,14 +559,14 @@ class DatasetService:
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("External knowledge binding not found.")
|
||||
|
||||
# Update binding if values have changed
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
):
|
||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
db.session.add(external_knowledge_binding)
|
||||
# Update binding if values have changed
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
):
|
||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
session.add(external_knowledge_binding)
|
||||
|
||||
@staticmethod
|
||||
def _update_internal_dataset(dataset, data, user):
|
||||
@@ -596,7 +607,7 @@ class DatasetService:
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
|
||||
# Update dataset in database
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data))
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
@@ -631,7 +642,7 @@ class DatasetService:
|
||||
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
|
||||
return
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, dataset.pipeline_id)
|
||||
if not pipeline:
|
||||
return
|
||||
|
||||
@@ -1138,8 +1149,10 @@ class DatasetService:
|
||||
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
# For partial team permission, user needs explicit permission or be the creator
|
||||
if dataset.created_by != user.id:
|
||||
user_permission = (
|
||||
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
|
||||
user_permission = db.session.scalar(
|
||||
select(DatasetPermission)
|
||||
.where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not user_permission:
|
||||
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
|
||||
@@ -1161,7 +1174,9 @@ class DatasetService:
|
||||
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
if not any(
|
||||
dp.dataset_id == dataset.id
|
||||
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
|
||||
for dp in db.session.scalars(
|
||||
select(DatasetPermission).where(DatasetPermission.account_id == user.id)
|
||||
).all()
|
||||
):
|
||||
raise NoPermissionError("You do not have permission to access this dataset.")
|
||||
|
||||
@@ -1175,12 +1190,11 @@ class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_related_apps(dataset_id: str):
|
||||
return (
|
||||
db.session.query(AppDatasetJoin)
|
||||
return db.session.scalars(
|
||||
select(AppDatasetJoin)
|
||||
.where(AppDatasetJoin.dataset_id == dataset_id)
|
||||
.order_by(db.desc(AppDatasetJoin.created_at))
|
||||
.all()
|
||||
)
|
||||
.order_by(AppDatasetJoin.created_at.desc())
|
||||
).all()
|
||||
|
||||
@staticmethod
|
||||
def update_dataset_api_status(dataset_id: str, status: bool):
|
||||
@@ -1195,7 +1209,7 @@ class DatasetService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> AutoDisableLogsDict:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
@@ -1396,8 +1410,8 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
|
||||
if document_id:
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
return document
|
||||
else:
|
||||
@@ -1626,7 +1640,7 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_id(document_id: str) -> Document | None:
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
document = db.session.get(Document, document_id)
|
||||
|
||||
return document
|
||||
|
||||
@@ -1691,7 +1705,7 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_document_file_detail(file_id: str):
|
||||
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
|
||||
file_detail = db.session.get(UploadFile, file_id)
|
||||
return file_detail
|
||||
|
||||
@staticmethod
|
||||
@@ -1765,9 +1779,11 @@ class DocumentService:
|
||||
document.name = name
|
||||
db.session.add(document)
|
||||
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
|
||||
db.session.query(UploadFile).where(
|
||||
UploadFile.id == document.data_source_info_dict["upload_file_id"]
|
||||
).update({UploadFile.name: name})
|
||||
db.session.execute(
|
||||
update(UploadFile)
|
||||
.where(UploadFile.id == document.data_source_info_dict["upload_file_id"])
|
||||
.values(name=name)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@@ -1854,8 +1870,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = (
|
||||
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1)
|
||||
)
|
||||
if document:
|
||||
return document.position + 1
|
||||
@@ -2012,28 +2028,28 @@ class DocumentService:
|
||||
if not knowledge_config.data_source.info_list.file_info_list:
|
||||
raise ValueError("File source info is required")
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
)
|
||||
.all()
|
||||
files = list(
|
||||
db.session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == dataset.tenant_id,
|
||||
UploadFile.id.in_(upload_file_list),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if len(files) != len(set(upload_file_list)):
|
||||
raise FileNotExistsError("One or more files not found.")
|
||||
|
||||
file_names = [file.name for file in files]
|
||||
db_documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
.all()
|
||||
db_documents = list(
|
||||
db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == DataSourceType.UPLOAD_FILE,
|
||||
Document.enabled == True,
|
||||
Document.name.in_(file_names),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
documents_map = {document.name: document for document in db_documents}
|
||||
for file in files:
|
||||
@@ -2079,15 +2095,15 @@ class DocumentService:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type=DataSourceType.NOTION_IMPORT,
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
documents = list(
|
||||
db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset.id,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
Document.data_source_type == DataSourceType.NOTION_IMPORT,
|
||||
Document.enabled == True,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if documents:
|
||||
for document in documents:
|
||||
@@ -2518,14 +2534,15 @@ class DocumentService:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
documents_count = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.completed_at.isnot(None),
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
return documents_count
|
||||
|
||||
@@ -2575,10 +2592,10 @@ class DocumentService:
|
||||
raise ValueError("No file info list found.")
|
||||
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
file = db.session.scalar(
|
||||
select(UploadFile)
|
||||
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# raise error if file not found
|
||||
@@ -2595,8 +2612,8 @@ class DocumentService:
|
||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
@@ -2605,7 +2622,7 @@ class DocumentService:
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
@@ -2650,8 +2667,10 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
# update document segment
|
||||
|
||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.values(status=SegmentStatus.RE_SEGMENT)
|
||||
)
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
@@ -3143,10 +3162,8 @@ class SegmentService:
|
||||
lock_name = f"add_segment_lock_document_id_{document.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@@ -3198,7 +3215,7 @@ class SegmentService:
|
||||
segment_document.status = SegmentStatus.ERROR
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
segment = db.session.get(DocumentSegment, segment_document.id)
|
||||
return segment
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
@@ -3221,10 +3238,8 @@ class SegmentService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
@@ -3369,11 +3384,7 @@ class SegmentService:
|
||||
else:
|
||||
raise ValueError("The knowledge base index technique is not high quality!")
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
@@ -3391,13 +3402,13 @@ class SegmentService:
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
existing_summary = db.session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Check if summary has changed
|
||||
@@ -3473,11 +3484,7 @@ class SegmentService:
|
||||
else:
|
||||
raise ValueError("The knowledge base index technique is not high quality!")
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
@@ -3489,13 +3496,13 @@ class SegmentService:
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
existing_summary = db.session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if args.summary is None:
|
||||
@@ -3561,7 +3568,7 @@ class SegmentService:
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
new_segment = db.session.get(DocumentSegment, segment.id)
|
||||
if not new_segment:
|
||||
raise ValueError("new_segment is not found")
|
||||
return new_segment
|
||||
@@ -3581,15 +3588,14 @@ class SegmentService:
|
||||
# Get child chunk IDs before parent segment is deleted
|
||||
child_node_ids = []
|
||||
if segment.index_node_id:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
child_node_ids = list(
|
||||
db.session.scalars(
|
||||
select(ChildChunk.index_node_id).where(
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay(
|
||||
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
|
||||
@@ -3608,17 +3614,14 @@ class SegmentService:
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
segments_info = (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
|
||||
.where(
|
||||
segments_info = db.session.execute(
|
||||
select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments_info:
|
||||
return
|
||||
@@ -3630,15 +3633,16 @@ class SegmentService:
|
||||
# Get child chunk IDs before parent segments are deleted
|
||||
child_node_ids = []
|
||||
if index_node_ids:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id.in_(segment_db_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
child_node_ids = [
|
||||
nid
|
||||
for nid in db.session.scalars(
|
||||
select(ChildChunk.index_node_id).where(
|
||||
ChildChunk.segment_id.in_(segment_db_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
).all()
|
||||
if nid
|
||||
]
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
@@ -3654,7 +3658,7 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
|
||||
# Delete database records
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
|
||||
db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)))
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
@@ -3728,15 +3732,13 @@ class SegmentService:
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
index_node_id = str(uuid.uuid4())
|
||||
index_node_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
db.session.query(func.max(ChildChunk.position))
|
||||
.where(
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(ChildChunk.position)).where(
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
ChildChunk.document_id == document.id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
child_chunk = ChildChunk(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@@ -3896,10 +3898,8 @@ class SegmentService:
|
||||
@classmethod
|
||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
|
||||
"""Get a child chunk by its ID."""
|
||||
result = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
|
||||
.first()
|
||||
result = db.session.scalar(
|
||||
select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
return result if isinstance(result, ChildChunk) else None
|
||||
|
||||
@@ -3934,10 +3934,10 @@ class SegmentService:
|
||||
@classmethod
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
|
||||
"""Get a segment by its ID."""
|
||||
result = (
|
||||
db.session.query(DocumentSegment)
|
||||
result = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
@@ -3980,15 +3980,15 @@ class DatasetCollectionBindingService:
|
||||
def get_dataset_collection_binding(
|
||||
cls, provider_name: str, model_name: str, collection_type: str = "dataset"
|
||||
) -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = db.session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name == provider_name,
|
||||
DatasetCollectionBinding.model_name == model_name,
|
||||
DatasetCollectionBinding.type == collection_type,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
@@ -4006,13 +4006,13 @@ class DatasetCollectionBindingService:
|
||||
def get_dataset_collection_binding_by_id_and_type(
|
||||
cls, collection_binding_id: str, collection_type: str = "dataset"
|
||||
) -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = db.session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise ValueError("Dataset collection binding not found")
|
||||
@@ -4034,7 +4034,7 @@ class DatasetPermissionService:
|
||||
@classmethod
|
||||
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
|
||||
try:
|
||||
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
|
||||
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
|
||||
permissions = []
|
||||
for user in user_list:
|
||||
permission = DatasetPermission(
|
||||
@@ -4070,7 +4070,7 @@ class DatasetPermissionService:
|
||||
@classmethod
|
||||
def clear_partial_member_list(cls, dataset_id):
|
||||
try:
|
||||
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
|
||||
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -367,16 +368,16 @@ class DatasourceProviderService:
|
||||
check if tenant oauth params is enabled
|
||||
"""
|
||||
return (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
enabled=True,
|
||||
db.session.scalar(
|
||||
select(func.count(DatasourceOauthTenantParamConfig.id)).where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
DatasourceOauthTenantParamConfig.enabled == True,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
def get_tenant_oauth_client(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
|
||||
@@ -384,14 +385,14 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get tenant oauth client
|
||||
"""
|
||||
tenant_oauth_client_params = (
|
||||
db.session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
tenant_oauth_client_params = db.session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
@@ -707,24 +708,27 @@ class DatasourceProviderService:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
datasource_providers: list[DatasourceProvider] = list(
|
||||
db.session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
default_provider = db.session.execute(
|
||||
select(DatasourceProvider.id)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
copy_credentials_list = []
|
||||
default_provider = (
|
||||
db.session.query(DatasourceProvider.id)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
).first()
|
||||
default_provider_id = default_provider.id if default_provider else None
|
||||
for datasource_provider in datasource_providers:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
@@ -880,14 +884,14 @@ class DatasourceProviderService:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.all()
|
||||
datasource_providers: list[DatasourceProvider] = list(
|
||||
db.session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if not datasource_providers:
|
||||
return []
|
||||
@@ -987,10 +991,15 @@ class DatasourceProviderService:
|
||||
:param plugin_id: plugin id
|
||||
:return:
|
||||
"""
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
datasource_provider = db.session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == auth_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if datasource_provider:
|
||||
db.session.delete(datasource_provider)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import case, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -25,14 +25,14 @@ class EndUserService:
|
||||
"""
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return (
|
||||
session.query(EndUser)
|
||||
return session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.id == end_user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id == app_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -57,8 +57,8 @@ class EndUserService:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility
|
||||
# This single query approach is more efficient than separate queries
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
end_user = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id == app_id,
|
||||
@@ -68,7 +68,7 @@ class EndUserService:
|
||||
# Prioritize records with matching type (0 = match, 1 = no match)
|
||||
case((EndUser.type == type, 0), else_=1)
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if end_user:
|
||||
@@ -137,15 +137,15 @@ class EndUserService:
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Fetch existing end users for all target apps in a single query
|
||||
existing_end_users: list[EndUser] = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id.in_(unique_app_ids),
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == type,
|
||||
)
|
||||
.all()
|
||||
existing_end_users: list[EndUser] = list(
|
||||
session.scalars(
|
||||
select(EndUser).where(
|
||||
EndUser.tenant_id == tenant_id,
|
||||
EndUser.app_id.in_(unique_app_ids),
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == type,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
found_app_ids: set[str] = set()
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
import enum
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
from services.enterprise.base import EnterprisePluginManagerRequest
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginCredentialType(enum.Enum):
|
||||
MODEL = 0 # must be 0 for API contract compatibility
|
||||
TOOL = 1 # must be 1 for API contract compatibility
|
||||
|
||||
def to_number(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class CheckCredentialPolicyComplianceRequest(BaseModel):
|
||||
dify_credential_id: str
|
||||
provider: str
|
||||
|
||||
31
api/services/entities/auth_entities.py
Normal file
31
api/services/entities/auth_entities.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from libs.helper import EmailStr
|
||||
from libs.password import valid_password
|
||||
|
||||
|
||||
class LoginPayloadBase(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(min_length=1)
|
||||
new_password: str
|
||||
password_confirm: str
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
@@ -1,17 +1,12 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.entities import Rule
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
class ParentMode(StrEnum):
|
||||
FULL_DOC = "full-doc"
|
||||
PARAGRAPH = "paragraph"
|
||||
|
||||
|
||||
class NotionIcon(BaseModel):
|
||||
type: str
|
||||
url: str | None = None
|
||||
@@ -53,24 +48,6 @@ class DataSource(BaseModel):
|
||||
info_list: InfoList
|
||||
|
||||
|
||||
class PreProcessingRule(BaseModel):
|
||||
id: str
|
||||
enabled: bool
|
||||
|
||||
|
||||
class Segmentation(BaseModel):
|
||||
separator: str = "\n"
|
||||
max_tokens: int
|
||||
chunk_overlap: int = 0
|
||||
|
||||
|
||||
class Rule(BaseModel):
|
||||
pre_processing_rules: list[PreProcessingRule] | None = None
|
||||
segmentation: Segmentation | None = None
|
||||
parent_mode: Literal["full-doc", "paragraph"] | None = None
|
||||
subchunk_segmentation: Segmentation | None = None
|
||||
|
||||
|
||||
class ProcessRule(BaseModel):
|
||||
mode: Literal["automatic", "custom", "hierarchical"]
|
||||
rules: Rule | None = None
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.entities import KeywordSetting, VectorSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
|
||||
@@ -36,24 +37,6 @@ class RerankingModelConfig(BaseModel):
|
||||
reranking_model_name: str | None = ""
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""
|
||||
Vector Setting.
|
||||
"""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""
|
||||
Keyword Setting.
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
@@ -63,23 +46,6 @@ class WeightedScoreConfig(BaseModel):
|
||||
keyword_setting: KeywordSetting | None
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
@@ -95,16 +61,6 @@ class RetrievalSetting(BaseModel):
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
|
||||
|
||||
class KnowledgeConfiguration(BaseModel):
|
||||
"""
|
||||
Knowledge Base Configuration.
|
||||
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from graphon.nodes.http_request.exc import InvalidHttpMethodError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import (
|
||||
@@ -103,8 +103,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -112,8 +114,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -132,8 +136,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -144,9 +150,12 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
|
||||
count = (
|
||||
db.session.query(ExternalKnowledgeBindings)
|
||||
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ExternalKnowledgeBindings.id)).where(
|
||||
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
return True, count
|
||||
@@ -154,8 +163,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
@@ -163,8 +174,10 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("api template not found")
|
||||
@@ -238,12 +251,17 @@ class ExternalDatasetService:
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
|
||||
if db.session.scalar(
|
||||
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
|
||||
):
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
|
||||
.first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(
|
||||
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
|
||||
ExternalKnowledgeApis.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if external_knowledge_api is None:
|
||||
@@ -284,18 +302,20 @@ class ExternalDatasetService:
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: MetadataCondition | None = None,
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
|
||||
external_knowledge_api = (
|
||||
db.session.query(ExternalKnowledgeApis)
|
||||
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
|
||||
.first()
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("external api template not found")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
|
||||
@@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryDict(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
class RetrieveResponseDict(TypedDict):
|
||||
query: QueryDict
|
||||
records: list[dict[str, Any]]
|
||||
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
@@ -34,7 +44,7 @@ class HitTestingService:
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
retrieval_model: dict | None,
|
||||
external_retrieval_model: dict,
|
||||
attachment_ids: list | None = None,
|
||||
limit: int = 10,
|
||||
@@ -44,12 +54,13 @@ class HitTestingService:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
assert isinstance(retrieval_model, dict)
|
||||
document_ids_filter = None
|
||||
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
|
||||
if metadata_filtering_conditions and query:
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
from core.app.app_config.entities import MetadataFilteringCondition
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
|
||||
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
|
||||
@@ -150,7 +161,7 @@ class HitTestingService:
|
||||
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
|
||||
return {
|
||||
@@ -161,7 +172,7 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
|
||||
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict:
|
||||
records = []
|
||||
if dataset.provider == "external":
|
||||
for document in documents:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -25,10 +27,14 @@ class MetadataService:
|
||||
raise ValueError("Metadata name cannot exceed 255 characters.")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check if metadata name already exists
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
DatasetMetadata.tenant_id == current_tenant_id,
|
||||
DatasetMetadata.dataset_id == dataset_id,
|
||||
DatasetMetadata.name == metadata_args.name,
|
||||
)
|
||||
.limit(1)
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
@@ -54,10 +60,14 @@ class MetadataService:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if (
|
||||
db.session.query(DatasetMetadata)
|
||||
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
|
||||
.first()
|
||||
if db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
DatasetMetadata.tenant_id == current_tenant_id,
|
||||
DatasetMetadata.dataset_id == dataset_id,
|
||||
DatasetMetadata.name == name,
|
||||
)
|
||||
.limit(1)
|
||||
):
|
||||
raise ValueError("Metadata name already exists.")
|
||||
for field in BuiltInField:
|
||||
@@ -65,7 +75,11 @@ class MetadataService:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
|
||||
metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
|
||||
.limit(1)
|
||||
)
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
old_name = metadata.name
|
||||
@@ -74,9 +88,9 @@ class MetadataService:
|
||||
metadata.updated_at = naive_utc_now()
|
||||
|
||||
# update related documents
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
dataset_metadata_bindings = db.session.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
|
||||
).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
@@ -101,15 +115,19 @@ class MetadataService:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
|
||||
metadata = db.session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
|
||||
.limit(1)
|
||||
)
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
|
||||
# deal related documents
|
||||
dataset_metadata_bindings = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
|
||||
)
|
||||
dataset_metadata_bindings = db.session.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
|
||||
).all()
|
||||
if dataset_metadata_bindings:
|
||||
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
|
||||
documents = DocumentService.get_document_by_ids(document_ids)
|
||||
@@ -224,16 +242,23 @@ class MetadataService:
|
||||
|
||||
# deal metadata binding (in the same transaction as the doc_metadata update)
|
||||
if not operation.partial_update:
|
||||
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
|
||||
db.session.execute(
|
||||
delete(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.document_id == operation.document_id
|
||||
)
|
||||
)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
for metadata_value in operation.metadata_list:
|
||||
# check if binding already exists
|
||||
if operation.partial_update:
|
||||
existing_binding = (
|
||||
db.session.query(DatasetMetadataBinding)
|
||||
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
|
||||
.first()
|
||||
existing_binding = db.session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.document_id == operation.document_id,
|
||||
DatasetMetadataBinding.metadata_id == metadata_value.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing_binding:
|
||||
continue
|
||||
@@ -275,9 +300,13 @@ class MetadataService:
|
||||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"count": db.session.query(DatasetMetadataBinding)
|
||||
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
|
||||
.count(),
|
||||
"count": db.session.scalar(
|
||||
select(func.count(DatasetMetadataBinding.id)).where(
|
||||
DatasetMetadataBinding.metadata_id == item.get("id"),
|
||||
DatasetMetadataBinding.dataset_id == dataset.id,
|
||||
)
|
||||
)
|
||||
or 0,
|
||||
}
|
||||
for item in dataset.doc_metadata or []
|
||||
if item.get("id") != "built-in"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@@ -25,6 +25,23 @@ from models.provider import LoadBalancingModelConfig, ProviderCredential, Provid
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadBalancingConfigDetailDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict[str, Any]
|
||||
enabled: bool
|
||||
|
||||
|
||||
class LoadBalancingConfigSummaryDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict[str, Any]
|
||||
credential_id: str | None
|
||||
enabled: bool
|
||||
in_cooldown: bool
|
||||
ttl: int
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
@staticmethod
|
||||
def _get_provider_manager(tenant_id: str) -> ProviderManager:
|
||||
@@ -74,7 +91,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||
) -> tuple[bool, list[dict]]:
|
||||
) -> tuple[bool, list[LoadBalancingConfigSummaryDict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@@ -156,7 +173,7 @@ class ModelLoadBalancingService:
|
||||
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
# fetch status and ttl for each config
|
||||
datas = []
|
||||
datas: list[LoadBalancingConfigSummaryDict] = []
|
||||
for load_balancing_config in load_balancing_configs:
|
||||
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
|
||||
tenant_id=tenant_id,
|
||||
@@ -214,7 +231,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def get_load_balancing_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
|
||||
) -> dict | None:
|
||||
) -> LoadBalancingConfigDetailDict | None:
|
||||
"""
|
||||
Get load balancing configuration.
|
||||
:param tenant_id: workspace id
|
||||
@@ -267,12 +284,13 @@ class ModelLoadBalancingService:
|
||||
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
)
|
||||
|
||||
return {
|
||||
result: LoadBalancingConfigDetailDict = {
|
||||
"id": load_balancing_model_config.id,
|
||||
"name": load_balancing_model_config.name,
|
||||
"credentials": credentials,
|
||||
"enabled": load_balancing_model_config.enabled,
|
||||
}
|
||||
return result
|
||||
|
||||
def _init_inherit_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: ModelType
|
||||
|
||||
@@ -2,7 +2,7 @@ import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
@@ -29,7 +29,7 @@ class OAuthServerService:
|
||||
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
||||
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
return session.execute(query).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
import click
|
||||
@@ -42,6 +42,16 @@ class _TenantPluginRecord(TypedDict):
|
||||
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
|
||||
|
||||
|
||||
class ExtractedPluginsDict(TypedDict):
|
||||
plugins: dict[str, str]
|
||||
plugin_not_exist: list[str]
|
||||
|
||||
|
||||
class PluginInstallResultDict(TypedDict):
|
||||
success: list[str]
|
||||
failed: list[str]
|
||||
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
@@ -310,7 +320,7 @@ class PluginMigration:
|
||||
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
|
||||
def extract_unique_plugins(cls, extracted_plugins: str) -> ExtractedPluginsDict:
|
||||
plugins: dict[str, str] = {}
|
||||
plugin_ids = []
|
||||
plugin_not_exist = []
|
||||
@@ -524,7 +534,7 @@ class PluginMigration:
|
||||
@classmethod
|
||||
def handle_plugin_instance_install(
|
||||
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
|
||||
) -> Mapping[str, Any]:
|
||||
) -> PluginInstallResultDict:
|
||||
"""
|
||||
Install plugins for a tenant.
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@@ -56,24 +57,24 @@ class PluginParameterService:
|
||||
# fetch credentials from db
|
||||
with Session(db.engine) as session:
|
||||
if credential_id:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
db_record = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
db_record = (
|
||||
session.query(BuiltinToolProvider)
|
||||
db_record = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if db_record is None:
|
||||
|
||||
@@ -38,11 +38,7 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
from core.helper import marketplace
|
||||
from core.rag.entities.event import (
|
||||
DatasourceCompletedEvent,
|
||||
DatasourceErrorEvent,
|
||||
DatasourceProcessingEvent,
|
||||
)
|
||||
from core.rag.entities import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping
|
||||
@@ -156,27 +152,27 @@ class RagPipelineService:
|
||||
:param template_id: template id
|
||||
:param template_info: template info
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
# check template name is exist
|
||||
template_name = template_info.name
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.id != template_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
@@ -192,13 +188,13 @@ class RagPipelineService:
|
||||
"""
|
||||
Delete customized pipeline template.
|
||||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not customized_template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
@@ -210,14 +206,14 @@ class RagPipelineService:
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by rag pipeline
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
@@ -232,28 +228,28 @@ class RagPipelineService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == pipeline.workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""Fetch a published workflow snapshot by ID for restore operations."""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == workflow_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if workflow and workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("source workflow must be published")
|
||||
@@ -974,7 +970,7 @@ class RagPipelineService:
|
||||
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
|
||||
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
|
||||
if document_id:
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
document = db.session.get(Document, document_id.value)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.ERROR
|
||||
document.error = error
|
||||
@@ -1178,15 +1174,15 @@ class RagPipelineService:
|
||||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
@@ -1194,26 +1190,26 @@ class RagPipelineService:
|
||||
# check template name is exist
|
||||
template_name = args.get("name")
|
||||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
template = db.session.scalar(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if template:
|
||||
raise ValueError("Template name is already exists")
|
||||
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineCustomizedTemplate.position))
|
||||
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(PipelineCustomizedTemplate.position)).where(
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
|
||||
)
|
||||
)
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
if args.get("icon_info") is None:
|
||||
@@ -1239,13 +1235,14 @@ class RagPipelineService:
|
||||
|
||||
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
db.session.scalar(
|
||||
select(func.count(Workflow.id)).where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
) > 0
|
||||
|
||||
def get_node_last_run(
|
||||
@@ -1353,11 +1350,11 @@ class RagPipelineService:
|
||||
|
||||
def get_recommended_plugins(self, type: str) -> dict:
|
||||
# Query active recommended plugins
|
||||
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
|
||||
if type and type != "all":
|
||||
query = query.where(PipelineRecommendedPlugin.type == type)
|
||||
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
|
||||
|
||||
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
|
||||
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
|
||||
|
||||
if not pipeline_recommended_plugins:
|
||||
return {
|
||||
@@ -1396,14 +1393,12 @@ class RagPipelineService:
|
||||
"""
|
||||
Retry error document
|
||||
"""
|
||||
document_pipeline_execution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
document_pipeline_execution_log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
|
||||
)
|
||||
if not document_pipeline_execution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
|
||||
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
@@ -1432,23 +1427,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
@@ -1530,23 +1525,23 @@ class RagPipelineService:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
dataset: Dataset | None = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
pipeline: Pipeline | None = db.session.scalar(
|
||||
select(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, tuple_
|
||||
@@ -158,6 +158,13 @@ class MessagesCleanupMetrics:
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class MessagesCleanStatsDict(TypedDict):
|
||||
batches: int
|
||||
total_messages: int
|
||||
filtered_messages: int
|
||||
total_deleted: int
|
||||
|
||||
|
||||
class MessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages based on retention policies.
|
||||
@@ -299,7 +306,7 @@ class MessagesCleanService:
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
def run(self) -> dict[str, int]:
|
||||
def run(self) -> MessagesCleanStatsDict:
|
||||
"""
|
||||
Execute the message cleanup operation.
|
||||
|
||||
@@ -319,7 +326,7 @@ class MessagesCleanService:
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _clean_messages_by_time_range(self) -> dict[str, int]:
|
||||
def _clean_messages_by_time_range(self) -> MessagesCleanStatsDict:
|
||||
"""
|
||||
Clean messages within a time range using cursor-based pagination.
|
||||
|
||||
@@ -334,7 +341,7 @@ class MessagesCleanService:
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
stats = {
|
||||
stats: MessagesCleanStatsDict = {
|
||||
"batches": 0,
|
||||
"total_messages": 0,
|
||||
"filtered_messages": 0,
|
||||
|
||||
@@ -24,7 +24,7 @@ import zipfile
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import click
|
||||
from graphon.enums import WorkflowType
|
||||
@@ -49,6 +49,23 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHI
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableStatsManifestEntry(TypedDict):
|
||||
row_count: int
|
||||
checksum: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class ArchiveManifestDict(TypedDict):
|
||||
schema_version: str
|
||||
workflow_run_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
created_at: str
|
||||
archived_at: str
|
||||
tables: dict[str, TableStatsManifestEntry]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableStats:
|
||||
"""Statistics for a single archived table."""
|
||||
@@ -472,25 +489,26 @@ class WorkflowRunArchiver:
|
||||
self,
|
||||
run: WorkflowRun,
|
||||
table_stats: list[TableStats],
|
||||
) -> dict[str, Any]:
|
||||
) -> ArchiveManifestDict:
|
||||
"""Generate a manifest for the archived workflow run."""
|
||||
return {
|
||||
"schema_version": ARCHIVE_SCHEMA_VERSION,
|
||||
"workflow_run_id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"app_id": run.app_id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"created_at": run.created_at.isoformat(),
|
||||
"archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||
"tables": {
|
||||
stat.table_name: {
|
||||
"row_count": stat.row_count,
|
||||
"checksum": stat.checksum,
|
||||
"size_bytes": stat.size_bytes,
|
||||
}
|
||||
for stat in table_stats
|
||||
},
|
||||
tables: dict[str, TableStatsManifestEntry] = {
|
||||
stat.table_name: {
|
||||
"row_count": stat.row_count,
|
||||
"checksum": stat.checksum,
|
||||
"size_bytes": stat.size_bytes,
|
||||
}
|
||||
for stat in table_stats
|
||||
}
|
||||
return ArchiveManifestDict(
|
||||
schema_version=ARCHIVE_SCHEMA_VERSION,
|
||||
workflow_run_id=run.id,
|
||||
tenant_id=run.tenant_id,
|
||||
app_id=run.app_id,
|
||||
workflow_id=run.workflow_id,
|
||||
created_at=run.created_at.isoformat(),
|
||||
archived_at=datetime.datetime.now(datetime.UTC).isoformat(),
|
||||
tables=tables,
|
||||
)
|
||||
|
||||
def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@@ -12,7 +12,7 @@ from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
@@ -24,6 +24,15 @@ if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class RelatedCountsDict(TypedDict):
|
||||
node_executions: int
|
||||
offloads: int
|
||||
app_logs: int
|
||||
trigger_logs: int
|
||||
pauses: int
|
||||
pause_reasons: int
|
||||
|
||||
|
||||
class WorkflowRunCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
|
||||
@@ -173,6 +182,9 @@ class WorkflowRunCleanupMetrics:
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
_RELATED_RECORD_KEYS = ("node_executions", "offloads", "app_logs", "trigger_logs", "pauses", "pause_reasons")
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -230,7 +242,7 @@ class WorkflowRunCleanup:
|
||||
|
||||
total_runs_deleted = 0
|
||||
total_runs_targeted = 0
|
||||
related_totals = self._empty_related_counts() if self.dry_run else None
|
||||
related_totals: RelatedCountsDict | None = self._empty_related_counts() if self.dry_run else None
|
||||
batch_index = 0
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
status = "success"
|
||||
@@ -312,8 +324,7 @@ class WorkflowRunCleanup:
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
self._accumulate_related_counts(related_totals, batch_counts)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
@@ -332,7 +343,10 @@ class WorkflowRunCleanup:
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_counts={
|
||||
k: batch_counts[k] # type: ignore[literal-required]
|
||||
for k in _RELATED_RECORD_KEYS
|
||||
},
|
||||
related_action="would_delete",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
@@ -372,7 +386,10 @@ class WorkflowRunCleanup:
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=counts["runs"],
|
||||
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_counts={
|
||||
k: counts[k] # type: ignore[literal-required]
|
||||
for k in _RELATED_RECORD_KEYS
|
||||
},
|
||||
related_action="deleted",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
@@ -506,7 +523,7 @@ class WorkflowRunCleanup:
|
||||
return trigger_repo.count_by_run_ids(run_ids)
|
||||
|
||||
@staticmethod
|
||||
def _empty_related_counts() -> dict[str, int]:
|
||||
def _empty_related_counts() -> RelatedCountsDict:
|
||||
return {
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
@@ -517,7 +534,7 @@ class WorkflowRunCleanup:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_related_counts(counts: dict[str, int]) -> str:
|
||||
def _format_related_counts(counts: RelatedCountsDict) -> str:
|
||||
return (
|
||||
f"node_executions {counts['node_executions']}, "
|
||||
f"offloads {counts['offloads']}, "
|
||||
@@ -527,6 +544,15 @@ class WorkflowRunCleanup:
|
||||
f"pause_reasons {counts['pause_reasons']}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _accumulate_related_counts(totals: RelatedCountsDict, batch: RunsWithRelatedCountsDict) -> None:
|
||||
totals["node_executions"] += batch.get("node_executions", 0)
|
||||
totals["offloads"] += batch.get("offloads", 0)
|
||||
totals["app_logs"] += batch.get("app_logs", 0)
|
||||
totals["trigger_logs"] += batch.get("trigger_logs", 0)
|
||||
totals["pauses"] += batch.get("pauses", 0)
|
||||
totals["pause_reasons"] += batch.get("pause_reasons", 0)
|
||||
|
||||
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
|
||||
@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
|
||||
|
||||
@@ -23,7 +23,17 @@ class DeleteResult:
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
success: bool
|
||||
deleted_counts: dict[str, int] = field(default_factory=dict)
|
||||
deleted_counts: RunsWithRelatedCountsDict = field(
|
||||
default_factory=lambda: { # type: ignore[assignment]
|
||||
"runs": 0,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
)
|
||||
error: str | None = None
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
@@ -25,6 +25,22 @@ from models.enums import SummaryStatus
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummaryEntryDict(TypedDict):
|
||||
segment_id: str
|
||||
segment_position: int
|
||||
status: str
|
||||
summary_preview: str | None
|
||||
error: str | None
|
||||
created_at: int | None
|
||||
updated_at: int | None
|
||||
|
||||
|
||||
class DocumentSummaryStatusDetailDict(TypedDict):
|
||||
total_segments: int
|
||||
summary_status: dict[str, int]
|
||||
summaries: list[SummaryEntryDict]
|
||||
|
||||
|
||||
class SummaryIndexService:
|
||||
"""Service for generating and managing summary indexes."""
|
||||
|
||||
@@ -1352,7 +1368,7 @@ class SummaryIndexService:
|
||||
def get_document_summary_status_detail(
|
||||
document_id: str,
|
||||
dataset_id: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> DocumentSummaryStatusDetailDict:
|
||||
"""
|
||||
Get detailed summary status for a document.
|
||||
|
||||
@@ -1403,7 +1419,7 @@ class SummaryIndexService:
|
||||
SummaryStatus.NOT_STARTED: 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
summary_list: list[SummaryEntryDict] = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
@@ -1438,8 +1454,8 @@ class SummaryIndexService:
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"summary_status": status_counts,
|
||||
"summaries": summary_list,
|
||||
}
|
||||
return DocumentSummaryStatusDetailDict(
|
||||
total_segments=total_segments,
|
||||
summary_status=cast(dict[str, int], status_counts),
|
||||
summaries=summary_list,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import uuid
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -11,6 +12,28 @@ from models.enums import TagType
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
class SaveTagPayload(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=50)
|
||||
type: TagType
|
||||
|
||||
|
||||
class UpdateTagPayload(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=50)
|
||||
type: TagType
|
||||
|
||||
|
||||
class TagBindingCreatePayload(BaseModel):
|
||||
tag_ids: list[str]
|
||||
target_id: str
|
||||
type: TagType
|
||||
|
||||
|
||||
class TagBindingDeletePayload(BaseModel):
|
||||
tag_id: str
|
||||
target_id: str
|
||||
type: TagType
|
||||
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
|
||||
@@ -78,12 +101,12 @@ class TagService:
|
||||
return tags or []
|
||||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
|
||||
def save_tags(payload: SaveTagPayload) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
|
||||
raise ValueError("Tag name already exists")
|
||||
tag = Tag(
|
||||
name=args["name"],
|
||||
type=TagType(args["type"]),
|
||||
name=payload.name,
|
||||
type=TagType(payload.type),
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
@@ -93,13 +116,24 @@ class TagService:
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
|
||||
raise ValueError("Tag name already exists")
|
||||
def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
|
||||
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
tag.name = args["name"]
|
||||
if payload.name != tag.name:
|
||||
existing = db.session.scalar(
|
||||
select(Tag)
|
||||
.where(
|
||||
Tag.name == payload.name,
|
||||
Tag.tenant_id == current_user.current_tenant_id,
|
||||
Tag.type == tag.type,
|
||||
Tag.id != tag_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if existing:
|
||||
raise ValueError("Tag name already exists")
|
||||
tag.name = payload.name
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@@ -122,21 +156,19 @@ class TagService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def save_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# save tag binding
|
||||
for tag_id in args["tag_ids"]:
|
||||
def save_tag_binding(payload: TagBindingCreatePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
for tag_id in payload.tag_ids:
|
||||
tag_binding = db.session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
|
||||
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
|
||||
.limit(1)
|
||||
)
|
||||
if tag_binding:
|
||||
continue
|
||||
new_tag_binding = TagBinding(
|
||||
tag_id=tag_id,
|
||||
target_id=args["target_id"],
|
||||
target_id=payload.target_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
@@ -144,17 +176,15 @@ class TagService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.scalar(
|
||||
def delete_tag_binding(payload: TagBindingDeletePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
tag_binding = db.session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"])
|
||||
.where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id)
|
||||
.limit(1)
|
||||
)
|
||||
if tag_bindings:
|
||||
db.session.delete(tag_bindings)
|
||||
if tag_binding:
|
||||
db.session.delete(tag_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -285,7 +285,7 @@ class MCPToolManageService:
|
||||
|
||||
# Batch query all users to avoid N+1 problem
|
||||
user_ids = {provider.user_id for provider in mcp_providers}
|
||||
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
|
||||
users = self._session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
|
||||
user_name_map = {user.id: user.name for user in users}
|
||||
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
@@ -21,6 +20,7 @@ from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
emoji_icon_adapter,
|
||||
)
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
@@ -53,11 +53,14 @@ class ToolTransformService:
|
||||
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
parsed = emoji_icon_adapter.validate_json(icon)
|
||||
return {"background": parsed["background"], "content": parsed["content"]}
|
||||
return {"background": icon["background"], "content": icon["content"]}
|
||||
except (ValueError, ValidationError, KeyError):
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
if isinstance(icon, Mapping):
|
||||
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
|
||||
return icon
|
||||
return ""
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ import logging
|
||||
from datetime import datetime
|
||||
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy import delete, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration, emoji_icon_adapter
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
@@ -42,20 +42,22 @@ class WorkflowToolManageService:
|
||||
labels: list[str] | None = None,
|
||||
):
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
@@ -122,30 +124,30 @@ class WorkflowToolManageService:
|
||||
:return: the updated tool
|
||||
"""
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
existing_workflow_tool_provider = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App | None = (
|
||||
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if app is None:
|
||||
@@ -234,9 +236,11 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
"""
|
||||
db.session.query(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@@ -251,10 +255,10 @@ class WorkflowToolManageService:
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@@ -267,10 +271,10 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@@ -284,8 +288,8 @@ class WorkflowToolManageService:
|
||||
if db_tool is None:
|
||||
raise ValueError("Tool not found")
|
||||
|
||||
workflow_app: App | None = (
|
||||
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
|
||||
workflow_app: App | None = db.session.scalar(
|
||||
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
|
||||
)
|
||||
|
||||
if workflow_app is None:
|
||||
@@ -309,7 +313,7 @@ class WorkflowToolManageService:
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"icon": emoji_icon_adapter.validate_json(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"output_schema": output_schema,
|
||||
@@ -331,10 +335,10 @@ class WorkflowToolManageService:
|
||||
:param workflow_tool_id: the workflow tool id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_tool: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import mimetypes
|
||||
import secrets
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
import orjson
|
||||
from flask import request
|
||||
@@ -51,6 +51,26 @@ logger = logging.getLogger(__name__)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class RawWebhookDataDict(TypedDict):
|
||||
method: str
|
||||
headers: dict[str, str]
|
||||
query_params: dict[str, str]
|
||||
body: dict[str, Any]
|
||||
files: dict[str, Any]
|
||||
|
||||
|
||||
class ValidationResultDict(TypedDict):
|
||||
valid: bool
|
||||
error: NotRequired[str]
|
||||
|
||||
|
||||
class WorkflowInputsDict(TypedDict):
|
||||
webhook_data: RawWebhookDataDict
|
||||
webhook_headers: dict[str, str]
|
||||
webhook_query_params: dict[str, str]
|
||||
webhook_body: dict[str, Any]
|
||||
|
||||
|
||||
class WebhookService:
|
||||
"""Service for handling webhook operations."""
|
||||
|
||||
@@ -146,7 +166,7 @@ class WebhookService:
|
||||
@classmethod
|
||||
def extract_and_validate_webhook_data(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
|
||||
) -> dict[str, Any]:
|
||||
) -> RawWebhookDataDict:
|
||||
"""Extract and validate webhook data in a single unified process.
|
||||
|
||||
Args:
|
||||
@@ -166,7 +186,7 @@ class WebhookService:
|
||||
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
|
||||
validation_result = cls._validate_http_metadata(raw_data, node_data)
|
||||
if not validation_result["valid"]:
|
||||
raise ValueError(validation_result["error"])
|
||||
raise ValueError(validation_result.get("error", "Validation failed"))
|
||||
|
||||
# Process and validate data according to configuration
|
||||
processed_data = cls._process_and_validate_data(raw_data, node_data)
|
||||
@@ -174,7 +194,7 @@ class WebhookService:
|
||||
return processed_data
|
||||
|
||||
@classmethod
|
||||
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
|
||||
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict:
|
||||
"""Extract raw data from incoming webhook request without type conversion.
|
||||
|
||||
Args:
|
||||
@@ -190,7 +210,7 @@ class WebhookService:
|
||||
"""
|
||||
cls._validate_content_length()
|
||||
|
||||
data = {
|
||||
data: RawWebhookDataDict = {
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": dict(request.args),
|
||||
@@ -224,7 +244,7 @@ class WebhookService:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict:
|
||||
"""Process and validate webhook data according to node configuration.
|
||||
|
||||
Args:
|
||||
@@ -665,7 +685,7 @@ class WebhookService:
|
||||
raise ValueError(f"Required header missing: {header_name}")
|
||||
|
||||
@classmethod
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> ValidationResultDict:
|
||||
"""Validate HTTP method and content-type.
|
||||
|
||||
Args:
|
||||
@@ -709,7 +729,7 @@ class WebhookService:
|
||||
return content_type.split(";")[0].strip()
|
||||
|
||||
@classmethod
|
||||
def _validation_error(cls, error_message: str) -> dict[str, Any]:
|
||||
def _validation_error(cls, error_message: str) -> ValidationResultDict:
|
||||
"""Create a standard validation error response.
|
||||
|
||||
Args:
|
||||
@@ -730,7 +750,7 @@ class WebhookService:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> WorkflowInputsDict:
|
||||
"""Construct workflow inputs payload from webhook data.
|
||||
|
||||
Args:
|
||||
@@ -748,7 +768,7 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_execution(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow
|
||||
) -> None:
|
||||
"""Trigger workflow execution via AsyncWorkflowService.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import delete, select
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.entities import ParentMode
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
@@ -15,7 +16,6 @@ from extensions.ext_database import db
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, NotRequired, TypedDict, cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
@@ -126,6 +126,15 @@ class WebsiteCrawlStatusApiRequest:
|
||||
return cls(provider=provider, job_id=job_id)
|
||||
|
||||
|
||||
class CrawlStatusDict(TypedDict):
|
||||
status: str
|
||||
job_id: str
|
||||
total: int
|
||||
current: int
|
||||
data: list[Any]
|
||||
time_consuming: NotRequired[str | float]
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
"""Service class for website crawling operations using different providers."""
|
||||
|
||||
@@ -261,13 +270,13 @@ class WebsiteService:
|
||||
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> CrawlStatusDict:
|
||||
"""Get crawl status using string parameters."""
|
||||
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
|
||||
return cls.get_crawl_status_typed(api_request)
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
|
||||
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> CrawlStatusDict:
|
||||
"""Get crawl status using typed request."""
|
||||
api_key, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
|
||||
|
||||
@@ -281,10 +290,10 @@ class WebsiteService:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data: dict[str, Any] = {
|
||||
crawl_status_data: CrawlStatusDict = {
|
||||
"status": result["status"],
|
||||
"job_id": job_id,
|
||||
"total": result["total"] or 0,
|
||||
@@ -302,18 +311,18 @@ class WebsiteService:
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
|
||||
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict:
|
||||
return cast(CrawlStatusDict, dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)))
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> CrawlStatusDict:
|
||||
response = _adaptive_http_client.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
crawl_status_data = {
|
||||
crawl_status_data: CrawlStatusDict = {
|
||||
"status": data.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": len(data.get("urls", [])),
|
||||
|
||||
@@ -170,34 +170,38 @@ class WorkflowConverter:
|
||||
|
||||
graph = self._append_node(graph, llm_node)
|
||||
|
||||
if new_app_mode == AppMode.WORKFLOW:
|
||||
# convert to end node by app mode
|
||||
end_node = self._convert_to_end_node()
|
||||
graph = self._append_node(graph, end_node)
|
||||
else:
|
||||
answer_node = self._convert_to_answer_node()
|
||||
graph = self._append_node(graph, answer_node)
|
||||
|
||||
app_model_config_dict = app_config.app_model_config_dict
|
||||
|
||||
# features
|
||||
if new_app_mode == AppMode.ADVANCED_CHAT:
|
||||
features = {
|
||||
"opening_statement": app_model_config_dict.get("opening_statement"),
|
||||
"suggested_questions": app_model_config_dict.get("suggested_questions"),
|
||||
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
|
||||
"speech_to_text": app_model_config_dict.get("speech_to_text"),
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
"retriever_resource": app_model_config_dict.get("retriever_resource"),
|
||||
}
|
||||
else:
|
||||
features = {
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
}
|
||||
match new_app_mode:
|
||||
case AppMode.WORKFLOW:
|
||||
end_node = self._convert_to_end_node()
|
||||
graph = self._append_node(graph, end_node)
|
||||
features = {
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
}
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
answer_node = self._convert_to_answer_node()
|
||||
graph = self._append_node(graph, answer_node)
|
||||
features = {
|
||||
"opening_statement": app_model_config_dict.get("opening_statement"),
|
||||
"suggested_questions": app_model_config_dict.get("suggested_questions"),
|
||||
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
|
||||
"speech_to_text": app_model_config_dict.get("speech_to_text"),
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
"retriever_resource": app_model_config_dict.get("retriever_resource"),
|
||||
}
|
||||
case _:
|
||||
answer_node = self._convert_to_answer_node()
|
||||
graph = self._append_node(graph, answer_node)
|
||||
features = {
|
||||
"text_to_speech": app_model_config_dict.get("text_to_speech"),
|
||||
"file_upload": app_model_config_dict.get("file_upload"),
|
||||
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
|
||||
}
|
||||
|
||||
# create workflow record
|
||||
workflow = Workflow(
|
||||
@@ -220,19 +224,23 @@ class WorkflowConverter:
|
||||
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
|
||||
app_mode_enum = AppMode.value_of(app_model.mode)
|
||||
app_config: EasyUIBasedAppConfig
|
||||
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
app_model.mode = AppMode.AGENT_CHAT
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
elif app_mode_enum == AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
elif app_mode_enum == AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_mode_enum != AppMode.AGENT_CHAT else app_mode_enum
|
||||
)
|
||||
match effective_mode:
|
||||
case AppMode.AGENT_CHAT:
|
||||
app_model.mode = AppMode.AGENT_CHAT
|
||||
app_config = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
case AppMode.CHAT:
|
||||
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
case AppMode.COMPLETION:
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model, app_model_config=app_model_config
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid app mode")
|
||||
|
||||
return app_config
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.entities import PluginCredentialType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
@@ -66,7 +67,6 @@ from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
TriggerNodeLimitExceededError,
|
||||
@@ -635,7 +635,7 @@ class WorkflowService:
|
||||
# If we can't determine the status, assume load balancing is not enabled
|
||||
return False
|
||||
|
||||
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
|
||||
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all load balancing configurations for a model.
|
||||
|
||||
@@ -659,7 +659,7 @@ class WorkflowService:
|
||||
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
|
||||
)
|
||||
all_configs = configs + custom_configs
|
||||
all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs)
|
||||
|
||||
return [config for config in all_configs if config.get("credential_id")]
|
||||
|
||||
@@ -834,7 +834,7 @@ class WorkflowService:
|
||||
if workflow_node_execution is None:
|
||||
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
@@ -1417,16 +1417,17 @@ class WorkflowService:
|
||||
self._validate_human_input_node_data(node_data)
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
match app_model.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
case AppMode.WORKFLOW:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
||||
def _validate_human_input_node_data(self, node_data: dict) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user