Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-15 10:20:27 +08:00
691 changed files with 7963 additions and 6893 deletions

View File

@@ -62,7 +62,7 @@ jobs:
- name: Render coverage markdown from structured data
id: render
run: |
comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"

View File

@@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE=
REDIS_SSL_KEYFILE=
# Path to client private key file for SSL authentication
REDIS_DB=0
# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts.
# Leave empty to preserve current unprefixed behavior.
REDIS_KEY_PREFIX=
# redis Sentinel configuration.
REDIS_USE_SENTINEL=false

View File

@@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings):
default="",
)
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
description=(
"PostgreSQL session timezone override injected via startup options."
" Default is 'UTC' for out-of-the-box consistency."
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
" together with a database-side default timezone."
),
default="UTC",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
@@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings):
connect_args: dict[str, str] = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
merged_options = options.strip()
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
if session_timezone_override:
timezone_opt = f"-c timezone={session_timezone_override}"
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
if merged_options:
connect_args = {"options": merged_options}
result: SQLAlchemyEngineOptionsDict = {
"pool_size": self.SQLALCHEMY_POOL_SIZE,

View File

@@ -32,6 +32,11 @@ class RedisConfig(BaseSettings):
default=0,
)
REDIS_KEY_PREFIX: str = Field(
description="Optional global prefix for Redis keys, topics, and transport artifacts",
default="",
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
default=False,

View File

@@ -6,7 +6,6 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
@@ -31,6 +30,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
@@ -161,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
return value
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
class Tag(ResponseModel):
id: str
name: str
@@ -292,7 +283,7 @@ class Site(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
@field_validator("icon_type", mode="before")
@classmethod
@@ -342,7 +333,7 @@ class AppPartial(ResponseModel):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
@@ -390,7 +381,7 @@ class AppDetailWithSite(AppDetail):
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
return build_icon_url(self.icon_type, self.icon)
class AppPagination(ResponseModel):

View File

@@ -1,44 +1,86 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import (
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
return value
try:
return serialize_value_type(value)
except Exception:
return serialize_value_type({"value_type": value})
@field_validator("value", mode="before")
@classmethod
def _normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class PaginatedConversationVariableResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
console_ns,
ConversationVariablesQuery,
ConversationVariableResponse,
PaginatedConversationVariableResponse,
)
@@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource):
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@console_ns.response(
200,
"Conversation variables retrieved successfully",
console_ns.models[PaginatedConversationVariableResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
rows = session.scalars(stmt).all()
return {
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
for row in rows
],
}
response = PaginatedConversationVariableResponse.model_validate(
{
"page": page,
"limit": page_size,
"total": len(rows),
"has_more": False,
"data": [
ConversationVariableResponse.model_validate(
{
"created_at": row.created_at,
"updated_at": row.updated_at,
**row.to_variable().model_dump(),
}
)
for row in rows
],
}
)
return response.model_dump(mode="json")

View File

@@ -1,8 +1,9 @@
import logging
from datetime import datetime
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
@@ -25,10 +26,21 @@ from controllers.console.wraps import (
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from fields.base import ResponseModel
from fields.conversation_fields import (
AgentThought,
ConversationAnnotation,
ConversationAnnotationHitHistory,
Feedback,
JSONValue,
MessageFile,
format_files_contained,
to_timestamp,
)
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
@@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
class MessageDetailResponse(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue | None = None
message_tokens: int | None = None
answer: str = Field(validation_alias="re_sign_file_url_answer")
answer_tokens: int | None = None
provider_response_latency: float | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback] = Field(default_factory=list)
workflow_run_id: str | None = None
annotation: ConversationAnnotation | None = None
annotation_hit_history: ConversationAnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought] = Field(default_factory=list)
message_files: list[MessageFile] = Field(default_factory=list)
extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list)
metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict")
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class MessageInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[MessageDetailResponse]
register_schema_models(
console_ns,
ChatMessagesQuery,
@@ -105,124 +162,8 @@ register_schema_models(
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Message infinite scroll pagination model
message_infinite_scroll_pagination_model = console_ns.model(
"MessageInfiniteScrollPagination",
{
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_model)),
},
MessageDetailResponse,
MessageInfiniteScrollPaginationResponse,
)
@@ -232,13 +173,12 @@ class ChatMessageListApi(Resource):
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__])
@console_ns.response(404, "Conversation not found")
@login_required
@account_initialization_required
@setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
@@ -298,7 +238,10 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
return MessageInfiniteScrollPaginationResponse.model_validate(
InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@@ -468,13 +411,12 @@ class MessageApi(Resource):
@console_ns.doc("get_message")
@console_ns.doc(description="Get message details by ID")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
@console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__])
@console_ns.response(404, "Message not found")
@get_app_model
@setup_required
@login_required
@account_initialization_required
@marshal_with(message_detail_model)
def get(self, app_model, message_id: str):
message_id = str(message_id)
@@ -486,4 +428,4 @@ class MessageApi(Resource):
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return message
return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json")

View File

@@ -1,27 +1,26 @@
from datetime import datetime
from typing import Any
from dateutil.parser import isoparse
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.workflow_app_log_fields import (
build_workflow_app_log_pagination_model,
build_workflow_archived_log_pagination_model,
)
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs.login import login_required
from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
@@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel):
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowRunForArchivedLogResponse(ResponseModel):
id: str
status: str | None = None
triggered_from: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
@field_validator("status", mode="before")
@classmethod
def _normalize_status(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: Any = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowArchivedLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForArchivedLogResponse | None = None
trigger_metadata: Any = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
class WorkflowArchivedLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowArchivedLogPartialResponse]
register_schema_models(
console_ns,
WorkflowAppLogQuery,
WorkflowRunForLogResponse,
WorkflowRunForArchivedLogResponse,
WorkflowAppLogPartialResponse,
WorkflowArchivedLogPartialResponse,
WorkflowAppLogPaginationResponse,
WorkflowArchivedLogPaginationResponse,
)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
@@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@console_ns.response(
200,
"Workflow app logs retrieved successfully",
console_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow app logs
@@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
return WorkflowAppLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
@@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource):
@console_ns.doc(description="Get workflow archived execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
@console_ns.response(
200,
"Workflow archived logs retrieved successfully",
console_ns.models[WorkflowArchivedLogPaginationResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_archived_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow archived logs
@@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource):
limit=args.limit,
)
return workflow_app_log_pagination
return WorkflowArchivedLogPaginationResponse.model_validate(
workflow_app_log_pagination, from_attributes=True
).model_dump(mode="json")

View File

@@ -1,16 +1,17 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from flask_restx import Resource
from pydantic import BaseModel, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from fields.base import ResponseModel
from libs.login import current_user, login_required
from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode
@@ -21,15 +22,6 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
triggers_list_fields_copy = triggers_list_fields.copy()
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
class Parser(BaseModel):
@@ -41,10 +33,52 @@ class ParserEnable(BaseModel):
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class WorkflowTriggerResponse(ResponseModel):
id: str
trigger_type: str
title: str
node_id: str
provider_name: str
icon: str
status: str
created_at: datetime | None = None
updated_at: datetime | None = None
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
@field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
class WorkflowTriggerListResponse(ResponseModel):
data: list[WorkflowTriggerResponse]
class WebhookTriggerResponse(ResponseModel):
id: str
webhook_id: str
webhook_url: str
webhook_debug_url: str
node_id: str
created_at: datetime | None = None
@field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before")
@classmethod
def _normalize_string_fields(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
register_schema_models(
console_ns,
Parser,
ParserEnable,
WorkflowTriggerResponse,
WorkflowTriggerListResponse,
WebhookTriggerResponse,
)
@@ -57,7 +91,7 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_model)
@console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__])
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource):
if not webhook_trigger:
raise NotFound("Webhook trigger not found for this node")
return webhook_trigger
return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/triggers")
@@ -89,7 +123,7 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_model)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
@@ -118,7 +152,9 @@ class AppTriggersApi(Resource):
else:
trigger.icon = "" # type: ignore
return {"data": triggers}
return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump(
mode="json"
)
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
@@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_model)
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
@@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource):
else:
trigger.icon = "" # type: ignore
return trigger
return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json")

View File

@@ -1,3 +1,5 @@
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
@@ -40,7 +42,7 @@ class ActivatePayload(BaseModel):
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):

View File

@@ -1026,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource):
if not isinstance(doc_metadata, dict):
raise ValueError("doc_metadata must be a dictionary.")
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
document.doc_metadata = {}
if doc_type == "others":

View File

