Merge branch 'feat/new-biliing-quota' into deploy/dev

This commit is contained in:
hj24
2026-04-08 15:02:50 +08:00
929 changed files with 23722 additions and 15509 deletions

View File

@@ -8,8 +8,8 @@ from hashlib import sha256
from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
class InvitationData(TypedDict):
@@ -83,6 +83,12 @@ from tasks.mail_reset_password_task import (
logger = logging.getLogger(__name__)
class InvitationDetailDict(TypedDict):
account: Account
data: InvitationData
tenant: Tenant
def _try_join_enterprise_default_workspace(account_id: str) -> None:
"""Best-effort join to enterprise default workspace."""
if not dify_config.ENTERPRISE_ENABLED:
@@ -144,22 +150,26 @@ class AccountService:
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.query(Account).filter_by(id=user_id).first()
account = db.session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
)
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if not available_ta:
return None
@@ -195,7 +205,7 @@ class AccountService:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@@ -371,8 +381,10 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
account_integrate: AccountIntegrate | None = db.session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
)
if account_integrate:
@@ -416,7 +428,9 @@ class AccountService:
def update_account_email(account: Account, email: str) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
account_integrate = db.session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
@@ -818,7 +832,7 @@ class AccountService:
)
)
account = db.session.query(Account).where(Account.email == email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@@ -1018,7 +1032,7 @@ class AccountService:
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.query(Account).filter_by(email=email).first() is None
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:
@@ -1061,11 +1075,11 @@ class TenantService:
@staticmethod
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not"""
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if available_ta:
@@ -1096,7 +1110,11 @@ class TenantService:
logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if ta:
ta.role = TenantAccountRole(role)
else:
@@ -1111,11 +1129,12 @@ class TenantService:
@staticmethod
def get_join_tenants(account: Account) -> list[Tenant]:
"""Get account join tenants"""
return (
db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.all()
return list(
db.session.scalars(
select(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
).all()
)
@staticmethod
@@ -1125,7 +1144,11 @@ class TenantService:
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if ta:
tenant.role = ta.role
else:
@@ -1140,23 +1163,25 @@ class TenantService:
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")
tenant_account_join = (
db.session.query(TenantAccountJoin)
tenant_account_join = db.session.scalar(
select(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.where(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
)
.first()
.limit(1)
)
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
db.session.query(TenantAccountJoin).where(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
db.session.execute(
update(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
.values(current=False)
)
tenant_account_join.current = True
# Set the current tenant for the account
account.set_tenant_id(tenant_account_join.tenant_id)
@@ -1165,8 +1190,8 @@ class TenantService:
@staticmethod
def get_tenant_members(tenant: Tenant) -> list[Account]:
"""Get tenant members"""
query = (
db.session.query(Account, TenantAccountJoin.role)
stmt = (
select(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(TenantAccountJoin.tenant_id == tenant.id)
@@ -1175,7 +1200,7 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in query:
for account, role in db.session.execute(stmt):
account.role = role
updated_accounts.append(account)
@@ -1184,8 +1209,8 @@ class TenantService:
@staticmethod
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
"""Get dataset admin members"""
query = (
db.session.query(Account, TenantAccountJoin.role)
stmt = (
select(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(TenantAccountJoin.tenant_id == tenant.id)
@@ -1195,7 +1220,7 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in query:
for account, role in db.session.execute(stmt):
account.role = role
updated_accounts.append(account)
@@ -1208,26 +1233,31 @@ class TenantService:
raise ValueError("all roles must be TenantAccountRole")
return (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
.first()
db.session.scalar(
select(TenantAccountJoin)
.where(
TenantAccountJoin.tenant_id == tenant.id,
TenantAccountJoin.role.in_([role.value for role in roles]),
)
.limit(1)
)
is not None
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first()
.limit(1)
)
return TenantAccountRole(join.role) if join else None
@staticmethod
def get_tenant_count() -> int:
"""Get tenant count"""
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
@@ -1244,7 +1274,11 @@ class TenantService:
if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.")
ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
ta_operator = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
.limit(1)
)
if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.")
@@ -1262,7 +1296,11 @@ class TenantService:
TenantService.check_member_permission(tenant, operator, account, "remove")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta:
raise MemberNotInTenantError("Member not in tenant.")
@@ -1277,7 +1315,12 @@ class TenantService:
should_delete_account = False
if account.status == AccountStatus.PENDING:
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count()
remaining_joins = (
db.session.scalar(
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
)
or 0
)
if remaining_joins == 0:
db.session.delete(account)
should_delete_account = True
@@ -1312,8 +1355,10 @@ class TenantService:
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
target_member_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
)
if not target_member_join:
@@ -1324,8 +1369,10 @@ class TenantService:
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
current_owner_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if current_owner_join:
current_owner_join.role = TenantAccountRole.ADMIN
@@ -1384,10 +1431,10 @@ class RegisterService:
db.session.add(dify_setup)
db.session.commit()
except Exception as e:
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()
logger.exception("Setup account failed, email: %s, name: %s", email, name)
@@ -1469,7 +1516,7 @@ class RegisterService:
check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if not account:
@@ -1488,7 +1535,11 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id)
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not ta:
TenantService.create_tenant_member(tenant, account, role)
@@ -1540,26 +1591,23 @@ class RegisterService:
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
) -> InvitationDetailDict | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
tenant = (
db.session.query(Tenant)
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
tenant = db.session.scalar(
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
)
if not tenant:
return None
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
tenant_account = db.session.execute(
select(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
).first()
if not tenant_account:
return None
@@ -1605,7 +1653,7 @@ class RegisterService:
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
) -> InvitationDetailDict | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation

View File

@@ -32,22 +32,33 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
match app_mode:
case AppMode.CHAT:
match model_mode:
case "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
case "chat":
return cls.get_chat_prompt(
copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
case _:
pass
case AppMode.COMPLETION:
match model_mode:
case "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
case "chat":
return cls.get_chat_prompt(
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
)
case _:
pass
case _:
pass
# default return empty dict
return {}
@@ -73,25 +84,38 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
elif model_mode == "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
match app_mode:
case AppMode.CHAT:
match model_mode:
case "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
case "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
case _:
pass
case AppMode.COMPLETION:
match model_mode:
case "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
case "chat":
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG),
has_context,
baichuan_context_prompt,
)
case _:
pass
case _:
pass
# default return empty dict
return {}

View File

@@ -4,7 +4,9 @@ import uuid
import pandas as pd
logger = logging.getLogger(__name__)
from sqlalchemy import or_, select
from typing import TypedDict
from sqlalchemy import delete, or_, select, update
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -23,15 +25,34 @@ from tasks.annotation.enable_annotation_reply_task import enable_annotation_repl
from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task
class AnnotationJobStatusDict(TypedDict):
job_id: str
job_status: str
class EmbeddingModelDict(TypedDict):
embedding_provider_name: str
embedding_model_name: str
class AnnotationSettingDict(TypedDict):
id: str
enabled: bool
score_threshold: float
embedding_model: EmbeddingModelDict | dict
class AnnotationSettingDisabledDict(TypedDict):
enabled: bool
class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -43,7 +64,9 @@ class AppAnnotationService:
if args.get("message_id"):
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")
@@ -72,7 +95,9 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
assert current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
@@ -85,7 +110,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str):
def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict:
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@@ -109,7 +134,7 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str):
def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict:
_, current_tenant_id = current_account_with_tenant()
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
@@ -128,10 +153,8 @@ class AppAnnotationService:
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -170,20 +193,17 @@ class AppAnnotationService:
"""
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotations = (
db.session.query(MessageAnnotation)
annotations = db.session.scalars(
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
).all()
# Sanitize CSV-injectable fields to prevent formula injection
for annotation in annotations:
@@ -200,10 +220,8 @@ class AppAnnotationService:
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -219,7 +237,9 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -234,16 +254,14 @@ class AppAnnotationService:
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@@ -257,8 +275,8 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if app_annotation_setting:
@@ -276,16 +294,14 @@ class AppAnnotationService:
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@@ -301,8 +317,8 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if app_annotation_setting:
@@ -314,22 +330,19 @@ class AppAnnotationService:
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
# Fetch annotations and their settings in a single query
annotations_to_delete = (
db.session.query(MessageAnnotation, AppAnnotationSetting)
annotations_to_delete = db.session.execute(
select(MessageAnnotation, AppAnnotationSetting)
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
.where(MessageAnnotation.id.in_(annotation_ids))
.all()
)
).all()
if not annotations_to_delete:
return {"deleted_count": 0}
@@ -338,9 +351,9 @@ class AppAnnotationService:
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
# Step 2: Bulk delete hit histories in a single query
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
).delete(synchronize_session=False)
db.session.execute(
delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete))
)
# Step 3: Trigger async tasks for search index deletion
for annotation, annotation_setting in annotations_to_delete:
@@ -350,11 +363,10 @@ class AppAnnotationService:
)
# Step 4: Bulk delete annotations in a single query
deleted_count = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id.in_(annotation_ids_to_delete))
.delete(synchronize_session=False)
delete_result = db.session.execute(
delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete))
)
deleted_count = getattr(delete_result, "rowcount", 0)
db.session.commit()
return {"deleted_count": deleted_count}
@@ -375,10 +387,8 @@ class AppAnnotationService:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -499,16 +509,14 @@ class AppAnnotationService:
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@@ -528,7 +536,7 @@ class AppAnnotationService:
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
return None
@@ -548,8 +556,10 @@ class AppAnnotationService:
score: float,
):
# add hit count to annotation
db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
db.session.execute(
update(MessageAnnotation)
.where(MessageAnnotation.id == annotation_id)
.values(hit_count=MessageAnnotation.hit_count + 1)
)
annotation_hit_history = AppAnnotationHitHistory(
@@ -567,19 +577,19 @@ class AppAnnotationService:
db.session.commit()
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
if collection_binding_detail:
@@ -602,25 +612,25 @@ class AppAnnotationService:
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
def update_app_annotation_setting(
cls, app_id: str, annotation_setting_id: str, args: dict
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation_setting = (
db.session.query(AppAnnotationSetting)
annotation_setting = db.session.scalar(
select(AppAnnotationSetting)
.where(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
.first()
.limit(1)
)
if not annotation_setting:
raise NotFound("App annotation not found")
@@ -653,26 +663,26 @@ class AppAnnotationService:
@classmethod
def clear_all_annotations(cls, app_id: str):
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
# if annotation reply is enabled, delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
for annotation in annotations_query.yield_per(100):
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id == annotation.id
)
for annotation_hit_history in annotation_hit_histories_query.yield_per(100):
annotations_iter = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
).yield_per(100)
for annotation in annotations_iter:
hit_histories_iter = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation.id)
).yield_per(100)
for annotation_hit_history in hit_histories_iter:
db.session.delete(annotation_hit_history)
# if annotation reply is enabled, delete annotation index