@@ -1,13 +1,13 @@
from flask_restx import Resource, fields
from __future__ import annotations
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from datetime import datetime
from typing import Any
from flask_restx import Resource
from pydantic import Field, field_validator
from controllers.common.schema import register_schema_models
from fields.base import ResponseModel
from libs.login import login_required
from .. import console_ns
@@ -18,39 +18,92 @@ from ..wraps import (
setup_required,
)
register_schema_model(console_ns, HitTestingPayload)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
class HitTestingDocument(ResponseModel):
id: str | None = None
data_source_type: str | None = None
name: str | None = None
doc_type: str | None = None
doc_metadata: Any | None = None
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
class HitTestingSegment(ResponseModel):
id: str | None = None
position: int | None = None
document_id: str | None = None
content: str | None = None
sign_content: str | None = None
answer: str | None = None
word_count: int | None = None
tokens: int | None = None
keywords: list[str] = Field(default_factory=list)
index_node_id: str | None = None
index_node_hash: str | None = None
hit_count: int | None = None
enabled: bool | None = None
disabled_at: int | None = None
disabled_by: str | None = None
status: str | None = None
created_by: str | None = None
created_at: int | None = None
indexing_at: int | None = None
completed_at: int | None = None
error: str | None = None
stopped_at: int | None = None
document: HitTestingDocument | None = None
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
class HitTestingChildChunk(ResponseModel):
id: str | None = None
content: str | None = None
position: int | None = None
score: float | None = None
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
class HitTestingFile(ResponseModel):
id: str | None = None
name: str | None = None
size: int | None = None
extension: str | None = None
mime_type: str | None = None
source_url: str | None = None
class HitTestingRecord(ResponseModel):
segment: HitTestingSegment | None = None
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
score: float | None = None
tsne_position: Any | None = None
files: list[HitTestingFile] = Field(default_factory=list)
summary: str | None = None
class HitTestingResponse(ResponseModel):
query: str
records: list[HitTestingRecord] = Field(default_factory=list)
register_schema_models(
console_ns,
HitTestingPayload,
HitTestingDocument,
HitTestingSegment,
HitTestingChildChunk,
HitTestingFile,
HitTestingRecord,
HitTestingResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(
200,
"Hit testing completed successfully",
model=console_ns.models[HitTestingResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required
@@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")

View File

@@ -1,21 +1,24 @@
import logging
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, Field, computed_field, field_validator
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel):
logger = logging.getLogger(__name__)
app_model = get_or_create_model("InstalledAppInfo", app_fields)
def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
installed_app_fields_copy = installed_app_fields.copy()
installed_app_fields_copy["app"] = fields.Nested(app_model)
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
installed_app_list_fields_copy = installed_app_list_fields.copy()
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
def _safe_primitive(value: Any) -> Any:
if value is None or isinstance(value, (str, int, float, bool, datetime)):
return value
return None
class InstalledAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
use_icon_as_answer_icon: bool | None = None
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_like(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return _build_icon_url(self.icon_type, self.icon)
class InstalledAppResponse(ResponseModel):
id: str
app: InstalledAppInfoResponse
app_owner_tenant_id: str
is_pinned: bool
last_used_at: int | None = None
editable: bool
uninstallable: bool
@field_validator("app", mode="before")
@classmethod
def _normalize_app(cls, value: Any) -> Any:
if isinstance(value, dict):
return value
return {
"id": _safe_primitive(getattr(value, "id", "")) or "",
"name": _safe_primitive(getattr(value, "name", None)),
"mode": _safe_primitive(getattr(value, "mode", None)),
"icon_type": _safe_primitive(getattr(value, "icon_type", None)),
"icon": _safe_primitive(getattr(value, "icon", None)),
"icon_background": _safe_primitive(getattr(value, "icon_background", None)),
"use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)),
}
@field_validator("last_used_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class InstalledAppListResponse(ResponseModel):
installed_apps: list[InstalledAppResponse]
register_schema_models(
console_ns,
InstalledAppCreatePayload,
InstalledAppUpdatePayload,
InstalledAppsListQuery,
InstalledAppInfoResponse,
InstalledAppResponse,
InstalledAppListResponse,
)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(installed_app_list_model)
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
def get(self):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
@@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource):
)
)
return {"installed_apps": installed_app_list}
return InstalledAppListResponse.model_validate(
{"installed_apps": installed_app_list}, from_attributes=True
).model_dump(mode="json")
@login_required
@account_initialization_required

View File

@@ -1,66 +1,83 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, computed_field, field_validator
from constants.languages import languages
from controllers.common.schema import get_or_create_model
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import current_user, login_required
from services.recommended_app_service import RecommendedAppService
app_fields = {
"id": fields.String,
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
recommended_app_fields = {
"app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
class RecommendedAppInfoResponse(ResponseModel):
id: str
name: str | None = None
mode: str | None = None
icon: str | None = None
icon_type: str | None = None
icon_background: str | None = None
@staticmethod
def _normalize_enum_like(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("mode", "icon_type", mode="before")
@classmethod
def _normalize_enum_fields(cls, value: Any) -> str | None:
return cls._normalize_enum_like(value)
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def icon_url(self) -> str | None:
return build_icon_url(self.icon_type, self.icon)
class RecommendedAppResponse(ResponseModel):
app: RecommendedAppInfoResponse | None = None
app_id: str
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
category: str | None = None
position: int | None = None
is_listed: bool | None = None
can_trial: bool | None = None
class RecommendedAppListResponse(ResponseModel):
recommended_apps: list[RecommendedAppResponse]
categories: list[str]
register_schema_models(
console_ns,
RecommendedAppsQuery,
RecommendedAppInfoResponse,
RecommendedAppResponse,
RecommendedAppListResponse,
)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource):
else:
language_prefix = languages[0]
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
return RecommendedAppListResponse.model_validate(
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
from_attributes=True,
).model_dump(mode="json")
@console_ns.route("/explore/apps/<uuid:app_id>")

View File

@@ -1,15 +1,18 @@
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.api_based_extension_fields import api_based_extension_fields
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import register_schema_models
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
@@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel):
api_key: str = Field(description="API key for authentication")
register_schema_models(console_ns, APIBasedExtensionPayload)
class CodeBasedExtensionResponse(ResponseModel):
module: str = Field(description="Module name")
data: Any = Field(description="Extension data")
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
def _mask_api_key(api_key: str) -> str:
if not api_key:
return api_key
if len(api_key) <= 8:
return api_key[0] + "******" + api_key[-1]
return api_key[:3] + "******" + api_key[-3:]
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class APIBasedExtensionResponse(ResponseModel):
id: str
name: str
api_endpoint: str
api_key: str
created_at: int | None = None
@field_validator("api_key", mode="before")
@classmethod
def _normalize_api_key(cls, value: str) -> str:
return _mask_api_key(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
console_ns.schema_model(
"APIBasedExtensionListResponse",
TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]:
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
@console_ns.route("/code-based-extension")
@@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
),
console_ns.models[CodeBasedExtensionResponse.__name__],
)
@setup_required
@login_required
@@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource):
def get(self):
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
return CodeBasedExtensionResponse(
module=query.module,
data=CodeBasedExtensionService.get_code_based_extension(query.module),
).model_dump(mode="json")
@console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource):
@console_ns.doc("get_api_based_extensions")
@console_ns.doc(description="Get all API-based extensions for current tenant")
@console_ns.response(200, "Success", api_based_extension_list_model)
@console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self):
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
]
@console_ns.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension")
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
@console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
@@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource):
api_key=payload.api_key,
)
return APIBasedExtensionService.save(extension_data)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
@console_ns.route("/api-based-extension/<uuid:id>")
@@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("get_api_based_extension")
@console_ns.doc(description="Get API-based extension by ID")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.response(200, "Success", api_based_extension_model)
@console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"})
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
@console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_model)
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
@@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource):
if payload.api_key != HIDDEN_VALUE:
extension_data_from_db.api_key = payload.api_key
return APIBasedExtensionService.save(extension_data_from_db)
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
@console_ns.doc("delete_api_based_extension")
@console_ns.doc(description="Delete API-based extension")

View File

@@ -35,22 +35,24 @@ def plugin_permission_required(
return view(*args, **kwargs)
if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
match permission.install_permission:
case TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
case TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
match permission.debug_permission:
case TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
case TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)

View File

@@ -5,7 +5,7 @@ from typing import Any, Literal
import pytz
from flask import request
from flask_restx import Resource, fields, marshal_with
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
@@ -37,9 +37,10 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
@@ -178,17 +179,57 @@ def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
integrate_list_model = console_ns.model(
"AccountIntegrateList",
{"data": fields.List(fields.Nested(integrate_model))},
class AccountIntegrateResponse(ResponseModel):
provider: str
created_at: int | None = None
is_bound: bool
link: str | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountIntegrateListResponse(ResponseModel):
data: list[AccountIntegrateResponse]
class EducationVerifyResponse(ResponseModel):
token: str | None = None
class EducationStatusResponse(ResponseModel):
result: bool | None = None
is_student: bool | None = None
expire_at: int | None = None
allow_refresh: bool | None = None
@field_validator("expire_at", mode="before")
@classmethod
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class EducationAutocompleteResponse(ResponseModel):
data: list[str] = Field(default_factory=list)
curr_page: int | None = None
has_next: bool | None = None
register_schema_models(
console_ns,
AccountIntegrateResponse,
AccountIntegrateListResponse,
EducationVerifyResponse,
EducationStatusResponse,
EducationAutocompleteResponse,
)
@@ -359,7 +400,7 @@ class AccountIntegrateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_model)
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@@ -395,7 +436,9 @@ class AccountIntegrateApi(Resource):
}
)
return {"data": integrate_data}
return AccountIntegrateListResponse(
data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data]
).model_dump(mode="json")
@console_ns.route("/account/delete/verify")
@@ -447,31 +490,22 @@ class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.route("/account/education/verify")
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
return BillingService.EducationIdentity.verify(account.id, account.email)
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
@console_ns.route("/account/education")
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
"is_student": fields.Boolean,
"expire_at": TimestampField,
"allow_refresh": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__])
@setup_required
@login_required
@@ -491,37 +525,33 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(status_fields)
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
res = BillingService.EducationIdentity.status(account.id)
res = BillingService.EducationIdentity.status(account.id) or {}
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc)
return res
return EducationStatusResponse.model_validate(res).model_dump(mode="json")
@console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(data_fields)
@console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__])
def get(self):
payload = request.args.to_dict(flat=True)
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
return EducationAutocompleteResponse.model_validate(
BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {}
).model_dump(mode="json")
@console_ns.route("/account/change-email")

View File

@@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource):
class ParserValidate(BaseModel):
model: str
model_type: ModelType
credentials: dict
credentials: dict[str, Any]
console_ns.schema_model(

View File

@@ -1,8 +1,9 @@
import logging
from datetime import datetime
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@@ -26,6 +27,7 @@ from controllers.console.wraps import (
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
@@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel):
name: str
class TenantInfoResponse(ResponseModel):
id: str
name: str | None = None
plan: str | None = None
status: str | None = None
created_at: int | None = None
role: str | None = None
in_trial: bool | None = None
trial_end_reason: str | None = None
custom_config: dict | None = None
trial_credits: int | None = None
trial_credits_used: int | None = None
next_credit_reset_date: int | None = None
@field_validator("plan", "status", "trial_end_reason", mode="before")
@classmethod
def _normalize_enum_like(cls, value):
if value is None:
return None
if isinstance(value, str):
return value
return str(getattr(value, "value", value))
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None):
if isinstance(value, datetime):
return int(value.timestamp())
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -66,6 +99,7 @@ reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
reg(TenantInfoResponse)
provider_fields = {
"provider_name": fields.String,
@@ -180,7 +214,7 @@ class TenantApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tenant_fields)
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
def post(self):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
@@ -200,7 +234,13 @@ class TenantApi(Resource):
else:
raise Unauthorized("workspace is archived")
return WorkspaceService.get_tenant_info(tenant), 200
return (
TenantInfoResponse.model_validate(
WorkspaceService.get_tenant_info(tenant),
from_attributes=True,
).model_dump(mode="json"),
200,
)
@console_ns.route("/workspaces/switch")

View File

@@ -1,7 +1,9 @@
from datetime import datetime
from typing import Any, Literal
from flask import request
from flask_restx import Resource
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
@@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields._value_type_serializer import serialize_value_type
from fields.base import ResponseModel
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel):
value: Any
class ConversationVariableResponse(ResponseModel):
id: str
name: str
value_type: str
value: str | None = None
description: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
except ValueError:
return value
try:
return serialize_value_type(value)
except (AttributeError, TypeError, ValueError):
pass
try:
return serialize_value_type({"value_type": value})
except (AttributeError, TypeError, ValueError):
value_attr = getattr(value, "value", None)
if value_attr is not None:
return str(value_attr)
return str(value)
@field_validator("value", mode="before")
@classmethod
def normalize_value(cls, value: Any | None) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
return str(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel):
limit: int
has_more: bool
data: list[ConversationVariableResponse]
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
ConversationVariableResponse,
ConversationVariableInfiniteScrollPaginationResponse,
)
@@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource):
404: "Conversation not found",
}
)
@service_api_ns.response(
200,
"Variables retrieved successfully",
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser, c_id):
"""List all variables for a conversation.
@@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None
try:
return ConversationService.get_conversational_variable(
pagination = ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name
)
return ConversationVariableInfiniteScrollPaginationResponse.model_validate(
pagination, from_attributes=True
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource):
404: "Conversation or variable not found",
}
)
@service_api_ns.response(
200,
"Variable updated successfully",
service_api_ns.models[ConversationVariableResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns))
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value.
@@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource):
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.update_conversation_variable(
variable = ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, payload.value
)
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationVariableNotExistsError:

View File

@@ -1,13 +1,15 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Resource, fields
from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@@ -33,9 +35,10 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from fields.base import ResponseModel
from fields.end_user_fields import SimpleEndUser
from fields.member_fields import SimpleAccount
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def _enum_value(value):
return getattr(value, "value", value)
class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
return obj.status.value
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
if obj.status == WorkflowExecutionStatus.PAUSED:
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:
return {}
outputs = obj.outputs_dict
return outputs or {}
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": WorkflowRunStatusField,
"inputs": fields.Raw,
"outputs": WorkflowRunOutputsField,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": OptionalTimestampField,
"elapsed_time": fields.Float,
}
class WorkflowRunResponse(ResponseModel):
id: str
workflow_id: str
status: str
inputs: dict | list | str | int | float | bool | None = None
outputs: dict = Field(default_factory=dict)
error: str | None = None
total_steps: int | None = None
total_tokens: int | None = None
created_at: int | None = None
finished_at: int | None = None
elapsed_time: float | int | None = None
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
def build_workflow_run_model(api_or_ns: Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)
class WorkflowRunForLogResponse(ResponseModel):
id: str
version: str | None = None
status: str | None = None
triggered_from: str | None = None
error: str | None = None
elapsed_time: float | int | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
exceptions_count: int | None = None
@field_validator("status", "triggered_from", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", "finished_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPartialResponse(ResponseModel):
id: str
workflow_run: WorkflowRunForLogResponse | None = None
details: dict | list | str | int | float | bool | None = None
created_from: str | None = None
created_by_role: str | None = None
created_by_account: SimpleAccount | None = None
created_by_end_user: SimpleEndUser | None = None
created_at: int | None = None
@field_validator("created_from", "created_by_role", mode="before")
@classmethod
def _normalize_enum(cls, value):
return _enum_value(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class WorkflowAppLogPaginationResponse(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[WorkflowAppLogPartialResponse]
register_schema_models(
service_api_ns,
WorkflowRunResponse,
WorkflowRunForLogResponse,
WorkflowAppLogPartialResponse,
WorkflowAppLogPaginationResponse,
)
def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict:
status = _enum_value(workflow_run.status)
raw_outputs = workflow_run.outputs_dict
if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None:
outputs: dict = {}
elif isinstance(raw_outputs, dict):
outputs = raw_outputs
elif isinstance(raw_outputs, Mapping):
outputs = dict(raw_outputs)
else:
outputs = {}
return WorkflowRunResponse.model_validate(
{
"id": workflow_run.id,
"workflow_id": workflow_run.workflow_id,
"status": status,
"inputs": workflow_run.inputs,
"outputs": outputs,
"error": workflow_run.error,
"total_steps": workflow_run.total_steps,
"total_tokens": workflow_run.total_tokens,
"created_at": workflow_run.created_at,
"finished_at": workflow_run.finished_at,
"elapsed_time": workflow_run.elapsed_time,
}
).model_dump(mode="json")
def _serialize_workflow_log_pagination(pagination) -> dict:
return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json")
@service_api_ns.route("/workflows/run/<string:workflow_run_id>")
@@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
@service_api_ns.response(
200,
"Workflow run details retrieved successfully",
service_api_ns.models[WorkflowRunResponse.__name__],
)
def get(self, app_model: App, workflow_run_id: str):
"""Get a workflow task running detail.
@@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource):
)
if not workflow_run:
raise NotFound("Workflow run not found.")
return workflow_run
return _serialize_workflow_run(workflow_run)
@service_api_ns.route("/workflows/run")
@@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
@service_api_ns.response(
200,
"Logs retrieved successfully",
service_api_ns.models[WorkflowAppLogPaginationResponse.__name__],
)
def get(self, app_model: App):
"""Get workflow app logs.
@@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource):
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
return _serialize_workflow_log_pagination(workflow_app_log_pagination)

View File

@@ -1,7 +1,7 @@
import json
import re
from collections.abc import Generator
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
@@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
action_name = None

View File

@@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: dict | None = None
output_schema: dict[str, Any] | None = None
features: list[AgentFeature] | None = None
meta_version: str | None = None
# pydantic configs

View File

@@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> tuple[dict[str, Any], list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}

View File

@@ -41,7 +41,7 @@ class ModelConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for model config

View File

@@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[dict[str, Any] | str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
) -> Generator[dict[str, Any] | str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict | str, Any, None]:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
yield from cls.convert_stream_full_response(response)
return _generate_full_response()
@@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict | str, Any, None]:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
return _generate_simple_response()
@@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
raise NotImplementedError
@classmethod

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
) -> Generator[dict[str, Any] | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -682,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
match invoke_from:
case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION:
# not save log for debugging
return
if not workflow_run_id:
return

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
import httpx
@@ -14,7 +14,7 @@ class APIBasedExtensionRequestor:
self.api_endpoint = api_endpoint
self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict):
def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]:
"""
Request the api.
@@ -49,4 +49,4 @@ class APIBasedExtensionRequestor:
if response.status_code != 200:
raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}")
return cast(dict, response.json())
return cast(dict[str, Any], response.json())

View File

@@ -21,8 +21,8 @@ class ExtensionModule(StrEnum):
class ModuleExtension(BaseModel):
extension_class: Any | None = None
name: str
label: dict | None = None
form_schema: list | None = None
label: dict[str, Any] | None = None
form_schema: list[dict[str, Any]] | None = None
builtin: bool = True
position: int | None = None

View File

@@ -13,7 +13,7 @@ class ExternalDataToolFactory:
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None:
"""
Validate the incoming form config data.

View File

@@ -1,3 +1,5 @@
from typing import Any
from flask import Flask
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel
@@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota):
class HostingProvider(BaseModel):
enabled: bool = False
credentials: dict | None = None
credentials: dict[str, Any] | None = None
quota_unit: QuotaUnit | None = None
quotas: list[HostingQuota] = []

View File

@@ -737,7 +737,7 @@ class IndexingRunner:
def _update_document_index_status(
document_id: str,
after_indexing_status: IndexingStatus,
extra_update_params: dict[Any, Any] | None = None,
extra_update_params: Mapping[Any, Any] | None = None,
):
"""
Update the document indexing status.
@@ -764,7 +764,7 @@ class IndexingRunner:
db.session.commit()
@staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]):
"""
Update the document segment by document id.
"""

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, TypedDict, cast
from typing import Any, Protocol, TypedDict, cast
import json_repair
from graphon.enums import WorkflowNodeExecutionMetadataKey
@@ -533,7 +533,7 @@ class LLMGenerator:
def __instruction_modify_common(
tenant_id: str,
model_config: ModelConfig,
last_run: dict | None,
last_run: dict[str, Any] | None,
current: str | None,
error_message: str | None,
instruction: str,

View File

@@ -202,7 +202,7 @@ def _handle_native_json_schema(
structured_output_schema: Mapping,
model_parameters: dict[str, Any],
rules: list[ParameterRule],
):
) -> dict[str, Any]:
"""
Handle structured output for models with native JSON schema support.
@@ -224,7 +224,7 @@ def _handle_native_json_schema(
return model_parameters
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
"""
Set the appropriate response format parameter based on model rules.
@@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict[str, Any]):
def remove_additional_properties(schema: dict[str, Any]) -> None:
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
@@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict[str, Any]):
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict):
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
"""
Convert boolean type specifications to string in JSON schema.

View File

@@ -1,6 +1,7 @@
import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from typing import Any
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
@@ -226,7 +227,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
# invoke model
response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice)
def handle() -> Generator[dict, None, None]:
def handle() -> Generator[dict[str, Any], None, None]:
for chunk in response:
yield {"result": hexlify(chunk).decode("utf-8")}

View File

@@ -313,7 +313,7 @@ class SimplePromptTransform(PromptTransform):
return prompt_message
def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str):
def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict[str, Any]:
"""
Get simple prompt rule.
:param app_mode: app mode
@@ -325,7 +325,7 @@ class SimplePromptTransform(PromptTransform):
# Check if the prompt file is already loaded
if prompt_file_name in prompt_file_contents:
return cast(dict, prompt_file_contents[prompt_file_name])
return cast(dict[str, Any], prompt_file_contents[prompt_file_name])
# Get the absolute path of the subdirectory
prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
@@ -338,7 +338,7 @@ class SimplePromptTransform(PromptTransform):
# Store the content of the prompt file
prompt_file_contents[prompt_file_name] = content
return cast(dict, content)
return cast(dict[str, Any], content)
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
# baichuan

View File

@@ -106,7 +106,7 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
"""Embed file documents."""
# use doc embedding cache or store if not exists
multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))]

View File

@@ -11,7 +11,7 @@ class Embeddings(ABC):
raise NotImplementedError
@abstractmethod
def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]:
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
"""Embed file documents."""
raise NotImplementedError

View File

@@ -95,9 +95,9 @@ class ExtractProcessor:
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
assert upload_file is not None, "upload_file is required"
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
@@ -113,6 +113,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
assert upload_file is not None
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
@@ -123,6 +124,7 @@ class ExtractProcessor:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
assert upload_file is not None
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
@@ -149,12 +151,14 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
assert upload_file is not None
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
assert upload_file is not None
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)

View File

@@ -174,21 +174,25 @@ class FirecrawlApp:
return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
response: httpx.Response | None = None
for attempt in range(retries):
response = httpx.post(url, headers=headers, json=data)
if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt))
else:
return response
assert response is not None, "retries must be at least 1"
return response
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response:
response: httpx.Response | None = None
for attempt in range(retries):
response = httpx.get(url, headers=headers)
if response.status_code == 502:
time.sleep(backoff_factor * (2**attempt))
else:
return response
assert response is not None, "retries must be at least 1"
return response
def _handle_error(self, response, action):

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import NamedTuple, Union
from typing import Any, NamedTuple, Union
@dataclass
@@ -10,7 +10,7 @@ class ReactAction:
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
tool_input: Union[str, dict[str, Any]]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""
@@ -19,7 +19,7 @@ class ReactAction:
class ReactFinish(NamedTuple):
"""The final return value of an ReactFinish."""
return_values: dict
return_values: dict[str, Any]
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
@@ -139,7 +139,7 @@ class ReactMultiDatasetRouter:
def _invoke_llm(
self,
completion_param: dict,
completion_param: dict[str, Any],
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str],

View File