View File

@@ -118,139 +118,143 @@ class AppGenerateService:
try:
request_id = rate_limit.enter(request_id)
quota_charge.commit()
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
)
match effective_mode:
case AppMode.COMPLETION:
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
advanced_generator = AdvancedChatAppGenerator()
case AppMode.AGENT_CHAT:
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id,
)
case AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming
),
),
request_id=request_id,
)
case AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
streaming=True,
call_depth=0,
)
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id=request_id,
)
payload_json = payload.model_dump_json()
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
)
),
request_id=request_id,
)
case AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
if streaming:
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
root_node_id=root_node_id,
workflow_run_id=str(uuid.uuid4()),
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
request_id,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
payload.workflow_run_id,
on_subscribe=on_subscribe,
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
),
),
request_id,
)
pause_config = PauseStateLayerConfig(
session_factory=session_factory.get_session_maker(),
state_owner_user_id=workflow.created_by,
)
return rate_limit.generate(
WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=False,
root_node_id=root_node_id,
call_depth=0,
pause_state_config=pause_config,
),
),
request_id,
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
except Exception:
quota_charge.refund()
rate_limit.exit(request_id)
@@ -282,43 +286,73 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
match app_model.mode:
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
case AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_single_loop(
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
):
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
match app_model.mode:
case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
case AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
)
else:
raise ValueError(f"Invalid app mode {app_model.mode}")
case AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode {app_model.mode}")
case _:
raise ValueError(f"Invalid app mode {app_model.mode}")
@classmethod
def generate_more_like_this(

View File

@@ -7,11 +7,12 @@ from models.model import AppMode, AppModelConfigDict
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:
return AgentChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.COMPLETION:
return CompletionAppConfigManager.config_validate(tenant_id, config)
else:
raise ValueError(f"Invalid app mode: {app_mode}")
match app_mode:
case AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
case AppMode.AGENT_CHAT:
return AgentChatAppConfigManager.config_validate(tenant_id, config)
case AppMode.COMPLETION:
return CompletionAppConfigManager.config_validate(tenant_id, config)
case AppMode.WORKFLOW | AppMode.ADVANCED_CHAT | AppMode.CHANNEL | AppMode.RAG_PIPELINE:
raise ValueError(f"Invalid app mode: {app_mode}")

View File

@@ -11,7 +11,7 @@ from typing import Any, Union
from celery.result import AsyncResult
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from enums.quota_type import QuotaType
from extensions.ext_database import db
@@ -244,7 +244,7 @@ class AsyncWorkflowService:
Returns:
Trigger log as dictionary or None if not found
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
@@ -270,7 +270,7 @@ class AsyncWorkflowService:
Returns:
List of trigger logs as dictionaries
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_recent_logs(
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
@@ -293,7 +293,7 @@ class AsyncWorkflowService:
Returns:
List of failed trigger logs as dictionaries
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
logs = trigger_log_repo.get_failed_for_retry(
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit

View File

@@ -1,6 +1,6 @@
import base64
from sqlalchemy import Engine
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
@@ -22,8 +22,8 @@ class AttachmentService:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = self._session_maker(expire_on_commit=False).scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
if not upload_file:
raise NotFound("File not found")

View File

@@ -107,6 +107,124 @@ class BillingInfo(TypedDict):
_billing_info_adapter = TypeAdapter(BillingInfo)
class _TenantFeatureQuota(TypedDict):
usage: int
limit: int
reset_date: NotRequired[int]
class TenantFeatureQuotaInfo(TypedDict):
"""Response of /quota/info.
NOTE (hj24):
- Same convention as BillingInfo: billing may return int fields as str,
always keep non-strict mode to auto-coerce.
"""
trigger_event: _TenantFeatureQuota
api_rate_limit: _TenantFeatureQuota
_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo)
class _BillingQuota(TypedDict):
size: int
limit: int
class _VectorSpaceQuota(TypedDict):
size: float
limit: int
class _KnowledgeRateLimit(TypedDict):
# NOTE (hj24):
# 1. Return for sandbox users but is null for other plans, it's defined but never used.
# 2. Keep it for compatibility for now, can be deprecated in future versions.
size: NotRequired[int]
# NOTE END
limit: int
class _BillingSubscription(TypedDict):
plan: str
interval: str
education: bool
class BillingInfo(TypedDict):
"""Response of /subscription/info.
NOTE (hj24):
- Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python()
- To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter:
1. validate_python in non-strict mode will coerce it to the expected type
2. In strict mode, it will raise ValidationError
3. To preserve compatibility, always keep non-strict mode here and avoid strict mode
"""
enabled: bool
subscription: _BillingSubscription
members: _BillingQuota
apps: _BillingQuota
vector_space: _VectorSpaceQuota
knowledge_rate_limit: _KnowledgeRateLimit
documents_upload_quota: _BillingQuota
annotation_quota_limit: _BillingQuota
docs_processing: str
can_replace_logo: bool
model_load_balancing_enabled: bool
knowledge_pipeline_publish_enabled: bool
next_credit_reset_date: NotRequired[int]
_billing_info_adapter = TypeAdapter(BillingInfo)
class KnowledgeRateLimitDict(TypedDict):
limit: int
subscription_plan: str
class TenantFeaturePlanUsageDict(TypedDict):
result: str
history_id: str
class LangContentDict(TypedDict):
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationDict(TypedDict):
notification_id: str
contents: dict[str, LangContentDict]
frequency: Literal["once", "every_page_load"]
class AccountNotificationDict(TypedDict, total=False):
should_show: bool
notification: NotificationDict
shouldShow: bool
notifications: list[dict]
class UpsertNotificationDict(TypedDict):
notification_id: str
class BatchAddNotificationAccountsDict(TypedDict):
count: int
class DismissNotificationDict(TypedDict):
success: bool
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@@ -133,9 +251,11 @@ class BillingService:
return usage_info
@classmethod
def get_quota_info(cls, tenant_id: str):
def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo:
params = {"tenant_id": tenant_id}
return cls._send_request("GET", "/quota/info", params=params)
return _tenant_feature_quota_info_adapter.validate_python(
cls._send_request("GET", "/quota/info", params=params)
)
@classmethod
def quota_reserve(
@@ -183,7 +303,7 @@ class BillingService:
)
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
@@ -214,7 +334,9 @@ class BillingService:
return cls._send_request("GET", "/invoices", params=params)
@classmethod
def update_tenant_feature_plan_usage(cls, tenant_id: str, feature_key: str, delta: int) -> dict:
def update_tenant_feature_plan_usage(
cls, tenant_id: str, feature_key: str, delta: int
) -> TenantFeaturePlanUsageDict:
"""
Update tenant feature plan usage.
@@ -234,7 +356,7 @@ class BillingService:
)
@classmethod
def refund_tenant_feature_plan_usage(cls, history_id: str) -> dict:
def refund_tenant_feature_plan_usage(cls, history_id: str) -> TenantFeaturePlanUsageDict:
"""
Refund a previous usage charge.
@@ -530,7 +652,7 @@ class BillingService:
return tenant_whitelist
@classmethod
def get_account_notification(cls, account_id: str) -> dict:
def get_account_notification(cls, account_id: str) -> AccountNotificationDict:
"""Return the active in-product notification for account_id, if any.
Calling this endpoint also marks the notification as seen; subsequent
@@ -554,13 +676,13 @@ class BillingService:
@classmethod
def upsert_notification(
cls,
contents: list[dict],
contents: list[LangContentDict],
frequency: str = "once",
status: str = "active",
notification_id: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
) -> dict:
) -> UpsertNotificationDict:
"""Create or update a notification.
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
@@ -581,7 +703,9 @@ class BillingService:
return cls._send_request("POST", "/notifications", json=payload)
@classmethod
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
def batch_add_notification_accounts(
cls, notification_id: str, account_ids: list[str]
) -> BatchAddNotificationAccountsDict:
"""Register target account IDs for a notification (max 1000 per call).
Returns {"count": int}.
@@ -593,7 +717,7 @@ class BillingService:
)
@classmethod
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
def dismiss_notification(cls, notification_id: str, account_id: str) -> DismissNotificationDict:
"""Mark a notification as dismissed for an account.
Returns {"success": bool}.

View File

@@ -346,7 +346,7 @@ class ClearFreePlanTenantExpiredLogs:
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
total_tenant_count = session.query(Tenant.id).count()
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs:
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1)
# Process tenants in this batch
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
# Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count
test_intervals = [

View File

@@ -1,7 +1,7 @@
import logging
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.errors.error import QuotaExceededError
@@ -71,7 +71,7 @@ class CreditPoolService:
actual_credits = min(credits_required, pool.remaining_credits)
try:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
stmt = (
update(TenantCreditPool)
.where(
@@ -81,7 +81,6 @@ class CreditPoolService:
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
)
session.execute(stmt)
session.commit()
except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")

View File

@@ -7,15 +7,15 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, cast
from typing import Any, Literal, TypedDict, cast
import sqlalchemy as sa
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
@@ -107,6 +107,16 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
logger = logging.getLogger(__name__)
class ProcessRulesDict(TypedDict):
mode: str
rules: dict[str, Any]
class AutoDisableLogsDict(TypedDict):
document_ids: list[str]
count: int
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
@@ -114,9 +124,11 @@ class DatasetService:
if user:
# get permitted dataset ids
dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
dataset_permission = db.session.scalars(
select(DatasetPermission).where(
DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id
)
).all()
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@@ -180,21 +192,20 @@ class DatasetService:
return datasets.items, datasets.total
@staticmethod
def get_process_rules(dataset_id):
def get_process_rules(dataset_id) -> ProcessRulesDict:
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
dataset_process_rule = db.session.execute(
select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
)
).scalar_one_or_none()
if dataset_process_rule:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict
rules = dataset_process_rule.rules_dict or {}
else:
mode = DocumentService.DEFAULT_RULES["mode"]
rules = DocumentService.DEFAULT_RULES["rules"]
mode = str(DocumentService.DEFAULT_RULES["mode"])
rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {})
return {"mode": mode, "rules": rules}
@staticmethod
@@ -225,7 +236,7 @@ class DatasetService:
summary_index_setting: dict | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)):
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None
if indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@@ -300,17 +311,17 @@ class DatasetService:
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
if db.session.scalar(
select(Dataset)
.where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id)
.limit(1)
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
else:
# generate a random name as Untitled 1 2 3 ...
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
@@ -344,7 +355,7 @@ class DatasetService:
@staticmethod
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
dataset: Dataset | None = db.session.get(Dataset, dataset_id)
return dataset
@staticmethod
@@ -466,14 +477,14 @@ class DatasetService:
@staticmethod
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
dataset = (
db.session.query(Dataset)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id != dataset_id,
Dataset.name == name,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
return dataset is not None
@@ -540,7 +551,7 @@ class DatasetService:
external_knowledge_id: External knowledge identifier
external_knowledge_api_id: External knowledge API identifier
"""
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
external_knowledge_binding = (
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
)
@@ -548,14 +559,14 @@ class DatasetService:
if not external_knowledge_binding:
raise ValueError("External knowledge binding not found.")
# Update binding if values have changed
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
db.session.add(external_knowledge_binding)
# Update binding if values have changed
if (
external_knowledge_binding.external_knowledge_id != external_knowledge_id
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
):
external_knowledge_binding.external_knowledge_id = external_knowledge_id
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
session.add(external_knowledge_binding)
@staticmethod
def _update_internal_dataset(dataset, data, user):
@@ -596,7 +607,7 @@ class DatasetService:
filtered_data["icon_info"] = data.get("icon_info")
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data))
db.session.commit()
# Reload dataset to get updated values
@@ -631,7 +642,7 @@ class DatasetService:
if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE:
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
pipeline = db.session.get(Pipeline, dataset.pipeline_id)
if not pipeline:
return
@@ -1138,8 +1149,10 @@ class DatasetService:
if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
# For partial team permission, user needs explicit permission or be the creator
if dataset.created_by != user.id:
user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
user_permission = db.session.scalar(
select(DatasetPermission)
.where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id)
.limit(1)
)
if not user_permission:
logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id)
@@ -1161,7 +1174,9 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any(
dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
for dp in db.session.scalars(
select(DatasetPermission).where(DatasetPermission.account_id == user.id)
).all()
):
raise NoPermissionError("You do not have permission to access this dataset.")
@@ -1175,12 +1190,11 @@ class DatasetService:
@staticmethod
def get_related_apps(dataset_id: str):
return (
db.session.query(AppDatasetJoin)
return db.session.scalars(
select(AppDatasetJoin)
.where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
.order_by(AppDatasetJoin.created_at.desc())
).all()
@staticmethod
def update_dataset_api_status(dataset_id: str, status: bool):
@@ -1195,7 +1209,7 @@ class DatasetService:
db.session.commit()
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
def get_dataset_auto_disable_logs(dataset_id: str) -> AutoDisableLogsDict:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
@@ -1396,8 +1410,8 @@ class DocumentService:
@staticmethod
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
if document_id:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
return document
else:
@@ -1626,7 +1640,7 @@ class DocumentService:
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
document = db.session.get(Document, document_id)
return document
@@ -1691,7 +1705,7 @@ class DocumentService:
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
file_detail = db.session.get(UploadFile, file_id)
return file_detail
@staticmethod
@@ -1765,9 +1779,11 @@ class DocumentService:
document.name = name
db.session.add(document)
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
db.session.query(UploadFile).where(
UploadFile.id == document.data_source_info_dict["upload_file_id"]
).update({UploadFile.name: name})
db.session.execute(
update(UploadFile)
.where(UploadFile.id == document.data_source_info_dict["upload_file_id"])
.values(name=name)
)
db.session.commit()
@@ -1854,8 +1870,8 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
document = db.session.scalar(
select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1)
)
if document:
return document.position + 1
@@ -2012,28 +2028,28 @@ class DocumentService:
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
files = (
db.session.query(UploadFile)
.where(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id.in_(upload_file_list),
)
.all()
files = list(
db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id.in_(upload_file_list),
)
).all()
)
if len(files) != len(set(upload_file_list)):
raise FileNotExistsError("One or more files not found.")
file_names = [file.name for file in files]
db_documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
.all()
db_documents = list(
db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
).all()
)
documents_map = {document.name: document for document in db_documents}
for file in files:
@@ -2079,15 +2095,15 @@ class DocumentService:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type=DataSourceType.NOTION_IMPORT,
enabled=True,
)
.all()
documents = list(
db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == DataSourceType.NOTION_IMPORT,
Document.enabled == True,
)
).all()
)
if documents:
for document in documents:
@@ -2518,14 +2534,15 @@ class DocumentService:
assert isinstance(current_user, Account)
documents_count = (
db.session.query(Document)
.where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
db.session.scalar(
select(func.count(Document.id)).where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
)
.count()
or 0
)
return documents_count
@@ -2575,10 +2592,10 @@ class DocumentService:
raise ValueError("No file info list found.")
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
file = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
# raise error if file not found
@@ -2595,8 +2612,8 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding)
.where(
sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
@@ -2605,7 +2622,7 @@ class DocumentService:
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
.limit(1)
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
@@ -2650,8 +2667,10 @@ class DocumentService:
db.session.commit()
# update document segment
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
db.session.execute(
update(DocumentSegment)
.where(DocumentSegment.document_id == document.id)
.values(status=SegmentStatus.RE_SEGMENT)
)
db.session.commit()
# trigger async task
@@ -3143,10 +3162,8 @@ class SegmentService:
lock_name = f"add_segment_lock_document_id_{document.id}"
try:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
max_position = db.session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
@@ -3198,7 +3215,7 @@ class SegmentService:
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
segment = db.session.get(DocumentSegment, segment_document.id)
return segment
except LockNotOwnedError:
pass
@@ -3221,10 +3238,8 @@ class SegmentService:
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
max_position = db.session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
)
pre_segment_data_list = []
segment_data_list = []
@@ -3369,11 +3384,7 @@ class SegmentService:
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
@@ -3391,13 +3402,13 @@ class SegmentService:
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
existing_summary = db.session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
.limit(1)
)
# Check if summary has changed
@@ -3473,11 +3484,7 @@ class SegmentService:
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
@@ -3489,13 +3496,13 @@ class SegmentService:
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
existing_summary = db.session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
.limit(1)
)
if args.summary is None:
@@ -3561,7 +3568,7 @@ class SegmentService:
segment.status = SegmentStatus.ERROR
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
new_segment = db.session.get(DocumentSegment, segment.id)
if not new_segment:
raise ValueError("new_segment is not found")
return new_segment
@@ -3581,15 +3588,14 @@ class SegmentService:
# Get child chunk IDs before parent segment is deleted
child_node_ids = []
if segment.index_node_id:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
.all()
child_node_ids = list(
db.session.scalars(
select(ChildChunk.index_node_id).where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
).all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
@@ -3608,17 +3614,14 @@ class SegmentService:
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
segments_info = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
.where(
segments_info = db.session.execute(
select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.all()
)
).all()
if not segments_info:
return
@@ -3630,15 +3633,16 @@ class SegmentService:
# Get child chunk IDs before parent segments are deleted
child_node_ids = []
if index_node_ids:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
child_node_ids = [
nid
for nid in db.session.scalars(
select(ChildChunk.index_node_id).where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
).all()
if nid
]
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
@@ -3654,7 +3658,7 @@ class SegmentService:
db.session.add(document)
# Delete database records
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)))
db.session.commit()
@classmethod
@@ -3728,15 +3732,13 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(ChildChunk.position))
.where(
max_position = db.session.scalar(
select(func.max(ChildChunk.position)).where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.scalar()
)
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
@@ -3896,10 +3898,8 @@ class SegmentService:
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
result = db.session.scalar(
select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1)
)
return result if isinstance(result, ChildChunk) else None
@@ -3934,10 +3934,10 @@ class SegmentService:
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
result = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
.limit(1)
)
return result if isinstance(result, DocumentSegment) else None
@@ -3980,15 +3980,15 @@ class DatasetCollectionBindingService:
def get_dataset_collection_binding(
cls, provider_name: str, model_name: str, collection_type: str = "dataset"
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
dataset_collection_binding = db.session.scalar(
select(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
.limit(1)
)
if not dataset_collection_binding:
@@ -4006,13 +4006,13 @@ class DatasetCollectionBindingService:
def get_dataset_collection_binding_by_id_and_type(
cls, collection_binding_id: str, collection_type: str = "dataset"
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
dataset_collection_binding = db.session.scalar(
select(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
)
.order_by(DatasetCollectionBinding.created_at)
.first()
.limit(1)
)
if not dataset_collection_binding:
raise ValueError("Dataset collection binding not found")
@@ -4034,7 +4034,7 @@ class DatasetPermissionService:
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try:
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
permissions = []
for user in user_list:
permission = DatasetPermission(
@@ -4070,7 +4070,7 @@ class DatasetPermissionService:
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
db.session.commit()
except Exception as e:
db.session.rollback()

View File

@@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -367,16 +368,16 @@ class DatasourceProviderService:
check if tenant oauth params is enabled
"""
return (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
db.session.scalar(
select(func.count(DatasourceOauthTenantParamConfig.id)).where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
DatasourceOauthTenantParamConfig.enabled == True,
)
)
.count()
> 0
)
or 0
) > 0
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
@@ -384,14 +385,14 @@ class DatasourceProviderService:
"""
get tenant oauth client
"""
tenant_oauth_client_params = (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
tenant_oauth_client_params = db.session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
@@ -707,24 +708,27 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
datasource_providers: list[DatasourceProvider] = list(
db.session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
).all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = db.session.execute(
select(DatasourceProvider.id)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = (
db.session.query(DatasourceProvider.id)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
).first()
default_provider_id = default_provider.id if default_provider else None
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
@@ -880,14 +884,14 @@ class DatasourceProviderService:
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
datasource_providers: list[DatasourceProvider] = list(
db.session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
).all()
)
if not datasource_providers:
return []
@@ -987,10 +991,15 @@ class DatasourceProviderService:
:param plugin_id: plugin id
:return:
"""
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
datasource_provider = db.session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == auth_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.limit(1)
)
if datasource_provider:
db.session.delete(datasource_provider)

View File

@@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from sqlalchemy import case
from sqlalchemy import case, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -25,14 +25,14 @@ class EndUserService:
"""
with Session(db.engine, expire_on_commit=False) as session:
return (
session.query(EndUser)
return session.scalar(
select(EndUser)
.where(
EndUser.id == end_user_id,
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
)
.first()
.limit(1)
)
@classmethod
@@ -57,8 +57,8 @@ class EndUserService:
with Session(db.engine, expire_on_commit=False) as session:
# Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility
# This single query approach is more efficient than separate queries
end_user = (
session.query(EndUser)
end_user = session.scalar(
select(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
@@ -68,7 +68,7 @@ class EndUserService:
# Prioritize records with matching type (0 = match, 1 = no match)
case((EndUser.type == type, 0), else_=1)
)
.first()
.limit(1)
)
if end_user:
@@ -137,15 +137,15 @@ class EndUserService:
with Session(db.engine, expire_on_commit=False) as session:
# Fetch existing end users for all target apps in a single query
existing_end_users: list[EndUser] = (
session.query(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id.in_(unique_app_ids),
EndUser.session_id == user_id,
EndUser.type == type,
)
.all()
existing_end_users: list[EndUser] = list(
session.scalars(
select(EndUser).where(
EndUser.tenant_id == tenant_id,
EndUser.app_id.in_(unique_app_ids),
EndUser.session_id == user_id,
EndUser.type == type,
)
).all()
)
found_app_ids: set[str] = set()

View File

@@ -1,23 +1,15 @@
import enum
import logging
from pydantic import BaseModel
from configs import dify_config
from core.entities import PluginCredentialType
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value
class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str

View File

@@ -0,0 +1,31 @@
from pydantic import BaseModel, Field, field_validator
from libs.helper import EmailStr
from libs.password import valid_password
class LoginPayloadBase(BaseModel):
email: EmailStr
password: str
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)

View File

@@ -1,17 +1,12 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, field_validator
from core.rag.entities import Rule
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class NotionIcon(BaseModel):
type: str
url: str | None = None
@@ -53,24 +48,6 @@ class DataSource(BaseModel):
info_list: InfoList
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Rule | None = None

View File

@@ -2,6 +2,7 @@ from typing import Literal
from pydantic import BaseModel, field_validator
from core.rag.entities import KeywordSetting, VectorSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -36,24 +37,6 @@ class RerankingModelConfig(BaseModel):
reranking_model_name: str | None = ""
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
@@ -63,23 +46,6 @@ class WeightedScoreConfig(BaseModel):
keyword_setting: KeywordSetting | None
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
@@ -95,16 +61,6 @@ class RetrievalSetting(BaseModel):
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class KnowledgeConfiguration(BaseModel):
"""
Knowledge Base Configuration.

View File

@@ -5,11 +5,11 @@ from urllib.parse import urlparse
import httpx
from graphon.nodes.http_request.exc import InvalidHttpMethodError
from sqlalchemy import select
from sqlalchemy import func, select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.entities import MetadataFilteringCondition
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import (
@@ -103,8 +103,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -112,8 +114,10 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -132,8 +136,10 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@@ -144,9 +150,12 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
)
)
or 0
)
if count > 0:
return True, count
@@ -154,8 +163,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
@@ -163,8 +174,10 @@ class ExternalDatasetService:
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("api template not found")
@@ -238,12 +251,17 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
):
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None:
@@ -284,18 +302,20 @@ class ExternalDatasetService:
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: MetadataCondition | None = None,
metadata_condition: MetadataFilteringCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("external api template not found")

View File

@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities import LLMMode
@@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
class QueryDict(TypedDict):
content: str
class RetrieveResponseDict(TypedDict):
query: QueryDict
records: list[dict[str, Any]]
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
@@ -34,7 +44,7 @@ class HitTestingService:
dataset: Dataset,
query: str,
account: Account,
retrieval_model: Any, # FIXME drop this any
retrieval_model: dict | None,
external_retrieval_model: dict,
attachment_ids: list | None = None,
limit: int = 10,
@@ -44,12 +54,13 @@ class HitTestingService:
# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model or default_retrieval_model
assert isinstance(retrieval_model, dict)
document_ids_filter = None
metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
if metadata_filtering_conditions and query:
dataset_retrieval = DatasetRetrieval()
from core.app.app_config.entities import MetadataFilteringCondition
from core.rag.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
@@ -150,7 +161,7 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
records = RetrievalService.format_retrieval_documents(documents)
return {
@@ -161,7 +172,7 @@ class HitTestingService:
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict:
records = []
if dataset.provider == "external":
for document in documents:

View File

@@ -1,6 +1,8 @@
import copy
import logging
from sqlalchemy import delete, func, select
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -25,10 +27,14 @@ class MetadataService:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == metadata_args.name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@@ -54,10 +60,14 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@@ -65,7 +75,11 @@ class MetadataService:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@@ -74,9 +88,9 @@ class MetadataService:
metadata.updated_at = naive_utc_now()
# update related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@@ -101,15 +115,19 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@@ -224,16 +242,23 @@ class MetadataService:
# deal metadata binding (in the same transaction as the doc_metadata update)
if not operation.partial_update:
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
db.session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.document_id == operation.document_id
)
)
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
# check if binding already exists
if operation.partial_update:
existing_binding = (
db.session.query(DatasetMetadataBinding)
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
.first()
existing_binding = db.session.scalar(
select(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.document_id == operation.document_id,
DatasetMetadataBinding.metadata_id == metadata_value.id,
)
.limit(1)
)
if existing_binding:
continue
@@ -275,9 +300,13 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
"count": db.session.scalar(
select(func.count(DatasetMetadataBinding.id)).where(
DatasetMetadataBinding.metadata_id == item.get("id"),
DatasetMetadataBinding.dataset_id == dataset.id,
)
)
or 0,
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Union
from typing import Any, TypedDict, Union
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -25,6 +25,23 @@ from models.provider import LoadBalancingModelConfig, ProviderCredential, Provid
logger = logging.getLogger(__name__)
class LoadBalancingConfigDetailDict(TypedDict):
id: str
name: str
credentials: dict[str, Any]
enabled: bool
class LoadBalancingConfigSummaryDict(TypedDict):
id: str
name: str
credentials: dict[str, Any]
credential_id: str | None
enabled: bool
in_cooldown: bool
ttl: int
class ModelLoadBalancingService:
@staticmethod
def _get_provider_manager(tenant_id: str) -> ProviderManager:
@@ -74,7 +91,7 @@ class ModelLoadBalancingService:
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
) -> tuple[bool, list[LoadBalancingConfigSummaryDict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
@@ -156,7 +173,7 @@ class ModelLoadBalancingService:
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
# fetch status and ttl for each config
datas = []
datas: list[LoadBalancingConfigSummaryDict] = []
for load_balancing_config in load_balancing_configs:
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
tenant_id=tenant_id,
@@ -214,7 +231,7 @@ class ModelLoadBalancingService:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> dict | None:
) -> LoadBalancingConfigDetailDict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@@ -267,12 +284,13 @@ class ModelLoadBalancingService:
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
result: LoadBalancingConfigDetailDict = {
"id": load_balancing_model_config.id,
"name": load_balancing_model_config.name,
"credentials": credentials,
"enabled": load_balancing_model_config.enabled,
}
return result
def _init_inherit_config(
self, tenant_id: str, provider: str, model: str, model_type: ModelType

View File

@@ -2,7 +2,7 @@ import enum
import uuid
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest
from extensions.ext_database import db
@@ -29,7 +29,7 @@ class OAuthServerService:
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
with Session(db.engine) as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return session.execute(query).scalar_one_or_none()
@staticmethod

View File

@@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, TypedDict
from typing import TypedDict
from uuid import uuid4
import click
@@ -42,6 +42,16 @@ class _TenantPluginRecord(TypedDict):
_tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord)
class ExtractedPluginsDict(TypedDict):
plugins: dict[str, str]
plugin_not_exist: list[str]
class PluginInstallResultDict(TypedDict):
success: list[str]
failed: list[str]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int):
@@ -310,7 +320,7 @@ class PluginMigration:
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
@classmethod
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
def extract_unique_plugins(cls, extracted_plugins: str) -> ExtractedPluginsDict:
plugins: dict[str, str] = {}
plugin_ids = []
plugin_not_exist = []
@@ -524,7 +534,7 @@ class PluginMigration:
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
) -> Mapping[str, Any]:
) -> PluginInstallResultDict:
"""
Install plugins for a tenant.
"""

View File

@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
@@ -56,24 +57,24 @@ class PluginParameterService:
# fetch credentials from db
with Session(db.engine) as session:
if credential_id:
db_record = (
session.query(BuiltinToolProvider)
db_record = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.id == credential_id,
)
.first()
.limit(1)
)
else:
db_record = (
session.query(BuiltinToolProvider)
db_record = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
.limit(1)
)
if db_record is None:

View File

@@ -38,11 +38,7 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.helper import marketplace
from core.rag.entities.event import (
DatasourceCompletedEvent,
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.rag.entities import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping
@@ -156,27 +152,27 @@ class RagPipelineService:
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
@@ -192,13 +188,13 @@ class RagPipelineService:
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
@@ -210,14 +206,14 @@ class RagPipelineService:
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
.limit(1)
)
# return draft workflow
@@ -232,28 +228,28 @@ class RagPipelineService:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
.limit(1)
)
return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations."""
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
@@ -974,7 +970,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
document = db.session.get(Document, document_id.value)
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = error
@@ -1178,15 +1174,15 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
dataset = pipeline.retrieve_dataset(session=session)
if not dataset:
raise ValueError("Dataset not found")
@@ -1194,26 +1190,26 @@ class RagPipelineService:
# check template name is exist
template_name = args.get("name")
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
max_position = db.session.scalar(
select(func.max(PipelineCustomizedTemplate.position)).where(
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
)
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
if args.get("icon_info") is None:
@@ -1239,13 +1235,14 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
db.session.scalar(
select(func.count(Workflow.id)).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
)
.count()
or 0
) > 0
def get_node_last_run(
@@ -1353,11 +1350,11 @@ class RagPipelineService:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
if not pipeline_recommended_plugins:
return {
@@ -1396,14 +1393,12 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_execution_log = (
db.session.query(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
document_pipeline_execution_log = db.session.scalar(
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
)
if not document_pipeline_execution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@@ -1432,23 +1427,23 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")
@@ -1530,23 +1525,23 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")

View File

@@ -3,7 +3,7 @@ import logging
import random
import time
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, TypedDict, cast
import sqlalchemy as sa
from sqlalchemy import delete, select, tuple_
@@ -158,6 +158,13 @@ class MessagesCleanupMetrics:
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
class MessagesCleanStatsDict(TypedDict):
batches: int
total_messages: int
filtered_messages: int
total_deleted: int
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
@@ -299,7 +306,7 @@ class MessagesCleanService:
task_label=task_label,
)
def run(self) -> dict[str, int]:
def run(self) -> MessagesCleanStatsDict:
"""
Execute the message cleanup operation.
@@ -319,7 +326,7 @@ class MessagesCleanService:
job_duration_seconds=time.monotonic() - run_start,
)
def _clean_messages_by_time_range(self) -> dict[str, int]:
def _clean_messages_by_time_range(self) -> MessagesCleanStatsDict:
"""
Clean messages within a time range using cursor-based pagination.
@@ -334,7 +341,7 @@ class MessagesCleanService:
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
stats = {
stats: MessagesCleanStatsDict = {
"batches": 0,
"total_messages": 0,
"filtered_messages": 0,

View File

@@ -24,7 +24,7 @@ import zipfile
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypedDict
import click
from graphon.enums import WorkflowType
@@ -49,6 +49,23 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHI
logger = logging.getLogger(__name__)
class TableStatsManifestEntry(TypedDict):
row_count: int
checksum: str
size_bytes: int
class ArchiveManifestDict(TypedDict):
schema_version: str
workflow_run_id: str
tenant_id: str
app_id: str
workflow_id: str
created_at: str
archived_at: str
tables: dict[str, TableStatsManifestEntry]
@dataclass
class TableStats:
"""Statistics for a single archived table."""
@@ -472,25 +489,26 @@ class WorkflowRunArchiver:
self,
run: WorkflowRun,
table_stats: list[TableStats],
) -> dict[str, Any]:
) -> ArchiveManifestDict:
"""Generate a manifest for the archived workflow run."""
return {
"schema_version": ARCHIVE_SCHEMA_VERSION,
"workflow_run_id": run.id,
"tenant_id": run.tenant_id,
"app_id": run.app_id,
"workflow_id": run.workflow_id,
"created_at": run.created_at.isoformat(),
"archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
"tables": {
stat.table_name: {
"row_count": stat.row_count,
"checksum": stat.checksum,
"size_bytes": stat.size_bytes,
}
for stat in table_stats
},
tables: dict[str, TableStatsManifestEntry] = {
stat.table_name: {
"row_count": stat.row_count,
"checksum": stat.checksum,
"size_bytes": stat.size_bytes,
}
for stat in table_stats
}
return ArchiveManifestDict(
schema_version=ARCHIVE_SCHEMA_VERSION,
workflow_run_id=run.id,
tenant_id=run.tenant_id,
app_id=run.app_id,
workflow_id=run.workflow_id,
created_at=run.created_at.isoformat(),
archived_at=datetime.datetime.now(datetime.UTC).isoformat(),
tables=tables,
)
def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
buffer = io.BytesIO()

View File

@@ -3,7 +3,7 @@ import logging
import random
import time
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict
import click
from sqlalchemy.orm import Session, sessionmaker
@@ -12,7 +12,7 @@ from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.billing_service import BillingService, SubscriptionPlan
@@ -24,6 +24,15 @@ if TYPE_CHECKING:
from opentelemetry.metrics import Counter, Histogram
class RelatedCountsDict(TypedDict):
node_executions: int
offloads: int
app_logs: int
trigger_logs: int
pauses: int
pause_reasons: int
class WorkflowRunCleanupMetrics:
"""
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
@@ -173,6 +182,9 @@ class WorkflowRunCleanupMetrics:
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
_RELATED_RECORD_KEYS = ("node_executions", "offloads", "app_logs", "trigger_logs", "pauses", "pause_reasons")
class WorkflowRunCleanup:
def __init__(
self,
@@ -230,7 +242,7 @@ class WorkflowRunCleanup:
total_runs_deleted = 0
total_runs_targeted = 0
related_totals = self._empty_related_counts() if self.dry_run else None
related_totals: RelatedCountsDict | None = self._empty_related_counts() if self.dry_run else None
batch_index = 0
last_seen: tuple[datetime.datetime, str] | None = None
status = "success"
@@ -312,8 +324,7 @@ class WorkflowRunCleanup:
int((time.monotonic() - count_start) * 1000),
)
if related_totals is not None:
for key in related_totals:
related_totals[key] += batch_counts.get(key, 0)
self._accumulate_related_counts(related_totals, batch_counts)
sample_ids = ", ".join(run.id for run in free_runs[:5])
click.echo(
click.style(
@@ -332,7 +343,10 @@ class WorkflowRunCleanup:
targeted_runs=len(free_runs),
skipped_runs=paid_or_skipped,
deleted_runs=0,
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
related_counts={
k: batch_counts[k] # type: ignore[literal-required]
for k in _RELATED_RECORD_KEYS
},
related_action="would_delete",
batch_duration_seconds=time.monotonic() - batch_start,
)
@@ -372,7 +386,10 @@ class WorkflowRunCleanup:
targeted_runs=len(free_runs),
skipped_runs=paid_or_skipped,
deleted_runs=counts["runs"],
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
related_counts={
k: counts[k] # type: ignore[literal-required]
for k in _RELATED_RECORD_KEYS
},
related_action="deleted",
batch_duration_seconds=time.monotonic() - batch_start,
)
@@ -506,7 +523,7 @@ class WorkflowRunCleanup:
return trigger_repo.count_by_run_ids(run_ids)
@staticmethod
def _empty_related_counts() -> dict[str, int]:
def _empty_related_counts() -> RelatedCountsDict:
return {
"node_executions": 0,
"offloads": 0,
@@ -517,7 +534,7 @@ class WorkflowRunCleanup:
}
@staticmethod
def _format_related_counts(counts: dict[str, int]) -> str:
def _format_related_counts(counts: RelatedCountsDict) -> str:
return (
f"node_executions {counts['node_executions']}, "
f"offloads {counts['offloads']}, "
@@ -527,6 +544,15 @@ class WorkflowRunCleanup:
f"pause_reasons {counts['pause_reasons']}"
)
@staticmethod
def _accumulate_related_counts(totals: RelatedCountsDict, batch: RunsWithRelatedCountsDict) -> None:
totals["node_executions"] += batch.get("node_executions", 0)
totals["offloads"] += batch.get("offloads", 0)
totals["app_logs"] += batch.get("app_logs", 0)
totals["trigger_logs"] += batch.get("trigger_logs", 0)
totals["pauses"] += batch.get("pauses", 0)
totals["pause_reasons"] += batch.get("pause_reasons", 0)
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(

View File

@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
@@ -23,7 +23,17 @@ class DeleteResult:
run_id: str
tenant_id: str
success: bool
deleted_counts: dict[str, int] = field(default_factory=dict)
deleted_counts: RunsWithRelatedCountsDict = field(
default_factory=lambda: { # type: ignore[assignment]
"runs": 0,
"node_executions": 0,
"offloads": 0,
"app_logs": 0,
"trigger_logs": 0,
"pauses": 0,
"pause_reasons": 0,
}
)
error: str | None = None
elapsed_time: float = 0.0

View File

@@ -4,7 +4,7 @@ import logging
import time
import uuid
from datetime import UTC, datetime
from typing import Any
from typing import TypedDict, cast
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
@@ -25,6 +25,22 @@ from models.enums import SummaryStatus
logger = logging.getLogger(__name__)
class SummaryEntryDict(TypedDict):
segment_id: str
segment_position: int
status: str
summary_preview: str | None
error: str | None
created_at: int | None
updated_at: int | None
class DocumentSummaryStatusDetailDict(TypedDict):
total_segments: int
summary_status: dict[str, int]
summaries: list[SummaryEntryDict]
class SummaryIndexService:
"""Service for generating and managing summary indexes."""
@@ -1352,7 +1368,7 @@ class SummaryIndexService:
def get_document_summary_status_detail(
document_id: str,
dataset_id: str,
) -> dict[str, Any]:
) -> DocumentSummaryStatusDetailDict:
"""
Get detailed summary status for a document.
@@ -1403,7 +1419,7 @@ class SummaryIndexService:
SummaryStatus.NOT_STARTED: 0,
}
summary_list = []
summary_list: list[SummaryEntryDict] = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
@@ -1438,8 +1454,8 @@ class SummaryIndexService:
}
)
return {
"total_segments": total_segments,
"summary_status": status_counts,
"summaries": summary_list,
}
return DocumentSummaryStatusDetailDict(
total_segments=total_segments,
summary_status=cast(dict[str, int], status_counts),
summaries=summary_list,
)

View File

@@ -2,6 +2,7 @@ import uuid
import sqlalchemy as sa
from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
@@ -11,6 +12,28 @@ from models.enums import TagType
from models.model import App, Tag, TagBinding
class SaveTagPayload(BaseModel):
name: str = Field(min_length=1, max_length=50)
type: TagType
class UpdateTagPayload(BaseModel):
name: str = Field(min_length=1, max_length=50)
type: TagType
class TagBindingCreatePayload(BaseModel):
tag_ids: list[str]
target_id: str
type: TagType
class TagBindingDeletePayload(BaseModel):
tag_id: str
target_id: str
type: TagType
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
@@ -78,12 +101,12 @@ class TagService:
return tags or []
@staticmethod
def save_tags(args: dict) -> Tag:
if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
def save_tags(payload: SaveTagPayload) -> Tag:
if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name):
raise ValueError("Tag name already exists")
tag = Tag(
name=args["name"],
type=TagType(args["type"]),
name=payload.name,
type=TagType(payload.type),
created_by=current_user.id,
tenant_id=current_user.current_tenant_id,
)
@@ -93,13 +116,24 @@ class TagService:
return tag
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag:
tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1))
if not tag:
raise NotFound("Tag not found")
tag.name = args["name"]
if payload.name != tag.name:
existing = db.session.scalar(
select(Tag)
.where(
Tag.name == payload.name,
Tag.tenant_id == current_user.current_tenant_id,
Tag.type == tag.type,
Tag.id != tag_id,
)
.limit(1)
)
if existing:
raise ValueError("Tag name already exists")
tag.name = payload.name
db.session.commit()
return tag
@@ -122,21 +156,19 @@ class TagService:
db.session.commit()
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args["type"], args["target_id"])
# save tag binding
for tag_id in args["tag_ids"]:
def save_tag_binding(payload: TagBindingCreatePayload):
TagService.check_target_exists(payload.type, payload.target_id)
for tag_id in payload.tag_ids:
tag_binding = db.session.scalar(
select(TagBinding)
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id)
.limit(1)
)
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args["target_id"],
target_id=payload.target_id,
tenant_id=current_user.current_tenant_id,
created_by=current_user.id,
)
@@ -144,17 +176,15 @@ class TagService:
db.session.commit()
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args["type"], args["target_id"])
# delete tag binding
tag_bindings = db.session.scalar(
def delete_tag_binding(payload: TagBindingDeletePayload):
TagService.check_target_exists(payload.type, payload.target_id)
tag_binding = db.session.scalar(
select(TagBinding)
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"])
.where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id)
.limit(1)
)
if tag_bindings:
db.session.delete(tag_bindings)
if tag_binding:
db.session.delete(tag_binding)
db.session.commit()
@staticmethod