@@ -63,7 +63,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def split_text(self, text: str) -> list[str]:
"""Split text into multiple components."""
def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]:
def create_documents(self, texts: list[str], metadatas: list[dict[str, Any]] | None = None) -> list[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []

View File

@@ -254,7 +254,7 @@ def resolve_dify_schema_refs(
return resolver.resolve(schema)
def _remove_metadata_fields(schema: dict) -> dict:
def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]:
"""
Remove metadata fields from schema that shouldn't be included in resolved output

View File

@@ -1,4 +1,5 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
@@ -26,6 +27,6 @@ class ApiToolBundle(BaseModel):
# icon
icon: str | None = None
# openapi operation
openapi: dict
openapi: dict[str, Any]
# output schema
output_schema: Mapping[str, object] = Field(default_factory=dict)

View File

@@ -47,7 +47,7 @@ class ToolEngine:
@staticmethod
def agent_invoke(
tool: Tool,
tool_parameters: Union[str, dict],
tool_parameters: Union[str, dict[str, Any]],
user_id: str,
tenant_id: str,
message: Message,
@@ -85,7 +85,7 @@ class ToolEngine:
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict[str, Any],
invocation_meta_dict: dict[str, ToolInvokeMeta],
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
):
for message in messages:

View File

@@ -1,4 +1,5 @@
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@@ -19,10 +20,18 @@ class ToolLabelManager:
return list(set(tool_labels))
@classmethod
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
def update_tool_labels(
cls, controller: ToolProviderController, labels: list[str], session: Session | None = None
) -> None:
"""
Update tool labels
:param controller: tool provider controller
:param labels: list of tool labels
:param session: database session, if None, a new session will be created
:return: None
"""
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
@@ -30,26 +39,46 @@ class ToolLabelManager:
else:
raise ValueError("Unsupported tool type")
if session is not None:
cls._update_tool_labels_logics(session, provider_id, controller, labels)
else:
with sessionmaker(db.engine).begin() as _session:
cls._update_tool_labels_logics(_session, provider_id, controller, labels)
@classmethod
def _update_tool_labels_logics(
cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str]
) -> None:
"""
Update tool labels logics
:param session: database session
:param provider_id: tool provider ID
:param controller: tool provider controller
:param labels: list of tool labels
:return: None
"""
# delete old labels
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
_ = session.execute(
delete(ToolLabelBinding).where(
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type
)
)
# insert new labels
for label in labels:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type,
label_name=label,
)
)
db.session.commit()
session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label))
@classmethod
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
"""
Get tool labels
:param controller: tool provider controller
:return: list of tool labels (str)
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
@@ -60,9 +89,11 @@ class ToolLabelManager:
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type,
)
labels = db.session.scalars(stmt).all()
return list(labels)
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
labels: list[str] = list(_session.scalars(stmt).all())
return labels
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
@@ -78,16 +109,22 @@ class ToolLabelManager:
if not tool_providers:
return {}
provider_ids: list[str] = []
provider_types: set[str] = set()
for controller in tool_providers:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id)
provider_types.add(controller.provider_type)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
labels: list[ToolLabelBinding] = []
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
stmt = select(ToolLabelBinding).where(
ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types))
)
labels = list(_session.scalars(stmt).all())
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
@@ -39,7 +39,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id: str
user_id: str | None = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
inputs: dict[str, Any]
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):

View File

@@ -277,7 +277,7 @@ class WorkflowTool(Tool):
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, str | None]]]:
"""
transform the tool parameters
@@ -355,7 +355,7 @@ class WorkflowTool(Tool):
return result, files
def _update_file_mapping(self, file_dict: dict[str, Any]):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
match transfer_method:

View File

@@ -329,7 +329,7 @@ class EnterpriseMetricHandler:
return
include_content = exporter.include_content
attrs: dict = {
attrs: dict[str, Any] = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,

View File

@@ -9,6 +9,7 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import normalize_redis_key_prefix
class _CelerySentinelKwargsDict(TypedDict):
@@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict):
password: str | None
class CelerySentinelTransportDict(TypedDict):
class CelerySentinelTransportDict(TypedDict, total=False):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
global_keyprefix: str
class CelerySSLOptionsDict(TypedDict):
@@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
transport_options: CelerySentinelTransportDict | dict[str, Any]
if dify_config.CELERY_USE_SENTINEL:
return CelerySentinelTransportDict(
transport_options = CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}
else:
transport_options = {}
global_keyprefix = get_celery_redis_global_keyprefix()
if global_keyprefix:
transport_options["global_keyprefix"] = global_keyprefix
return transport_options
def get_celery_redis_global_keyprefix() -> str | None:
"""Return the Redis transport prefix for Celery when namespace isolation is enabled."""
normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
if not normalized_prefix:
return None
return f"{normalized_prefix}:"
def init_app(app: DifyApp) -> Celery:

View File

@@ -3,7 +3,7 @@ import logging
import ssl
from collections.abc import Callable
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
from typing import Any, Union, cast
import redis
from redis import RedisError
@@ -18,17 +18,26 @@ from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
from extensions.redis_names import (
normalize_redis_key_prefix,
serialize_redis_name,
serialize_redis_name_arg,
serialize_redis_name_args,
)
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
logger = logging.getLogger(__name__)
_normalize_redis_key_prefix = normalize_redis_key_prefix
_serialize_redis_name = serialize_redis_name
_serialize_redis_name_arg = serialize_redis_name_arg
_serialize_redis_name_args = serialize_redis_name_args
class RedisClientWrapper:
"""
A wrapper class for the Redis client that addresses the issue where the global
@@ -59,68 +68,148 @@ class RedisClientWrapper:
if self._client is None:
self._client = client
if TYPE_CHECKING:
# Type hints for IDE support and static analysis
# These are not executed at runtime but provide type information
def get(self, name: str | bytes) -> Any: ...
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any: ...
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ...
def setnx(self, name: str | bytes, value: Any) -> Any: ...
def delete(self, *names: str | bytes) -> Any: ...
def incr(self, name: str | bytes, amount: int = 1) -> Any: ...
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Lock: ...
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any: ...
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
def __getattr__(self, item: str) -> Any:
def _require_client(self) -> redis.Redis | RedisCluster:
if self._client is None:
raise RuntimeError("Redis client is not initialized. Call init_app first.")
return getattr(self._client, item)
return self._client
def _get_prefix(self) -> str:
return dify_config.REDIS_KEY_PREFIX
def get(self, name: str | bytes) -> Any:
return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix()))
def set(
self,
name: str | bytes,
value: Any,
ex: int | None = None,
px: int | None = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: int | None = None,
pxat: int | None = None,
) -> Any:
return self._require_client().set(
_serialize_redis_name_arg(name, self._get_prefix()),
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any:
return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value)
def setnx(self, name: str | bytes, value: Any) -> Any:
return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value)
def delete(self, *names: str | bytes) -> Any:
return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix()))
def incr(self, name: str | bytes, amount: int = 1) -> Any:
return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount)
def expire(
self,
name: str | bytes,
time: int | timedelta,
nx: bool = False,
xx: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().expire(
_serialize_redis_name_arg(name, self._get_prefix()),
time,
nx=nx,
xx=xx,
gt=gt,
lt=lt,
)
def exists(self, *names: str | bytes) -> Any:
return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix()))
def ttl(self, name: str | bytes) -> Any:
return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix()))
def getdel(self, name: str | bytes) -> Any:
return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix()))
def lock(
self,
name: str,
timeout: float | None = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float | None = None,
thread_local: bool = True,
) -> Any:
return self._require_client().lock(
_serialize_redis_name(name, self._get_prefix()),
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
)
def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any:
return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs)
def hgetall(self, name: str | bytes) -> Any:
return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix()))
def hdel(self, name: str | bytes, *keys: str | bytes) -> Any:
return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys)
def hlen(self, name: str | bytes) -> Any:
return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix()))
def zadd(
self,
name: str | bytes,
mapping: dict[str | bytes | int | float, float | int | str | bytes],
nx: bool = False,
xx: bool = False,
ch: bool = False,
incr: bool = False,
gt: bool = False,
lt: bool = False,
) -> Any:
return self._require_client().zadd(
_serialize_redis_name_arg(name, self._get_prefix()),
cast(Any, mapping),
nx=nx,
xx=xx,
ch=ch,
incr=incr,
gt=gt,
lt=lt,
)
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any:
return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max)
def zcard(self, name: str | bytes) -> Any:
return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix()))
def pubsub(self) -> PubSub:
return self._require_client().pubsub()
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any:
return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint)
def __getattr__(self, item: str) -> Any:
return getattr(self._require_client(), item)
redis_client: RedisClientWrapper = RedisClientWrapper()

View File

@@ -0,0 +1,32 @@
from configs import dify_config
def normalize_redis_key_prefix(prefix: str | None) -> str:
"""Normalize the configured Redis key prefix for consistent runtime use."""
if prefix is None:
return ""
return prefix.strip()
def get_redis_key_prefix() -> str:
"""Read and normalize the current Redis key prefix from config."""
return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX)
def serialize_redis_name(name: str, prefix: str | None = None) -> str:
"""Convert a logical Redis name into the physical name used in Redis."""
normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix)
if not normalized_prefix:
return name
return f"{normalized_prefix}:{name}"
def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes:
"""Prefix string Redis names while preserving bytes inputs unchanged."""
if isinstance(name, bytes):
return name
return serialize_redis_name(name, prefix)
def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]:
return tuple(serialize_redis_name_arg(name, prefix) for name in names)

View File

@@ -2,6 +2,7 @@ import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
import opendal
from dotenv import dotenv_values
@@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str
if key.startswith(config_prefix):
kwargs[key[len(config_prefix) :].lower()] = value
file_env_vars: dict = dotenv_values(env_file_path) or {}
file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {}
for key, value in file_env_vars.items():
if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value:
kwargs[key[len(config_prefix) :].lower()] = value

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@@ -32,12 +33,13 @@ class Topic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.publish(self._topic, payload)
self._client.publish(self._redis_topic, payload)
def as_subscriber(self) -> Subscriber:
return self
@@ -46,7 +48,7 @@ class Topic:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis, RedisCluster
@@ -30,12 +31,13 @@ class ShardedTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
self._redis_topic = serialize_redis_name(topic)
def as_producer(self) -> Producer:
return self
def publish(self, payload: bytes) -> None:
self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr]
self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr]
def as_subscriber(self) -> Subscriber:
return self
@@ -44,7 +46,7 @@ class ShardedTopic:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
topic=self._redis_topic,
)

View File

@@ -6,6 +6,7 @@ import threading
from collections.abc import Iterator
from typing import Self
from extensions.redis_names import serialize_redis_name
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
@@ -35,7 +36,7 @@ class StreamsTopic:
def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600):
self._client = redis_client
self._topic = topic
self._key = f"stream:{topic}"
self._key = serialize_redis_name(f"stream:{topic}")
self._retention_seconds = retention_seconds
self.max_length = 5000

View File

@@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock:
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
lock = self._lock
if lock is None:
raise RuntimeError("Redis lock initialization failed.")
acquired = bool(lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()

View File

@@ -120,10 +120,22 @@ class AppIconUrlField(fields.Raw):
obj = obj["app"]
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(obj.icon)
return build_icon_url(obj.icon_type, obj.icon)
return None
def build_icon_url(icon_type: Any, icon: str | None) -> str | None:
if icon is None or icon_type is None:
return None
from models.model import IconType
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
class AvatarUrlField(fields.Raw):
def output(self, key, obj, **kwargs):
if obj is None:

View File

@@ -1551,7 +1551,7 @@ class PipelineBuiltInTemplate(TypeBase):
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False)
@@ -1584,7 +1584,7 @@ class PipelineCustomizedTemplate(TypeBase):
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
description: Mapped[str] = mapped_column(LongText, nullable=False)
chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False)
icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
position: Mapped[int] = mapped_column(sa.Integer, nullable=False)
yaml_content: Mapped[str] = mapped_column(LongText, nullable=False)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False)
@@ -1657,7 +1657,7 @@ class DocumentPipelineExecutionLog(TypeBase):
datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
datasource_info: Mapped[str] = mapped_column(LongText, nullable=False)
datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False)
input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False

View File

@@ -1007,7 +1007,7 @@ class OAuthProviderApp(TypeBase):
app_icon: Mapped[str] = mapped_column(String(255), nullable=False)
client_id: Mapped[str] = mapped_column(String(255), nullable=False)
client_secret: Mapped[str] = mapped_column(String(255), nullable=False)
app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict)
redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list)
scope: Mapped[str] = mapped_column(
String(255),
@@ -2495,7 +2495,7 @@ class TraceAppConfig(TypeBase):
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True)
tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True)
tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
import sqlalchemy as sa
from sqlalchemy import func
@@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase):
)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
class DatasourceProvider(TypeBase):
@@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase):
provider: Mapped[str] = mapped_column(sa.String(128), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default")
is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1)
@@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider: Mapped[str] = mapped_column(sa.String(255), nullable=False)
plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(

View File

@@ -25,7 +25,7 @@ class DataSourceOauthBinding(TypeBase):
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
access_token: Mapped[str] = mapped_column(String(255), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False)
source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)

View File

@@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector):
self,
embeddings: list[list[float]],
metadatas: list[dict[Any, Any]] | None = None,
index_params: dict | None = None,
index_params: dict[str, Any] | None = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):

View File

@@ -1,3 +1,4 @@
from typing import Any
from uuid import UUID
from numpy import ndarray
@@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase):
__tablename__: str
id: Mapped[UUID]
text: Mapped[str]
meta: Mapped[dict]
meta: Mapped[dict[str, Any]]
vector: Mapped[ndarray]

View File

@@ -67,7 +67,7 @@ class PGVectoRS(BaseVector):
primary_key=True,
)
text: Mapped[str]
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))
self._table = _Table

View File

@@ -136,6 +136,9 @@ dify-vdb-weaviate = { workspace = true }
[tool.uv]
default-groups = ["storage", "tools", "vdb-all"]
package = false
override-dependencies = [
"pyarrow>=18.0.0",
]
[dependency-groups]

View File

@@ -87,7 +87,7 @@ class RetrievalModel(BaseModel):
class MetaDataConfig(BaseModel):
doc_type: str
doc_metadata: dict
doc_metadata: dict[str, Any]
class KnowledgeConfig(BaseModel):

View File

@@ -1,4 +1,5 @@
import logging
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule
@@ -168,7 +169,9 @@ class ModelProviderService:
model_name=model,
)
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
def get_provider_credential(
self, tenant_id: str, provider: str, credential_id: str | None = None
) -> dict[str, Any] | None:
"""
get provider credentials.
@@ -180,7 +183,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]):
"""
validate provider credentials before saving.
@@ -192,7 +195,7 @@ class ModelProviderService:
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
@@ -210,7 +213,7 @@ class ModelProviderService:
self,
tenant_id: str,
provider: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
) -> None:
@@ -254,7 +257,7 @@ class ModelProviderService:
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> dict | None:
) -> dict[str, Any] | None:
"""
Retrieve model-specific credentials.
@@ -270,7 +273,9 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
def validate_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any]
):
"""
validate model credentials.
@@ -287,7 +292,13 @@ class ModelProviderService:
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
self,
tenant_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
) -> None:
"""
create and save model credentials.
@@ -314,7 +325,7 @@ class ModelProviderService:
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
) -> None:

View File

@@ -104,7 +104,7 @@ class RagPipelineService:
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@classmethod
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict:
def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
@@ -120,7 +120,7 @@ class RagPipelineService:
return result
@classmethod
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None:
"""
Get pipeline template detail.
@@ -131,7 +131,7 @@ class RagPipelineService:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id)
if built_in_result is None:
logger.warning(
"pipeline template retrieval returned empty result, template_id: %s, mode: %s",
@@ -142,7 +142,7 @@ class RagPipelineService:
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id)
return customized_result
@classmethod
@@ -297,7 +297,7 @@ class RagPipelineService:
self,
*,
pipeline: Pipeline,
graph: dict,
graph: dict[str, Any],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[VariableBase],
@@ -467,7 +467,9 @@ class RagPipelineService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
def get_default_block_config(
self, node_type: str, filters: dict[str, Any] | None = None
) -> Mapping[str, object] | None:
"""
Get default config of node.
:param node_type: node type
@@ -500,7 +502,7 @@ class RagPipelineService:
return default_config
def run_draft_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account
self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account
) -> WorkflowNodeExecutionModel | None:
"""
Run draft workflow node
@@ -582,7 +584,7 @@ class RagPipelineService:
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
account: Account,
datasource_type: str,
is_published: bool,
@@ -749,7 +751,7 @@ class RagPipelineService:
self,
pipeline: Pipeline,
node_id: str,
user_inputs: dict,
user_inputs: dict[str, Any],
account: Account,
datasource_type: str,
is_published: bool,
@@ -979,7 +981,7 @@ class RagPipelineService:
return workflow_node_execution
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
) -> Workflow | None:
"""
Update workflow attributes
@@ -1099,7 +1101,9 @@ class RagPipelineService:
]
return datasource_provider_variables
def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination:
def get_rag_pipeline_paginate_workflow_runs(
self, pipeline: Pipeline, args: dict[str, Any]
) -> InfiniteScrollPagination:
"""
Get debug workflow run list
Only return triggered_from == debugging
@@ -1169,7 +1173,7 @@ class RagPipelineService:
return list(node_executions)
@classmethod
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]):
"""
Publish customized pipeline template
"""
@@ -1259,7 +1263,7 @@ class RagPipelineService:
)
return node_exec
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account):
def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account):
"""
Set datasource variables
"""
@@ -1346,7 +1350,7 @@ class RagPipelineService:
)
return workflow_node_execution_db_model
def get_recommended_plugins(self, type: str) -> dict:
def get_recommended_plugins(self, type: str) -> dict[str, Any]:
# Query active recommended plugins
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":

View File

@@ -139,62 +139,82 @@ class WorkflowToolManageService:
:param labels: labels
:return: the updated tool
"""
# check if the name is unique
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.limit(1)
)
existing_workflow_tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
# query if the name exists for other tools
existing_workflow_tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.limit(1)
)
# if the name exists raise error
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
# query the workflow tool provider
workflow_tool_provider: WorkflowToolProvider | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
workflow_tool_provider = _session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.limit(1)
)
# if not found raise error
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
# query the app
app: App | None = None
with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
app = _session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
# if not found raise error
if app is None:
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
# query the workflow
workflow: Workflow | None = app.workflow
# if not found raise error
if workflow is None:
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
# check if workflow configuration is synced
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict)
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
with sessionmaker(db.engine).begin() as _session:
_session.add(workflow_tool_provider)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
# update workflow tool provider
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters])
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
db.session.commit()
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
)
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
labels,
session=_session,
)
return {"result": "success"}

View File

@@ -241,8 +241,8 @@ class WorkflowService:
self,
*,
app_model: App,
graph: dict,
features: dict,
graph: dict[str, Any],
features: dict[str, Any],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[VariableBase],
@@ -576,7 +576,7 @@ class WorkflowService:
except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None:
"""
Validate load balancing credentials for a workflow node.
@@ -1214,7 +1214,7 @@ class WorkflowService:
return variable_pool
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run free workflow node
@@ -1361,7 +1361,7 @@ class WorkflowService:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App:
"""
Basic mode of chatbot app(expert mode) to workflow
Completion App to Workflow App
@@ -1421,7 +1421,7 @@ class WorkflowService:
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
self._validate_human_input_node_data(node_data)
def validate_features_structure(self, app_model: App, features: dict):
def validate_features_structure(self, app_model: App, features: dict[str, Any]):
match app_model.mode:
case AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
@@ -1434,7 +1434,7 @@ class WorkflowService:
case _:
raise ValueError(f"Invalid app mode: {app_model.mode}")
def _validate_human_input_node_data(self, node_data: dict) -> None:
def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None:
"""
Validate HumanInput node data format.
@@ -1452,7 +1452,7 @@ class WorkflowService:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any]
) -> Workflow | None:
"""
Update workflow attributes

View File

@@ -33,6 +33,7 @@ REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_USE_SSL=false
REDIS_DB=0
REDIS_KEY_PREFIX=
# PostgreSQL database configuration
DB_USERNAME=postgres

View File

@@ -432,7 +432,7 @@ class TestWorkflowAppLogEndpoints:
monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker)
def fake_get_paginate(self, **_kwargs):
return {"items": [], "total": 0}
return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
monkeypatch.setattr(
workflow_app_log_module.WorkflowAppService,
@@ -443,7 +443,7 @@ class TestWorkflowAppLogEndpoints:
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
class TestWorkflowDraftVariableEndpoints:
@@ -608,7 +608,8 @@ class TestWorkflowTriggerEndpoints:
with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is trigger
assert isinstance(result, dict)
assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys())
class TestWrapsEndpoints:

View File

@@ -0,0 +1,73 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console.app.conversation import _get_conversation
from models.enums import ConversationFromSource
from models.model import AppMode, Conversation
from tests.test_containers_integration_tests.controllers.console.helpers import (
create_console_account_and_tenant,
create_console_app,
)
def test_get_conversation_mark_read_keeps_updated_at_unchanged(
db_session_with_containers: Session,
):
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
original_updated_at = datetime(2026, 2, 8, 0, 0, 0)
conversation = Conversation(
app_id=app.id,
name="read timestamp test",
inputs={},
status="normal",
mode=AppMode.CHAT,
from_source=ConversationFromSource.CONSOLE,
from_account_id=account.id,
updated_at=original_updated_at,
)
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
read_at = datetime(2026, 2, 9, 0, 0, 0)
with (
patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, tenant.id),
autospec=True,
),
patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=read_at,
autospec=True,
),
):
loaded = _get_conversation(app, conversation.id)
db_session_with_containers.refresh(conversation)
assert loaded.id == conversation.id
assert conversation.read_at == read_at
assert conversation.read_account_id == account.id
assert conversation.updated_at == original_updated_at
def test_get_conversation_raises_not_found_for_missing_conversation(
db_session_with_containers: Session,
):
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, tenant.id),
autospec=True,
):
with pytest.raises(NotFound):
_get_conversation(app, "00000000-0000-0000-0000-000000000000")

View File

@@ -1,28 +1,48 @@
"""Unit tests for controllers.web.site endpoints."""
"""Testcontainers integration tests for controllers.web.site endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.web.site import AppSiteApi, AppSiteInfo
from models import Tenant, TenantStatus
from models.model import App, AppMode, CustomizeTokenStrategy, Site
def _tenant(*, status: str = "normal") -> SimpleNamespace:
return SimpleNamespace(
id="tenant-1",
status=status,
plan="basic",
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
@pytest.fixture
def app(flask_app_with_containers) -> Flask:
return flask_app_with_containers
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
tenant = Tenant(name="test-tenant", status=status)
db_session.add(tenant)
db_session.commit()
return tenant
def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App:
app_model = App(
tenant_id=tenant_id,
mode=AppMode.CHAT,
name="test-app",
enable_site=enable_site,
enable_api=True,
)
db_session.add(app_model)
db_session.commit()
return app_model
def _site() -> SimpleNamespace:
return SimpleNamespace(
def _create_site(db_session: Session, app_id: str) -> Site:
site = Site(
app_id=app_id,
title="Site",
icon_type="emoji",
icon="robot",
@@ -31,77 +51,64 @@ def _site() -> SimpleNamespace:
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
code=f"code-{app_id[-6:]}",
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
db_session.add(site)
db_session.commit()
return site
# ---------------------------------------------------------------------------
# AppSiteApi
# ---------------------------------------------------------------------------
class TestAppSiteApi:
@patch("controllers.web.site.FeatureService.get_features")
@patch("controllers.web.site.db")
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
site_obj = _site()
mock_db.session.scalar.return_value = site_obj
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
end_user = SimpleNamespace(id="eu-1")
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
with app.test_request_context("/site"):
result = AppSiteApi().get(app_model, end_user)
# marshal_with serializes AppSiteInfo to a dict
assert result["app_id"] == "app-1"
assert result["app_id"] == app_model.id
assert result["plan"] == "basic"
assert result["enable_site"] is True
@patch("controllers.web.site.db")
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
mock_db.session.scalar.return_value = None
tenant = _tenant()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
tenant = _create_tenant(db_session_with_containers)
app_model = _create_app(db_session_with_containers, tenant.id)
end_user = SimpleNamespace(id="eu-1")
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
@patch("controllers.web.site.db")
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
@patch("controllers.web.site.FeatureService.get_features")
def test_archived_tenant_raises_forbidden(
self, mock_features, app: Flask, db_session_with_containers: Session
) -> None:
app.config["RESTX_MASK_HEADER"] = "X-Fields"
from models.account import TenantStatus
mock_db.session.scalar.return_value = _site()
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.ARCHIVE,
plan="basic",
custom_config_dict={},
)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE)
app_model = _create_app(db_session_with_containers, tenant.id)
_create_site(db_session_with_containers, app_model.id)
end_user = SimpleNamespace(id="eu-1")
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
with app.test_request_context("/site"):
with pytest.raises(Forbidden):
AppSiteApi().get(app_model, end_user)
# ---------------------------------------------------------------------------
# AppSiteInfo
# ---------------------------------------------------------------------------
class TestAppSiteInfo:
def test_basic_fields(self) -> None:
tenant = _tenant()
site_obj = _site()
tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={})
site_obj = SimpleNamespace()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
assert info.app_id == "app-1"
@@ -118,7 +125,7 @@ class TestAppSiteInfo:
plan="pro",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
)
site_obj = _site()
site_obj = SimpleNamespace()
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
assert info.can_replace_logo is True