View File

@@ -285,7 +285,7 @@ class MCPToolManageService:
# Batch query all users to avoid N+1 problem
user_ids = {provider.user_id for provider in mcp_providers}
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
users = self._session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
user_name_map = {user.id: user.name for user in users}
return [

View File

@@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, Union
@@ -21,6 +20,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolParameter,
ToolProviderType,
emoji_icon_adapter,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
@@ -53,11 +53,14 @@ class ToolTransformService:
elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}:
try:
if isinstance(icon, str):
return json.loads(icon)
return icon
except (json.JSONDecodeError, ValueError):
parsed = emoji_icon_adapter.validate_json(icon)
return {"background": parsed["background"], "content": parsed["content"]}
return {"background": icon["background"], "content": icon["content"]}
except (ValueError, ValidationError, KeyError):
return {"background": "#252525", "content": "\ud83d\ude01"}
elif provider_type == ToolProviderType.MCP:
if isinstance(icon, Mapping):
return {"background": icon.get("background", ""), "content": icon.get("content", "")}
return icon
return ""

View File

@@ -3,12 +3,12 @@ import logging
from datetime import datetime
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration, emoji_icon_adapter
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
@@ -42,20 +42,22 @@ class WorkflowToolManageService:
labels: list[str] | None = None,
):
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
@@ -122,30 +124,30 @@ class WorkflowToolManageService:
:return: the updated tool
"""
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
@@ -234,9 +236,11 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
db.session.commit()
@@ -251,10 +255,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -267,10 +271,10 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -284,8 +288,8 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
workflow_app: App | None = db.session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
if workflow_app is None:
@@ -309,7 +313,7 @@ class WorkflowToolManageService:
"label": db_tool.label,
"workflow_tool_id": db_tool.id,
"workflow_app_id": db_tool.app_id,
"icon": json.loads(db_tool.icon),
"icon": emoji_icon_adapter.validate_json(db_tool.icon),
"description": db_tool.description,
"parameters": jsonable_encoder(db_tool.parameter_configurations),
"output_schema": output_schema,
@@ -331,10 +335,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if db_tool is None:

View File

@@ -3,7 +3,7 @@ import logging
import mimetypes
import secrets
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, NotRequired, TypedDict
import orjson
from flask import request
@@ -51,6 +51,26 @@ logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class RawWebhookDataDict(TypedDict):
method: str
headers: dict[str, str]
query_params: dict[str, str]
body: dict[str, Any]
files: dict[str, Any]
class ValidationResultDict(TypedDict):
valid: bool
error: NotRequired[str]
class WorkflowInputsDict(TypedDict):
webhook_data: RawWebhookDataDict
webhook_headers: dict[str, str]
webhook_query_params: dict[str, str]
webhook_body: dict[str, Any]
class WebhookService:
"""Service for handling webhook operations."""
@@ -146,7 +166,7 @@ class WebhookService:
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
) -> dict[str, Any]:
) -> RawWebhookDataDict:
"""Extract and validate webhook data in a single unified process.
Args:
@@ -166,7 +186,7 @@ class WebhookService:
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
validation_result = cls._validate_http_metadata(raw_data, node_data)
if not validation_result["valid"]:
raise ValueError(validation_result["error"])
raise ValueError(validation_result.get("error", "Validation failed"))
# Process and validate data according to configuration
processed_data = cls._process_and_validate_data(raw_data, node_data)
@@ -174,7 +194,7 @@ class WebhookService:
return processed_data
@classmethod
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict:
"""Extract raw data from incoming webhook request without type conversion.
Args:
@@ -190,7 +210,7 @@ class WebhookService:
"""
cls._validate_content_length()
data = {
data: RawWebhookDataDict = {
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
@@ -224,7 +244,7 @@ class WebhookService:
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict:
"""Process and validate webhook data according to node configuration.
Args:
@@ -665,7 +685,7 @@ class WebhookService:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> ValidationResultDict:
"""Validate HTTP method and content-type.
Args:
@@ -709,7 +729,7 @@ class WebhookService:
return content_type.split(";")[0].strip()
@classmethod
def _validation_error(cls, error_message: str) -> dict[str, Any]:
def _validation_error(cls, error_message: str) -> ValidationResultDict:
"""Create a standard validation error response.
Args:
@@ -730,7 +750,7 @@ class WebhookService:
return False
@classmethod
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> WorkflowInputsDict:
"""Construct workflow inputs payload from webhook data.
Args:
@@ -748,7 +768,7 @@ class WebhookService:
@classmethod
def trigger_workflow_execution(
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow
) -> None:
"""Trigger workflow execution via AsyncWorkflowService.

View File

@@ -6,6 +6,7 @@ from sqlalchemy import delete, select
from core.model_manager import ModelInstance, ModelManager
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.entities import ParentMode
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
@@ -15,7 +16,6 @@ from extensions.ext_database import db
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
logger = logging.getLogger(__name__)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import datetime
import json
from dataclasses import dataclass
from typing import Any
from typing import Any, NotRequired, TypedDict, cast
import httpx
from flask_login import current_user
@@ -126,6 +126,15 @@ class WebsiteCrawlStatusApiRequest:
return cls(provider=provider, job_id=job_id)
class CrawlStatusDict(TypedDict):
status: str
job_id: str
total: int
current: int
data: list[Any]
time_consuming: NotRequired[str | float]
class WebsiteService:
"""Service class for website crawling operations using different providers."""
@@ -261,13 +270,13 @@ class WebsiteService:
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
def get_crawl_status(cls, job_id: str, provider: str) -> CrawlStatusDict:
"""Get crawl status using string parameters."""
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
return cls.get_crawl_status_typed(api_request)
@classmethod
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> CrawlStatusDict:
"""Get crawl status using typed request."""
api_key, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
@@ -281,10 +290,10 @@ class WebsiteService:
raise ValueError("Invalid provider")
@classmethod
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id)
crawl_status_data: dict[str, Any] = {
crawl_status_data: CrawlStatusDict = {
"status": result["status"],
"job_id": job_id,
"total": result["total"] or 0,
@@ -302,18 +311,18 @@ class WebsiteService:
return crawl_status_data
@classmethod
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]:
return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict:
return cast(CrawlStatusDict, dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)))
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
def _get_jinareader_status(cls, job_id: str, api_key: str) -> CrawlStatusDict:
response = _adaptive_http_client.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
data = response.json().get("data", {})
crawl_status_data = {
crawl_status_data: CrawlStatusDict = {
"status": data.get("status", "active"),
"job_id": job_id,
"total": len(data.get("urls", [])),

View File

@@ -170,34 +170,38 @@ class WorkflowConverter:
graph = self._append_node(graph, llm_node)
if new_app_mode == AppMode.WORKFLOW:
# convert to end node by app mode
end_node = self._convert_to_end_node()
graph = self._append_node(graph, end_node)
else:
answer_node = self._convert_to_answer_node()
graph = self._append_node(graph, answer_node)
app_model_config_dict = app_config.app_model_config_dict
# features
if new_app_mode == AppMode.ADVANCED_CHAT:
features = {
"opening_statement": app_model_config_dict.get("opening_statement"),
"suggested_questions": app_model_config_dict.get("suggested_questions"),
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
"speech_to_text": app_model_config_dict.get("speech_to_text"),
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
"retriever_resource": app_model_config_dict.get("retriever_resource"),
}
else:
features = {
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
}
match new_app_mode:
case AppMode.WORKFLOW:
end_node = self._convert_to_end_node()
graph = self._append_node(graph, end_node)
features = {
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
}
case AppMode.ADVANCED_CHAT:
answer_node = self._convert_to_answer_node()
graph = self._append_node(graph, answer_node)
features = {
"opening_statement": app_model_config_dict.get("opening_statement"),
"suggested_questions": app_model_config_dict.get("suggested_questions"),
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
"speech_to_text": app_model_config_dict.get("speech_to_text"),
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
"retriever_resource": app_model_config_dict.get("retriever_resource"),
}
case _:
answer_node = self._convert_to_answer_node()
graph = self._append_node(graph, answer_node)
features = {
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
}
# create workflow record
workflow = Workflow(
@@ -220,19 +224,23 @@ class WorkflowConverter:
def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
elif app_mode_enum == AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
elif app_mode_enum == AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
else:
raise ValueError("Invalid app mode")
effective_mode = (
AppMode.AGENT_CHAT if app_model.is_agent and app_mode_enum != AppMode.AGENT_CHAT else app_mode_enum
)
match effective_mode:
case AppMode.AGENT_CHAT:
app_model.mode = AppMode.AGENT_CHAT
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
case AppMode.CHAT:
app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
case AppMode.COMPLETION:
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
case _:
raise ValueError("Invalid app mode")
return app_config

View File

@@ -38,6 +38,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.file_access import DatabaseFileAccessController
from core.entities import PluginCredentialType
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
@@ -66,7 +67,6 @@ from models.tools import WorkflowToolProvider
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import (
IsDraftWorkflowError,
TriggerNodeLimitExceededError,
@@ -635,7 +635,7 @@ class WorkflowService:
# If we can't determine the status, assume load balancing is not enabled
return False
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict[str, Any]]:
"""
Get all load balancing configurations for a model.
@@ -659,7 +659,7 @@ class WorkflowService:
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = configs + custom_configs
all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs)
return [config for config in all_configs if config.get("credential_id")]
@@ -834,7 +834,7 @@ class WorkflowService:
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin():
@@ -1417,16 +1417,17 @@ class WorkflowService:
self._validate_human_input_node_data(node_data)
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
else:
raise ValueError(f"Invalid app mode: {app_model.mode}")
match app_model.mode:
case AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
case AppMode.WORKFLOW:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
case _:
raise ValueError(f"Invalid app mode: {app_model.mode}")
def _validate_human_input_node_data(self, node_data: dict) -> None:
"""