View File

@@ -0,0 +1,613 @@
"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths."""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
Dataset,
DatasetAutoDisableLog,
DatasetCollectionBinding,
DatasetPermission,
DatasetPermissionEnum,
)
from models.enums import DataSourceType
from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService
from services.errors.account import NoPermissionError
class DatasetPermissionIntegrationFactory:
@staticmethod
def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.OWNER,
) -> tuple[Account, Tenant]:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.role = role
account._current_tenant = tenant
return account, tenant
@staticmethod
def create_account_in_tenant(
db_session_with_containers: Session,
tenant: Tenant,
role: TenantAccountRole = TenantAccountRole.EDITOR,
) -> Account:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.role = role
account._current_tenant = tenant
return account
@staticmethod
def create_dataset(
db_session_with_containers: Session,
*,
tenant_id: str,
created_by: str,
name: str | None = None,
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY,
enable_api: bool = True,
) -> Dataset:
dataset = Dataset(
tenant_id=tenant_id,
name=name or f"dataset-{uuid4()}",
description="desc",
data_source_type=DataSourceType.UPLOAD_FILE,
indexing_technique=indexing_technique,
created_by=created_by,
provider="vendor",
permission=permission,
retrieval_model={"top_k": 2},
)
dataset.enable_api = enable_api
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_dataset_permission(
db_session_with_containers: Session,
*,
dataset_id: str,
tenant_id: str,
account_id: str,
) -> DatasetPermission:
permission = DatasetPermission(
dataset_id=dataset_id,
tenant_id=tenant_id,
account_id=account_id,
has_permission=True,
)
db_session_with_containers.add(permission)
db_session_with_containers.commit()
return permission
@staticmethod
def create_app_dataset_join(
db_session_with_containers: Session,
*,
dataset_id: str,
) -> AppDatasetJoin:
join = AppDatasetJoin(
app_id=str(uuid4()),
dataset_id=dataset_id,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return join
@staticmethod
def create_collection_binding(
db_session_with_containers: Session,
*,
provider_name: str,
model_name: str,
collection_type: str = "dataset",
) -> DatasetCollectionBinding:
binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name=f"collection_{uuid4().hex}",
type=collection_type,
)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
return binding
@staticmethod
def create_auto_disable_log(
db_session_with_containers: Session,
*,
tenant_id: str,
dataset_id: str,
document_id: str,
) -> DatasetAutoDisableLog:
log = DatasetAutoDisableLog(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
)
db_session_with_containers.add(log)
db_session_with_containers.commit()
return log
class TestDatasetServicePermissionsAndLifecycle:
def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session):
owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
result = DatasetService.delete_dataset(str(uuid4()), user=owner)
assert result is False
def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal:
result = DatasetService.delete_dataset(dataset.id, user=owner)
assert result is True
assert db_session_with_containers.get(Dataset, dataset.id) is None
send_deleted_signal.assert_called_once_with(dataset)
def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
DatasetPermissionIntegrationFactory.create_app_dataset_join(
db_session_with_containers,
dataset_id=dataset.id,
)
assert DatasetService.dataset_use_check(dataset.id) is True
def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
assert DatasetService.dataset_use_check(dataset.id) is False
def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(
db_session_with_containers
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, outsider)
def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ONLY_ME,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_rejects_partial_team_user_without_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session):
creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=creator.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_allows_partial_team_member_with_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member.id,
)
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
db_session_with_containers,
tenant,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ONLY_ME,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_check_dataset_operator_permission_rejects_partial_team_without_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
db_session_with_containers,
tenant,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_check_dataset_operator_permission_allows_partial_team_with_binding(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
operator = DatasetPermissionIntegrationFactory.create_account_in_tenant(
db_session_with_containers,
tenant,
role=TenantAccountRole.EDITOR,
)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=operator.id,
)
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers):
with flask_app_with_containers.app_context():
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(str(uuid4()), True)
def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
enable_api=False,
)
with patch("services.dataset_service.current_user", SimpleNamespace(id=None)):
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.update_dataset_api_status(dataset.id, True)
def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
enable_api=False,
)
now = datetime(2026, 4, 14, 18, 0, 0)
with (
patch("services.dataset_service.current_user", owner),
patch("services.dataset_service.naive_utc_now", return_value=now),
):
DatasetService.update_dataset_api_status(dataset.id, True)
db_session_with_containers.refresh(dataset)
assert dataset.enable_api is True
assert dataset.updated_by == owner.id
assert dataset.updated_at == now
def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(
self, db_session_with_containers: Session
):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional"))
)
with (
patch("services.dataset_service.current_user", owner),
patch("services.dataset_service.FeatureService.get_features", return_value=features),
):
result = DatasetService.get_dataset_auto_disable_logs(str(uuid4()))
assert result == {"document_ids": [], "count": 0}
def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
)
DatasetPermissionIntegrationFactory.create_auto_disable_log(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=str(uuid4()),
)
DatasetPermissionIntegrationFactory.create_auto_disable_log(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=str(uuid4()),
)
features = SimpleNamespace(
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional"))
)
with (
patch("services.dataset_service.current_user", owner),
patch("services.dataset_service.FeatureService.get_features", return_value=features),
):
result = DatasetService.get_dataset_auto_disable_logs(dataset.id)
assert result["count"] == 2
assert len(result["document_ids"]) == 2
class TestDatasetCollectionBindingServiceIntegration:
def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session):
binding = DatasetPermissionIntegrationFactory.create_collection_binding(
db_session_with_containers,
provider_name="provider",
model_name="model",
)
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
assert result.id == binding.id
def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session):
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model")
persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id)
assert persisted is not None
assert persisted.provider_name == "provider"
assert persisted.model_name == "missing-model"
assert persisted.type == "dataset"
assert persisted.collection_name
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers):
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4()))
def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session):
binding = DatasetPermissionIntegrationFactory.create_collection_binding(
db_session_with_containers,
provider_name="provider",
model_name="model",
)
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id)
assert result.id == binding.id
class TestDatasetPermissionServiceIntegration:
def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member_a.id,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member_b.id,
)
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {member_a.id, member_b.id}
def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=stale_member.id,
)
DatasetPermissionService.update_partial_member_list(
tenant.id,
dataset.id,
[{"user_id": member_a.id}, {"user_id": member_b.id}],
)
permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all()
assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id}
def test_check_permission_requires_dataset_editor(self):
user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM)
with pytest.raises(NoPermissionError, match="does not have permission"):
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, [])
def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM)
with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, [])
def test_check_permission_requires_partial_member_list_for_partial_members_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
with pytest.raises(ValueError, match="Partial member list is required"):
DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, [])
def test_check_permission_rejects_dataset_operator_member_list_changes(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
with pytest.raises(ValueError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(
user,
dataset,
DatasetPermissionEnum.PARTIAL_TEAM,
[{"user_id": "user-2"}],
)
def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
DatasetPermissionService.check_permission(
user,
dataset,
DatasetPermissionEnum.PARTIAL_TEAM,
[{"user_id": "user-1"}],
)
def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session):
owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers)
member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant)
dataset = DatasetPermissionIntegrationFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetPermissionIntegrationFactory.create_dataset_permission(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=tenant.id,
account_id=member.id,
)
DatasetPermissionService.clear_partial_member_list(dataset.id)
remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all()
assert remaining == []

View File

@@ -0,0 +1,387 @@
"""Testcontainers integration tests for schedule service SQL-backed behavior."""
from datetime import datetime
from types import SimpleNamespace
from uuid import uuid4
import pytest
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.trigger import WorkflowSchedulePlan
from services.errors.account import AccountNotFoundError
from services.trigger.schedule_service import ScheduleService
class ScheduleServiceIntegrationFactory:
@staticmethod
def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.OWNER,
) -> tuple[Account, Tenant]:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_schedule_plan(
db_session_with_containers: Session,
*,
tenant_id: str,
app_id: str | None = None,
node_id: str = "start",
cron_expression: str = "30 10 * * *",
timezone: str = "UTC",
next_run_at: datetime | None = None,
) -> WorkflowSchedulePlan:
schedule = WorkflowSchedulePlan(
tenant_id=tenant_id,
app_id=app_id or str(uuid4()),
node_id=node_id,
cron_expression=cron_expression,
timezone=timezone,
next_run_at=next_run_at,
)
db_session_with_containers.add(schedule)
db_session_with_containers.commit()
return schedule
def _cron_workflow(
*,
node_id: str = "start",
cron_expression: str = "30 10 * * *",
timezone: str = "UTC",
):
return SimpleNamespace(
graph_dict={
"nodes": [
{
"id": node_id,
"data": {
"type": "trigger-schedule",
"mode": "cron",
"cron_expression": cron_expression,
"timezone": timezone,
},
}
]
}
)
def _no_schedule_workflow():
return SimpleNamespace(
graph_dict={
"nodes": [
{
"id": "node-1",
"data": {"type": "llm"},
}
]
}
)
class TestScheduleServiceIntegration:
def test_create_schedule_persists_schedule(self, db_session_with_containers: Session):
account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
expected_next_run = datetime(2026, 1, 1, 10, 30, 0)
config = ScheduleConfig(
node_id="start",
cron_expression="30 10 * * *",
timezone="UTC",
)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
schedule = ScheduleService.create_schedule(
session=db_session_with_containers,
tenant_id=tenant.id,
app_id=str(uuid4()),
config=config,
)
persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id)
assert persisted is not None
assert persisted.tenant_id == tenant.id
assert persisted.node_id == "start"
assert persisted.cron_expression == "30 10 * * *"
assert persisted.timezone == "UTC"
assert persisted.next_run_at == expected_next_run
def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
cron_expression="30 10 * * *",
timezone="UTC",
)
expected_next_run = datetime(2026, 1, 2, 12, 0, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
updated = ScheduleService.update_schedule(
session=db_session_with_containers,
schedule_id=schedule.id,
updates=SchedulePlanUpdate(
cron_expression="0 12 * * *",
timezone="America/New_York",
),
)
db_session_with_containers.refresh(updated)
assert updated.cron_expression == "0 12 * * *"
assert updated.timezone == "America/New_York"
assert updated.next_run_at == expected_next_run
def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
initial_next_run = datetime(2026, 1, 1, 10, 0, 0)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
next_run_at=initial_next_run,
)
with pytest.MonkeyPatch.context() as monkeypatch:
calls: list[tuple] = []
def _track(*args, **kwargs):
calls.append((args, kwargs))
return datetime(2026, 1, 9, 10, 0, 0)
monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track)
updated = ScheduleService.update_schedule(
session=db_session_with_containers,
schedule_id=schedule.id,
updates=SchedulePlanUpdate(node_id="node-new"),
)
db_session_with_containers.refresh(updated)
assert updated.node_id == "node-new"
assert updated.next_run_at == initial_next_run
assert calls == []
def test_update_schedule_not_found_raises(self, db_session_with_containers: Session):
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
ScheduleService.update_schedule(
session=db_session_with_containers,
schedule_id=str(uuid4()),
updates=SchedulePlanUpdate(node_id="node-new"),
)
def test_delete_schedule_removes_row(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
)
ScheduleService.delete_schedule(
session=db_session_with_containers,
schedule_id=schedule.id,
)
db_session_with_containers.commit()
assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None
def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session):
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
ScheduleService.delete_schedule(
session=db_session_with_containers,
schedule_id=str(uuid4()),
)
def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session):
owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.OWNER,
)
result = ScheduleService.get_tenant_owner(
session=db_session_with_containers,
tenant_id=tenant.id,
)
assert result.id == owner.id
def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session):
admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.ADMIN,
)
result = ScheduleService.get_tenant_owner(
session=db_session_with_containers,
tenant_id=tenant.id,
)
assert result.id == admin.id
def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
db_session_with_containers.execute(delete(TenantAccountJoin))
missing_account_id = str(uuid4())
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=missing_account_id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
with pytest.raises(AccountNotFoundError, match=missing_account_id):
ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id)
def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.commit()
with pytest.raises(AccountNotFoundError, match=tenant.id):
ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id)
def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
schedule = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
)
expected_next_run = datetime(2026, 1, 3, 10, 30, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
result = ScheduleService.update_next_run_at(
session=db_session_with_containers,
schedule_id=schedule.id,
)
db_session_with_containers.refresh(schedule)
assert result == expected_next_run
assert schedule.next_run_at == expected_next_run
def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session):
with pytest.raises(ScheduleNotFoundError, match="Schedule not found"):
ScheduleService.update_next_run_at(
session=db_session_with_containers,
schedule_id=str(uuid4()),
)
class TestSyncScheduleFromWorkflowIntegration:
def test_sync_schedule_create_new(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
expected_next_run = datetime(2026, 1, 4, 10, 30, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
result = sync_schedule_from_workflow(
tenant_id=tenant.id,
app_id=app_id,
workflow=_cron_workflow(),
)
assert result is not None
persisted = db_session_with_containers.execute(
select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id)
).scalar_one()
assert persisted.node_id == "start"
assert persisted.cron_expression == "30 10 * * *"
assert persisted.timezone == "UTC"
assert persisted.next_run_at == expected_next_run
def test_sync_schedule_update_existing(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
existing = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
app_id=app_id,
node_id="old-start",
cron_expression="30 10 * * *",
timezone="UTC",
)
existing_id = existing.id
expected_next_run = datetime(2026, 1, 5, 12, 0, 0)
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.setattr(
"services.trigger.schedule_service.calculate_next_run_at",
lambda *_args, **_kwargs: expected_next_run,
)
result = sync_schedule_from_workflow(
tenant_id=tenant.id,
app_id=app_id,
workflow=_cron_workflow(
node_id="start",
cron_expression="0 12 * * *",
timezone="America/New_York",
),
)
assert result is not None
db_session_with_containers.expire_all()
persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id)
assert persisted is not None
assert persisted.node_id == "start"
assert persisted.cron_expression == "0 12 * * *"
assert persisted.timezone == "America/New_York"
assert persisted.next_run_at == expected_next_run
def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session):
_account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers)
app_id = str(uuid4())
existing = ScheduleServiceIntegrationFactory.create_schedule_plan(
db_session_with_containers,
tenant_id=tenant.id,
app_id=app_id,
)
existing_id = existing.id
result = sync_schedule_from_workflow(
tenant_id=tenant.id,
app_id=app_id,
workflow=_no_schedule_workflow(),
)
assert result is None
db_session_with_containers.expire_all()
assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300)
# Verify logs exist before processing
existing_logs = (
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
existing_logs = db_session_with_containers.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
).all()
assert len(existing_logs) == 2
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify auto disable logs were deleted
remaining_logs = (
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
remaining_logs = db_session_with_containers.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
).all()
assert len(remaining_logs) == 0
# Verify index processing occurred normally

View File

@@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
@@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit() # Ensure all changes are committed
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_with_image_files(
@@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Verify that the task completed successfully by checking the log output
@@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
# Verify database cleanup
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_dataset_not_found(
@@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Document should still exist since cleanup failed
existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
existing_document = db_session_with_containers.scalar(
select(Document).where(Document.id == document_id).limit(1)
)
assert existing_document is not None
def test_batch_clean_document_task_storage_cleanup_failure(
@@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted from database
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted from database
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_multiple_documents(
@@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
assert deleted_file is None
def test_batch_clean_document_task_different_doc_forms(
@@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
except Exception as e:
@@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check if the segment still exists (task may have failed before deletion)
existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
existing_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
if existing_segment is not None:
# If segment still exists, the task failed before deletion
# This is acceptable in test environments with external service issues
@@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
assert deleted_file is None
def test_batch_clean_document_task_integration_with_real_database(
@@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Verify initial state
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 3
)
assert (
db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1))
is not None
)
# Store original IDs for verification
document_id = document.id
@@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
# Verify final database state
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id)
)
== 0
)
assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None

View File

@@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask:
from extensions.ext_redis import redis_client
# Clear all test data
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that segments were created
segments = (
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
)
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
).all()
assert len(segments) == 3
# Verify segment content and metadata
@@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created (since dataset doesn't exist)
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify no documents were modified
documents = db_session_with_containers.query(Document).all()
documents = db_session_with_containers.scalars(select(Document)).all()
assert len(documents) == 0
def test_batch_create_segment_to_index_task_document_not_found(
@@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify dataset remains unchanged (no segments were added to the dataset)
db_session_with_containers.refresh(dataset)
segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments_for_dataset = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(segments_for_dataset) == 0
def test_batch_create_segment_to_index_task_document_not_available(
@@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
assert len(segments) == 0
def test_batch_create_segment_to_index_task_upload_file_not_found(
@@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify document remains unchanged
@@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify error handling
# Since exception was raised, no segments should be created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify document remains unchanged
@@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that new segments were created with correct positions
all_segments = (
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
)
all_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
).all()
assert len(all_segments) == 6 # 3 existing + 3 new
# Verify position ordering

View File

@@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -52,18 +53,18 @@ class TestCleanDatasetTask:
from extensions.ext_redis import redis_client
# Clear all test data using the provided session fixture
db_session_with_containers.query(DatasetMetadataBinding).delete()
db_session_with_containers.query(DatasetMetadata).delete()
db_session_with_containers.query(AppDatasetJoin).delete()
db_session_with_containers.query(DatasetQuery).delete()
db_session_with_containers.query(DatasetProcessRule).delete()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DatasetMetadataBinding))
db_session_with_containers.execute(delete(DatasetMetadata))
db_session_with_containers.execute(delete(AppDatasetJoin))
db_session_with_containers.execute(delete(DatasetQuery))
db_session_with_containers.execute(delete(DatasetProcessRule))
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@@ -302,28 +303,40 @@ class TestCleanDatasetTask:
# Verify results
# Check that dataset-related data was cleaned up
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all()
assert len(documents) == 0
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(segments) == 0
# Check that metadata and bindings were cleaned up
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(metadata) == 0
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
bindings = db_session_with_containers.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
).all()
assert len(bindings) == 0
# Check that process rules and queries were cleaned up
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
process_rules = db_session_with_containers.scalars(
select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id)
).all()
assert len(process_rules) == 0
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
queries = db_session_with_containers.scalars(
select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id)
).all()
assert len(queries) == 0
# Check that app dataset joins were cleaned up
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
app_joins = db_session_with_containers.scalars(
select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id)
).all()
assert len(app_joins) == 0
# Verify index processor was called
@@ -414,24 +427,32 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id.in_(upload_file_ids))
).all()
assert len(remaining_files) == 0
# Check that metadata and bindings were cleaned up
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(remaining_metadata) == 0
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
remaining_bindings = db_session_with_containers.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
).all()
assert len(remaining_bindings) == 0
# Verify index processor was called
@@ -485,12 +506,14 @@ class TestCleanDatasetTask:
# Check that all data was cleaned up
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
remaining_segments = (
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
)
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Recreate data for next test case
@@ -538,11 +561,15 @@ class TestCleanDatasetTask:
# Verify results - even with vector cleanup failure, documents and segments should be deleted
# Check that documents were still deleted despite vector cleanup failure
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite vector cleanup failure
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Verify that index processor was called and failed
@@ -622,18 +649,22 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all image files were deleted from database
image_file_ids = [f.id for f in image_files]
remaining_image_files = (
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
)
remaining_image_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id.in_(image_file_ids))
).all()
assert len(remaining_image_files) == 0
# Verify that storage.delete was called for each image file
@@ -738,24 +769,32 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id.in_(upload_file_ids))
).all()
assert len(remaining_files) == 0
# Check that all metadata and bindings were deleted
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(remaining_metadata) == 0
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
remaining_bindings = db_session_with_containers.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id)
).all()
assert len(remaining_bindings) == 0
# Verify performance expectations
@@ -826,7 +865,9 @@ class TestCleanDatasetTask:
# Check that upload file was still deleted from database despite storage failure
# Note: When storage operations fail, the upload file may not be deleted
# This demonstrates that the cleanup process continues even with storage errors
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id == upload_file.id)
).all()
# The upload file should still be deleted from the database even if storage cleanup fails
# However, this depends on the specific implementation of clean_dataset_task
if len(remaining_files) > 0:
@@ -976,19 +1017,27 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.scalars(
select(Document).where(Document.dataset_id == dataset.id)
).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
remaining_files = db_session_with_containers.scalars(
select(UploadFile).where(UploadFile.id == upload_file_id)
).all()
assert len(remaining_files) == 0
# Check that all metadata was deleted
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id)
).all()
assert len(remaining_metadata) == 0
# Verify that storage.delete was called

View File

@@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
@@ -145,11 +146,16 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id.in_(document_ids))
)
== 3
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
)
== 6
)
@@ -158,9 +164,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
)
== 0
)
@@ -323,9 +329,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id)
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)
@@ -411,7 +417,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)
@@ -499,9 +507,16 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify all data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
)
== 5
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== 10
)
@@ -514,19 +529,26 @@ class TestCleanNotionDocumentTask:
# Verify only specified documents' segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean))
.count()
db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id.in_(documents_to_clean))
)
== 0
)
# Verify remaining documents and segments are intact
remaining_docs = [doc.id for doc in documents[3:]]
assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(remaining_docs))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs))
)
== 2
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs))
)
== 4
)
@@ -613,7 +635,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments exist before cleanup
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 4
)
@@ -622,7 +646,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)
@@ -795,11 +821,15 @@ class TestCleanNotionDocumentTask:
# Verify all data exists before cleanup
assert (
db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
)
== num_documents
)
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== num_documents * num_segments_per_doc
)
@@ -809,7 +839,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== 0
)
@@ -906,8 +938,8 @@ class TestCleanNotionDocumentTask:
# Verify all data exists before cleanup
# Note: There may be documents from previous tests, so we check for at least 3
assert db_session_with_containers.query(Document).count() >= 3
assert db_session_with_containers.query(DocumentSegment).count() >= 9
assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3
assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9
# Clean up documents from only the first dataset
target_dataset = datasets[0]
@@ -919,19 +951,26 @@ class TestCleanNotionDocumentTask:
# Verify only documents' segments from target dataset are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id)
.count()
db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == target_document.id)
)
== 0
)
# Verify documents from other datasets remain intact
remaining_docs = [doc.id for doc in all_documents[1:]]
assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(remaining_docs))
.count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs))
)
== 2
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs))
)
== 6
)
@@ -1028,11 +1067,13 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify all data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len(
document_statuses
)
assert db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id)
) == len(document_statuses)
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== len(document_statuses) * 2
)
@@ -1042,7 +1083,9 @@ class TestCleanNotionDocumentTask:
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
)
== 0
)
@@ -1142,9 +1185,16 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Verify data exists before cleanup
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(Document).where(Document.id == document.id)
)
== 1
)
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 3
)
@@ -1153,7 +1203,9 @@ class TestCleanNotionDocumentTask:
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 0
)

View File

@@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration:
def _query_document(self, db_session_with_containers, document_id: str) -> Document | None:
"""Return the latest persisted document state."""
return db_session_with_containers.query(Document).where(Document.id == document_id).first()
return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1))
def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None:
"""Assert all target documents are persisted in parsing status."""

View File

@@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe
from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask:
db_session_with_containers.refresh(segments[1])
# Check that segments are re-enabled after error
updated_segments = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
)
updated_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
).all()
for segment in updated_segments:
assert segment.enabled is True

View File

@@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy import delete, func, select, update
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask:
"""Test that task raises error when data_source_info is empty."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
{"data_source_info": None}
db_session_with_containers.execute(
update(Document).where(Document.id == context["document"].id).values(data_source_info=None)
)
db_session_with_containers.commit()
@@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.ERROR
@@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.COMPLETED
@@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
@@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask:
context = self._create_notion_sync_context(db_session_with_containers)
def _delete_dataset_before_clean() -> str:
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id))
db_session_with_containers.commit()
return "2024-01-02T00:00:00Z"
@@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.ERROR

View File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask:
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1))
assert updated.indexing_status == IndexingStatus.PARSING
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
remaining = db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
assert remaining == 0
@@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask:
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
remaining = db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
assert remaining > 0

View File

@@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
"""Test that DB_EXTRAS options are merged with the default timezone startup option."""
# Set environment variables
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
@@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
# Create config
config = DifyConfig()
# Get engine options
engine_options = config.SQLALCHEMY_ENGINE_OPTIONS
# Verify options contains both search_path and timezone
options = engine_options["connect_args"]["options"]
options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"]
assert "search_path=myschema" in options
assert "timezone=UTC" in options
def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema")
monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "")
config = DifyConfig()
assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == {
"options": "-c search_path=myschema",
}
def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
@@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.
_ = DifyConfig().normalized_pubsub_redis_url
def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == ""
def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a")
config = DifyConfig(_env_file=None)
assert config.REDIS_KEY_PREFIX == "enterprise-a"
@pytest.mark.parametrize(
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
[

View File

@@ -138,12 +138,15 @@ def app_models(app_module):
def patch_signed_url(monkeypatch, app_module):
"""Ensure icon URL generation uses a deterministic helper for tests."""
def _fake_signed_url(key: str | None) -> str | None:
if not key:
def _fake_build_icon_url(_icon_type, key: str | None) -> str | None:
if key is None:
return None
icon_type = str(_icon_type).lower()
if icon_type != "image":
return None
return f"signed:{key}"
monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url)
def _ts(hour: int = 12) -> datetime:

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from controllers.console.app.conversation import _get_conversation
def test_get_conversation_mark_read_keeps_updated_at_unchanged():
app_model = SimpleNamespace(id="app-id")
account = SimpleNamespace(id="account-id")
conversation = MagicMock()
conversation.id = "conversation-id"
with (
patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, None),
autospec=True,
),
patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=datetime(2026, 2, 9, 0, 0, 0),
autospec=True,
),
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
):
mock_session.scalar.return_value = conversation
_get_conversation(app_model, "conversation-id")
statement = mock_session.execute.call_args[0][0]
compiled = statement.compile()
sql_text = str(compiled).lower()
compact_sql_text = sql_text.replace(" ", "")
params = compiled.params
assert "updated_at=current_timestamp" not in compact_sql_text
assert "updated_at=conversations.updated_at" in compact_sql_text
assert "read_at=:read_at" in compact_sql_text
assert "read_account_id=:read_account_id" in compact_sql_text
assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0)
assert params["read_account_id"] == "account-id"

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
from contextlib import nullcontext
from datetime import UTC, datetime
from types import SimpleNamespace
import pytest
from graphon.variables.types import SegmentType
from pydantic import ValidationError
from controllers.console.app import conversation_variables as conversation_variables_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
created_at = datetime(2026, 1, 1, tzinfo=UTC)
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
row = SimpleNamespace(
created_at=created_at,
updated_at=updated_at,
to_variable=lambda: SimpleNamespace(
model_dump=lambda: {
"id": "var-1",
"name": "my_var",
"value_type": "string",
"value": "value",
"description": "desc",
}
),
)
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
conversation_variables_module,
"sessionmaker",
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
)
with app.test_request_context(
"/console/api/apps/app-1/conversation-variables",
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response["page"] == 1
assert response["limit"] == 100
assert response["total"] == 1
assert response["has_more"] is False
assert response["data"][0]["id"] == "var-1"
assert response["data"][0]["created_at"] == int(created_at.timestamp())
assert response["data"][0]["updated_at"] == int(updated_at.timestamp())
def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
row = SimpleNamespace(
created_at=None,
updated_at=None,
to_variable=lambda: SimpleNamespace(
model_dump=lambda: {
"id": "var-2",
"name": "my_var_2",
"value_type": SegmentType.INTEGER,
"value": 42,
"description": None,
}
),
)
session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row]))
monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
conversation_variables_module,
"sessionmaker",
lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)),
)
with app.test_request_context(
"/console/api/apps/app-1/conversation-variables",
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response["data"][0]["value_type"] == "number"
assert response["data"][0]["value"] == "42"
def test_get_conversation_variables_requires_conversation_id(app) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
with pytest.raises(ValidationError):
method(app_model=SimpleNamespace(id="app-1"))

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from datetime import UTC, datetime
import pytest
from controllers.console.app import message as message_module
@@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
assert len(response.data) == 2
assert response.data[0] == "What is AI?"
def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageDetailResponse normalizes alias fields and datetime timestamps."""
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = message_module.MessageDetailResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"conversation_id": "550e8400-e29b-41d4-a716-446655440001",
"inputs": {"foo": "bar"},
"query": "hello",
"re_sign_file_url_answer": "world",
"from_source": "user",
"status": "normal",
"created_at": created_at,
"message_metadata_dict": {"token_usage": 3},
}
)
assert response.answer == "world"
assert response.metadata == {"token_usage": 3}
assert response.created_at == int(created_at.timestamp())

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from datetime import UTC, datetime
from graphon.enums import WorkflowExecutionStatus
from controllers.console.app import workflow_app_log as workflow_app_log_module
def test_workflow_app_log_query_parses_bool_and_datetime():
query = workflow_app_log_module.WorkflowAppLogQuery.model_validate(
{
"detail": "true",
"created_at__before": "2026-01-02T03:04:05Z",
"page": "2",
"limit": "10",
}
)
assert query.detail is True
assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
assert query.page == 2
assert query.limit == 10
def test_workflow_app_log_pagination_response_normalizes_nested_fields():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate(
{
"page": 1,
"limit": 20,
"total": 1,
"has_more": False,
"data": [
{
"id": "log-1",
"workflow_run": {
"id": "run-1",
"status": WorkflowExecutionStatus.SUCCEEDED,
"created_at": created_at,
"finished_at": created_at,
},
"details": {"trigger_metadata": {}},
"created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"},
"created_at": created_at,
}
],
}
).model_dump(mode="json")
assert response["data"][0]["workflow_run"]["status"] == "succeeded"
assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp())
assert response["data"][0]["created_at"] == int(created_at.timestamp())
def test_workflow_archived_log_pagination_response_normalizes_nested_fields():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate(
{
"page": 1,
"limit": 20,
"total": 1,
"has_more": False,
"data": [
{
"id": "archived-1",
"workflow_run": {
"id": "run-1",
"status": WorkflowExecutionStatus.FAILED,
},
"trigger_metadata": {"type": "trigger-plugin"},
"created_by_end_user": {
"id": "eu-1",
"type": "anonymous",
"is_anonymous": True,
"session_id": "session-1",
},
"created_at": created_at,
}
],
}
).model_dump(mode="json")
assert response["data"][0]["workflow_run"]["status"] == "failed"
assert response["data"][0]["created_at"] == int(created_at.timestamp())

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from controllers.console.app import workflow_trigger as workflow_trigger_module
def test_parser_models_validate():
parser = workflow_trigger_module.Parser(node_id="node-1")
enable_parser = workflow_trigger_module.ParserEnable(
trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True
)
assert parser.node_id == "node-1"
assert enable_parser.enable_trigger is True
def test_workflow_trigger_response_serializes_datetime():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
trigger = SimpleNamespace(
id="trigger-1",
trigger_type="trigger-plugin",
title="Trigger",
node_id="node-1",
provider_name="provider",
icon="https://example.com/icon",
status="enabled",
created_at=created_at,
updated_at=created_at,
)
payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(
mode="json"
)
assert payload["id"] == "trigger-1"
assert payload["created_at"] == "2026-01-02T03:04:05Z"
assert payload["updated_at"] == "2026-01-02T03:04:05Z"
def test_webhook_trigger_response_serializes_datetime():
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
webhook = {
"id": "webhook-1",
"webhook_id": "whk-1",
"webhook_url": "https://example.com/hook",
"webhook_debug_url": "https://example.com/hook/debug",
"node_id": "node-1",
"created_at": created_at,
}
payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json")
assert payload["webhook_id"] == "whk-1"
assert payload["created_at"] == "2026-01-02T03:04:05Z"

Some files were not shown because too many files have changed in this diff Show More