diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml index 51f3ca54b6..974da99aad 100644 --- a/.github/workflows/pyrefly-type-coverage-comment.yml +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -62,7 +62,7 @@ jobs: - name: Render coverage markdown from structured data id: render run: | - comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \ + comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \ --base base_report.json \ < pr_report.json)" diff --git a/api/.env.example b/api/.env.example index a04a18944a..beb820e797 100644 --- a/api/.env.example +++ b/api/.env.example @@ -57,6 +57,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # redis Sentinel configuration. REDIS_USE_SENTINEL=false diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 817284d26f..c392b8840f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -160,6 +160,16 @@ class DatabaseConfig(BaseSettings): default="", ) + DB_SESSION_TIMEZONE_OVERRIDE: str = Field( + description=( + "PostgreSQL session timezone override injected via startup options." + " Default is 'UTC' for out-of-the-box consistency." + " Set to empty string to disable app-level timezone injection, for example when using RDS Proxy" + " together with a database-side default timezone." + ), + default="UTC", + ) + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str: @@ -227,12 +237,13 @@ class DatabaseConfig(BaseSettings): connect_args: dict[str, str] = {} # Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"): - timezone_opt = "-c timezone=UTC" - if options: - merged_options = f"{options} {timezone_opt}" - else: - merged_options = timezone_opt - connect_args = {"options": merged_options} + merged_options = options.strip() + session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip() + if session_timezone_override: + timezone_opt = f"-c timezone={session_timezone_override}" + merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt + if merged_options: + connect_args = {"options": merged_options} result: SQLAlchemyEngineOptionsDict = { "pool_size": self.SQLALCHEMY_POOL_SIZE, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index b49275758a..2def0a0d4e 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -32,6 +32,11 @@ class RedisConfig(BaseSettings): default=0, ) + REDIS_KEY_PREFIX: str = Field( + description="Optional global prefix for Redis keys, topics, and transport artifacts", + default="", + ) + REDIS_USE_SSL: bool = Field( description="Enable SSL/TLS for the Redis connection", default=False, diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2018f60215..75569eb596 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -6,7 +6,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -31,6 +30,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db from fields.base import ResponseModel +from libs.helper import build_icon_url from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType @@ -161,15 +161,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None: return value -def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: - if icon is None or icon_type is None: - return None - icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) - if icon_type_value.lower() != IconType.IMAGE: - return None - return file_helpers.get_signed_file_url(icon) - - class Tag(ResponseModel): id: str name: str @@ -292,7 +283,7 @@ class Site(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("icon_type", mode="before") @classmethod @@ -342,7 +333,7 @@ class AppPartial(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("created_at", "updated_at", mode="before") @classmethod @@ -390,7 +381,7 @@ class AppDetailWithSite(AppDetail): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) class AppPagination(ResponseModel): diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 369c26a80c..cead33d14f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,44 +1,86 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.conversation_variable_fields import ( - conversation_variable_fields, - paginated_conversation_variable_fields, -) +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from libs.login import login_required from models import ConversationVariable from models.model import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ConversationVariablesQuery(BaseModel): conversation_id: str = Field(..., description="Conversation ID to filter variables") -console_ns.schema_model( - ConversationVariablesQuery.__name__, - ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) -# For nested models, need to replace nested dict with registered model -paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy() -paginated_conversation_variable_fields_copy["data"] = fields.List( - fields.Nested(conversation_variable_model), attribute="data" -) -paginated_conversation_variable_model = console_ns.model( - "PaginatedConversationVariable", paginated_conversation_variable_fields_copy +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type().value) + if isinstance(value, str): + return value + try: + return serialize_value_type(value) + except Exception: + return serialize_value_type({"value_type": value}) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +register_schema_models( + console_ns, + ConversationVariablesQuery, + ConversationVariableResponse, + PaginatedConversationVariableResponse, ) @@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource): @console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) - @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) + @console_ns.response( + 200, + "Conversation variables retrieved successfully", + console_ns.models[PaginatedConversationVariableResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - @marshal_with(paginated_conversation_variable_model) def get(self, app_model): args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource): with sessionmaker(db.engine, expire_on_commit=False).begin() as session: rows = session.scalars(stmt).all() - return { - "page": page, - "limit": page_size, - "total": len(rows), - "has_more": False, - "data": [ - { - "created_at": row.created_at, - "updated_at": row.updated_at, - **row.to_variable().model_dump(), - } - for row in rows - ], - } + response = PaginatedConversationVariableResponse.model_validate( + { + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ + ConversationVariableResponse.model_validate( + { + "created_at": row.created_at, + "updated_at": row.updated_at, + **row.to_variable().model_dump(), + } + ) + for row in rows + ], + } + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5a19544eab..daeed4627c 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,9 @@ import logging +from datetime import datetime from typing import Literal from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select @@ -25,10 +26,21 @@ from controllers.console.wraps import ( setup_required, ) from core.app.entities.app_invoke_entities import InvokeFrom +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db -from fields.raws import FilesContainedField -from libs.helper import TimestampField, uuid_value +from fields.base import ResponseModel +from fields.conversation_fields import ( + AgentThought, + ConversationAnnotation, + ConversationAnnotationHitHistory, + Feedback, + JSONValue, + MessageFile, + format_files_contained, + to_timestamp, +) +from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required from models.enums import FeedbackFromSource, FeedbackRating @@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel): data: list[str] = Field(description="Suggested question") +class MessageDetailResponse(ResponseModel): + id: str + conversation_id: str + inputs: dict[str, JSONValue] + query: str + message: JSONValue | None = None + message_tokens: int | None = None + answer: str = Field(validation_alias="re_sign_file_url_answer") + answer_tokens: int | None = None + provider_response_latency: float | None = None + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + feedbacks: list[Feedback] = Field(default_factory=list) + workflow_run_id: str | None = None + annotation: ConversationAnnotation | None = None + annotation_hit_history: ConversationAnnotationHitHistory | None = None + created_at: int | None = None + agent_thoughts: list[AgentThought] = Field(default_factory=list) + message_files: list[MessageFile] = Field(default_factory=list) + extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list) + metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict") + status: str + error: str | None = None + parent_message_id: str | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class MessageInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[MessageDetailResponse] + + register_schema_models( console_ns, ChatMessagesQuery, @@ -105,124 +162,8 @@ register_schema_models( FeedbackExportQuery, AnnotationCountResponse, SuggestedQuestionsResponse, -) - -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "extra_contents": fields.List(fields.Raw), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Message infinite scroll pagination model -message_infinite_scroll_pagination_model = console_ns.model( - "MessageInfiniteScrollPagination", - { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_detail_model)), - }, + MessageDetailResponse, + MessageInfiniteScrollPaginationResponse, ) @@ -232,13 +173,12 @@ class ChatMessageListApi(Resource): @console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__]) - @console_ns.response(200, "Success", message_infinite_scroll_pagination_model) + @console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__]) @console_ns.response(404, "Conversation not found") @login_required @account_initialization_required @setup_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) @@ -298,7 +238,10 @@ class ChatMessageListApi(Resource): history_messages = list(reversed(history_messages)) attach_message_extra_contents(history_messages) - return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) + return MessageInfiniteScrollPaginationResponse.model_validate( + InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/apps//feedbacks") @@ -468,13 +411,12 @@ class MessageApi(Resource): @console_ns.doc("get_message") @console_ns.doc(description="Get message details by ID") @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) - @console_ns.response(200, "Message retrieved successfully", message_detail_model) + @console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__]) @console_ns.response(404, "Message not found") @get_app_model @setup_required @login_required @account_initialization_required - @marshal_with(message_detail_model) def get(self, app_model, message_id: str): message_id = str(message_id) @@ -486,4 +428,4 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") attach_message_extra_contents([message]) - return message + return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 8ae6a78a62..6b402898e8 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,27 +1,26 @@ from datetime import datetime +from typing import Any from dateutil.parser import isoparse from flask import request -from flask_restx import Resource, marshal_with +from flask_restx import Resource from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.workflow_app_log_fields import ( - build_workflow_app_log_pagination_model, - build_workflow_archived_log_pagination_model, -) +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount from libs.login import login_required from models import App from models.model import AppMode from services.workflow_app_service import WorkflowAppService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowAppLogQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword for filtering logs") @@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel): raise ValueError("Invalid boolean value for detail") -console_ns.schema_model( - WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None -# Register model for flask_restx to avoid dict type issues in Swagger -workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) -workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] + + +register_schema_models( + console_ns, + WorkflowAppLogQuery, + WorkflowRunForLogResponse, + WorkflowRunForArchivedLogResponse, + WorkflowAppLogPartialResponse, + WorkflowArchivedLogPartialResponse, + WorkflowAppLogPaginationResponse, + WorkflowArchivedLogPaginationResponse, +) @console_ns.route("/apps//workflow-app-logs") @@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource): @console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) + @console_ns.response( + 200, + "Workflow app logs retrieved successfully", + console_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_app_log_pagination_model) def get(self, app_model: App): """ Get workflow app logs @@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return WorkflowAppLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/apps//workflow-archived-logs") @@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource): @console_ns.doc(description="Get workflow archived execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @console_ns.response( + 200, + "Workflow archived logs retrieved successfully", + console_ns.models[WorkflowArchivedLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_archived_log_pagination_model) def get(self, app_model: App): """ Get workflow archived logs @@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource): limit=args.limit, ) - return workflow_app_log_pagination + return WorkflowArchivedLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index c457684c15..a6715fa200 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,16 +1,17 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel +from flask_restx import Resource +from pydantic import BaseModel, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields +from fields.base import ResponseModel from libs.login import current_user, login_required from models.enums import AppTriggerStatus from models.model import Account, App, AppMode @@ -21,15 +22,6 @@ from ..app.wraps import get_app_model from ..wraps import account_initialization_required, edit_permission_required, setup_required logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - -trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields) - -triggers_list_fields_copy = triggers_list_fields.copy() -triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model)) -triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy) - -webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields) class Parser(BaseModel): @@ -41,10 +33,52 @@ class ParserEnable(BaseModel): enable_trigger: bool -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class WorkflowTriggerResponse(ResponseModel): + id: str + trigger_type: str + title: str + node_id: str + provider_name: str + icon: str + status: str + created_at: datetime | None = None + updated_at: datetime | None = None -console_ns.schema_model( - ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + @field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +class WorkflowTriggerListResponse(ResponseModel): + data: list[WorkflowTriggerResponse] + + +class WebhookTriggerResponse(ResponseModel): + id: str + webhook_id: str + webhook_url: str + webhook_debug_url: str + node_id: str + created_at: datetime | None = None + + @field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +register_schema_models( + console_ns, + Parser, + ParserEnable, + WorkflowTriggerResponse, + WorkflowTriggerListResponse, + WebhookTriggerResponse, ) @@ -57,7 +91,7 @@ class WebhookTriggerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(webhook_trigger_model) + @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource): if not webhook_trigger: raise NotFound("Webhook trigger not found for this node") - return webhook_trigger + return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//triggers") @@ -89,7 +123,7 @@ class AppTriggersApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(triggers_list_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__]) def get(self, app_model: App): """Get app triggers list""" assert isinstance(current_user, Account) @@ -118,7 +152,9 @@ class AppTriggersApi(Resource): else: trigger.icon = "" # type: ignore - return {"data": triggers} + return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//trigger-enable") @@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(trigger_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__]) def post(self, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) @@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource): else: trigger.icon = "" # type: ignore - return trigger + return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e6316bfd62..f7061f820f 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,3 +1,5 @@ +from typing import Any + from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator @@ -40,7 +42,7 @@ class ActivatePayload(BaseModel): class ActivationCheckResponse(BaseModel): is_valid: bool = Field(description="Whether token is valid") - data: dict | None = Field(default=None, description="Activation data if valid") + data: dict[str, Any] | None = Field(default=None, description="Activation data if valid") class ActivationResponse(BaseModel): diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index de8fe1c0e2..98d4ad9412 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1026,7 +1026,7 @@ class DocumentMetadataApi(DocumentResource): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) + metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index e62be13c2f..36a7a4bb0e 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,13 @@ -from flask_restx import Resource, fields +from __future__ import annotations -from controllers.common.schema import register_schema_model -from fields.hit_testing_fields import ( - child_chunk_fields, - document_fields, - files_fields, - hit_testing_record_fields, - segment_fields, -) +from datetime import datetime +from typing import Any + +from flask_restx import Resource +from pydantic import Field, field_validator + +from controllers.common.schema import register_schema_models +from fields.base import ResponseModel from libs.login import login_required from .. import console_ns @@ -18,39 +18,92 @@ from ..wraps import ( setup_required, ) -register_schema_model(console_ns, HitTestingPayload) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def _get_or_create_model(model_name: str, field_def): - """Get or create a flask_restx model to avoid dict type issues in Swagger.""" - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing +class HitTestingDocument(ResponseModel): + id: str | None = None + data_source_type: str | None = None + name: str | None = None + doc_type: str | None = None + doc_metadata: Any | None = None -# Register models for flask_restx to avoid dict type issues in Swagger -document_model = _get_or_create_model("HitTestingDocument", document_fields) +class HitTestingSegment(ResponseModel): + id: str | None = None + position: int | None = None + document_id: str | None = None + content: str | None = None + sign_content: str | None = None + answer: str | None = None + word_count: int | None = None + tokens: int | None = None + keywords: list[str] = Field(default_factory=list) + index_node_id: str | None = None + index_node_hash: str | None = None + hit_count: int | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + status: str | None = None + created_by: str | None = None + created_at: int | None = None + indexing_at: int | None = None + completed_at: int | None = None + error: str | None = None + stopped_at: int | None = None + document: HitTestingDocument | None = None -segment_fields_copy = segment_fields.copy() -segment_fields_copy["document"] = fields.Nested(document_model) -segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) -files_model = _get_or_create_model("HitTestingFile", files_fields) -hit_testing_record_fields_copy = hit_testing_record_fields.copy() -hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) -hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) -hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) -hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) +class HitTestingChildChunk(ResponseModel): + id: str | None = None + content: str | None = None + position: int | None = None + score: float | None = None -# Response model for hit testing API -hit_testing_response_fields = { - "query": fields.String, - "records": fields.List(fields.Nested(hit_testing_record_model)), -} -hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + +class HitTestingFile(ResponseModel): + id: str | None = None + name: str | None = None + size: int | None = None + extension: str | None = None + mime_type: str | None = None + source_url: str | None = None + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment | None = None + child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) + score: float | None = None + tsne_position: Any | None = None + files: list[HitTestingFile] = Field(default_factory=list) + summary: str | None = None + + +class HitTestingResponse(ResponseModel): + query: str + records: list[HitTestingRecord] = Field(default_factory=list) + + +register_schema_models( + console_ns, + HitTestingPayload, + HitTestingDocument, + HitTestingSegment, + HitTestingChildChunk, + HitTestingFile, + HitTestingRecord, + HitTestingResponse, +) @console_ns.route("/datasets//hit-testing") @@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) + @console_ns.response( + 200, + "Hit testing completed successfully", + model=console_ns.models[HitTestingResponse.__name__], + ) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required @@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 0740dd0e24..7dbb7220f4 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,21 +1,24 @@ import logging +from datetime import datetime from typing import Any from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from graphon.file import helpers as file_helpers +from pydantic import BaseModel, Field, computed_field, field_validator from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields +from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp +from models.model import IconType from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel): logger = logging.getLogger(__name__) -app_model = get_or_create_model("InstalledAppInfo", app_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) -installed_app_fields_copy = installed_app_fields.copy() -installed_app_fields_copy["app"] = fields.Nested(app_model) -installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy) -installed_app_list_fields_copy = installed_app_list_fields.copy() -installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model)) -installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy) +def _safe_primitive(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool, datetime)): + return value + return None + + +class InstalledAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + use_icon_as_answer_icon: bool | None = None + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_like(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class InstalledAppResponse(ResponseModel): + id: str + app: InstalledAppInfoResponse + app_owner_tenant_id: str + is_pinned: bool + last_used_at: int | None = None + editable: bool + uninstallable: bool + + @field_validator("app", mode="before") + @classmethod + def _normalize_app(cls, value: Any) -> Any: + if isinstance(value, dict): + return value + return { + "id": _safe_primitive(getattr(value, "id", "")) or "", + "name": _safe_primitive(getattr(value, "name", None)), + "mode": _safe_primitive(getattr(value, "mode", None)), + "icon_type": _safe_primitive(getattr(value, "icon_type", None)), + "icon": _safe_primitive(getattr(value, "icon", None)), + "icon_background": _safe_primitive(getattr(value, "icon_background", None)), + "use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)), + } + + @field_validator("last_used_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class InstalledAppListResponse(ResponseModel): + installed_apps: list[InstalledAppResponse] + + +register_schema_models( + console_ns, + InstalledAppCreatePayload, + InstalledAppUpdatePayload, + InstalledAppsListQuery, + InstalledAppInfoResponse, + InstalledAppResponse, + InstalledAppListResponse, +) @console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required - @marshal_with(installed_app_list_model) + @console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__]) def get(self): query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() @@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource): ) ) - return {"installed_apps": installed_app_list} + return InstalledAppListResponse.model_validate( + {"installed_apps": installed_app_list}, from_attributes=True + ).model_dump(mode="json") @login_required @account_initialization_required diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index c9920c97cf..55bd679b48 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,66 +1,83 @@ +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from constants.languages import languages -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required -from libs.helper import AppIconUrlField +from fields.base import ResponseModel +from libs.helper import build_icon_url from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService -app_fields = { - "id": fields.String, - "name": fields.String, - "mode": fields.String, - "icon": fields.String, - "icon_type": fields.String, - "icon_url": AppIconUrlField, - "icon_background": fields.String, -} - -app_model = get_or_create_model("RecommendedAppInfo", app_fields) - -recommended_app_fields = { - "app": fields.Nested(app_model, attribute="app"), - "app_id": fields.String, - "description": fields.String(attribute="description"), - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "category": fields.String, - "position": fields.Integer, - "is_listed": fields.Boolean, - "can_trial": fields.Boolean, -} - -recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields) - -recommended_app_list_fields = { - "recommended_apps": fields.List(fields.Nested(recommended_app_model)), - "categories": fields.List(fields.String), -} - -recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields) - class RecommendedAppsQuery(BaseModel): language: str | None = Field(default=None) -console_ns.schema_model( - RecommendedAppsQuery.__name__, - RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"), +class RecommendedAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon: str | None = None + icon_type: str | None = None + icon_background: str | None = None + + @staticmethod + def _normalize_enum_like(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> str | None: + return cls._normalize_enum_like(value) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return build_icon_url(self.icon_type, self.icon) + + +class RecommendedAppResponse(ResponseModel): + app: RecommendedAppInfoResponse | None = None + app_id: str + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + category: str | None = None + position: int | None = None + is_listed: bool | None = None + can_trial: bool | None = None + + +class RecommendedAppListResponse(ResponseModel): + recommended_apps: list[RecommendedAppResponse] + categories: list[str] + + +register_schema_models( + console_ns, + RecommendedAppsQuery, + RecommendedAppInfoResponse, + RecommendedAppResponse, + RecommendedAppListResponse, ) @console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) + @console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__]) @login_required @account_initialization_required - @marshal_with(recommended_app_list_model) def get(self): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource): else: language_prefix = languages[0] - return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) + return RecommendedAppListResponse.model_validate( + RecommendedAppService.get_recommended_apps_and_categories(language_prefix), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/explore/apps/") diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index efa46c9779..7a6356d052 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,15 +1,18 @@ +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE -from fields.api_based_extension_fields import api_based_extension_fields +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService -from ..common.schema import register_schema_models +from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models from . import console_ns from .wraps import account_initialization_required, setup_required @@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel): api_key: str = Field(description="API key for authentication") -register_schema_models(console_ns, APIBasedExtensionPayload) +class CodeBasedExtensionResponse(ResponseModel): + module: str = Field(description="Module name") + data: Any = Field(description="Extension data") -api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) +def _mask_api_key(api_key: str) -> str: + if not api_key: + return api_key + if len(api_key) <= 8: + return api_key[0] + "******" + api_key[-1] + return api_key[:3] + "******" + api_key[-3:] -api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class APIBasedExtensionResponse(ResponseModel): + id: str + name: str + api_endpoint: str + api_key: str + created_at: int | None = None + + @field_validator("api_key", mode="before") + @classmethod + def _normalize_api_key(cls, value: str) -> str: + return _mask_api_key(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse) +console_ns.schema_model( + "APIBasedExtensionListResponse", + TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]: + return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json") @console_ns.route("/code-based-extension") @@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "CodeBasedExtensionResponse", - {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, - ), + console_ns.models[CodeBasedExtensionResponse.__name__], ) @setup_required @login_required @@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource): def get(self): query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} + return CodeBasedExtensionResponse( + module=query.module, + data=CodeBasedExtensionService.get_code_based_extension(query.module), + ).model_dump(mode="json") @console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): @console_ns.doc("get_api_based_extensions") @console_ns.doc(description="Get all API-based extensions for current tenant") - @console_ns.response(200, "Success", api_based_extension_list_model) + @console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self): _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + return [ + _serialize_api_based_extension(extension) + for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + ] @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(201, "Extension created successfully", api_based_extension_model) + @console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self): payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() @@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource): api_key=payload.api_key, ) - return APIBasedExtensionService.save(extension_data) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data)) @console_ns.route("/api-based-extension/") @@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("get_api_based_extension") @console_ns.doc(description="Get API-based extension by ID") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.response(200, "Success", api_based_extension_model) + @console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self, id): api_based_extension_id = str(id) _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + return _serialize_api_based_extension( + APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + ) @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(200, "Extension updated successfully", api_based_extension_model) + @console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self, id): api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() @@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource): if payload.api_key != HIDDEN_VALUE: extension_data_from_db.api_key = payload.api_key - return APIBasedExtensionService.save(extension_data_from_db) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db)) @console_ns.doc("delete_api_based_extension") @console_ns.doc(description="Delete API-based extension") diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 60f712e476..59dd29fdac 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -35,22 +35,24 @@ def plugin_permission_required( return view(*args, **kwargs) if install_required: - if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: - raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.install_permission: + case TenantPluginPermission.InstallPermission.NOBODY: raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: - pass + case TenantPluginPermission.InstallPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.InstallPermission.EVERYONE: + pass if debug_required: - if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: - raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.debug_permission: + case TenantPluginPermission.DebugPermission.NOBODY: raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: - pass + case TenantPluginPermission.DebugPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.DebugPermission.EVERYONE: + pass return view(*args, **kwargs) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index c35006a7ee..582c38052e 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -5,7 +5,7 @@ from typing import Any, Literal import pytz from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select @@ -37,9 +37,10 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db +from fields.base import ResponseModel from fields.member_fields import Account as AccountResponse from libs.datetime_utils import naive_utc_now -from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode from models.account import AccountStatus, InvitationCodeStatus @@ -178,17 +179,57 @@ def _serialize_account(account) -> dict[str, Any]: return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") -integrate_fields = { - "provider": fields.String, - "created_at": TimestampField, - "is_bound": fields.Boolean, - "link": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -integrate_model = console_ns.model("AccountIntegrate", integrate_fields) -integrate_list_model = console_ns.model( - "AccountIntegrateList", - {"data": fields.List(fields.Nested(integrate_model))}, + +class AccountIntegrateResponse(ResponseModel): + provider: str + created_at: int | None = None + is_bound: bool + link: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountIntegrateListResponse(ResponseModel): + data: list[AccountIntegrateResponse] + + +class EducationVerifyResponse(ResponseModel): + token: str | None = None + + +class EducationStatusResponse(ResponseModel): + result: bool | None = None + is_student: bool | None = None + expire_at: int | None = None + allow_refresh: bool | None = None + + @field_validator("expire_at", mode="before") + @classmethod + def _normalize_expire_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class EducationAutocompleteResponse(ResponseModel): + data: list[str] = Field(default_factory=list) + curr_page: int | None = None + has_next: bool | None = None + + +register_schema_models( + console_ns, + AccountIntegrateResponse, + AccountIntegrateListResponse, + EducationVerifyResponse, + EducationStatusResponse, + EducationAutocompleteResponse, ) @@ -359,7 +400,7 @@ class AccountIntegrateApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__]) def get(self): account, _ = current_account_with_tenant() @@ -395,7 +436,9 @@ class AccountIntegrateApi(Resource): } ) - return {"data": integrate_data} + return AccountIntegrateListResponse( + data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data] + ).model_dump(mode="json") @console_ns.route("/account/delete/verify") @@ -447,31 +490,22 @@ class AccountDeleteUpdateFeedbackApi(Resource): @console_ns.route("/account/education/verify") class EducationVerifyApi(Resource): - verify_fields = { - "token": fields.String, - } - @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(verify_fields) + @console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - return BillingService.EducationIdentity.verify(account.id, account.email) + return EducationVerifyResponse.model_validate( + BillingService.EducationIdentity.verify(account.id, account.email) or {} + ).model_dump(mode="json") @console_ns.route("/account/education") class EducationApi(Resource): - status_fields = { - "result": fields.Boolean, - "is_student": fields.Boolean, - "expire_at": TimestampField, - "allow_refresh": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationActivatePayload.__name__]) @setup_required @login_required @@ -491,37 +525,33 @@ class EducationApi(Resource): @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(status_fields) + @console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - res = BillingService.EducationIdentity.status(account.id) + res = BillingService.EducationIdentity.status(account.id) or {} # convert expire_at to UTC timestamp from isoformat if res and "expire_at" in res: res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc) - return res + return EducationStatusResponse.model_validate(res).model_dump(mode="json") @console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): - data_fields = { - "data": fields.List(fields.String), - "curr_page": fields.Integer, - "has_next": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__]) @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(data_fields) + @console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__]) def get(self): payload = request.args.to_dict(flat=True) args = EducationAutocompleteQuery.model_validate(payload) - return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) + return EducationAutocompleteResponse.model_validate( + BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {} + ).model_dump(mode="json") @console_ns.route("/account/change-email") diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 9182dbb510..f8f95304f0 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource): class ParserValidate(BaseModel): model: str model_type: ModelType - credentials: dict + credentials: dict[str, Any] console_ns.schema_model( diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 42874e6033..565099db61 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,9 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -26,6 +27,7 @@ from controllers.console.wraps import ( ) from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from fields.base import ResponseModel from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantCustomConfigDict, TenantStatus @@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel): name: str +class TenantInfoResponse(ResponseModel): + id: str + name: str | None = None + plan: str | None = None + status: str | None = None + created_at: int | None = None + role: str | None = None + in_trial: bool | None = None + trial_end_reason: str | None = None + custom_config: dict | None = None + trial_credits: int | None = None + trial_credits_used: int | None = None + next_credit_reset_date: int | None = None + + @field_validator("plan", "status", "trial_end_reason", mode="before") + @classmethod + def _normalize_enum_like(cls, value): + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None): + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @@ -66,6 +99,7 @@ reg(WorkspaceListQuery) reg(SwitchWorkspacePayload) reg(WorkspaceCustomConfigPayload) reg(WorkspaceInfoPayload) +reg(TenantInfoResponse) provider_fields = { "provider_name": fields.String, @@ -180,7 +214,7 @@ class TenantApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tenant_fields) + @console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__]) def post(self): if request.path == "/info": logger.warning("Deprecated URL /info was used.") @@ -200,7 +234,13 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - return WorkspaceService.get_tenant_info(tenant), 200 + return ( + TenantInfoResponse.model_validate( + WorkspaceService.get_tenant_info(tenant), + from_attributes=True, + ).model_dump(mode="json"), + 200, + ) @console_ns.route("/workspaces/switch") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 1ec289e2a2..50851aea08 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,7 +1,9 @@ +from datetime import datetime from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, TypeAdapter, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound @@ -14,14 +16,12 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from fields.conversation_fields import ( ConversationInfiniteScrollPagination, SimpleConversation, ) -from fields.conversation_variable_fields import ( - build_conversation_variable_infinite_scroll_pagination_model, - build_conversation_variable_model, -) from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel): value: Any +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type().value) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type().value) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + register_schema_models( service_api_ns, ConversationListQuery, ConversationRenamePayload, ConversationVariablesQuery, ConversationVariableUpdatePayload, + ConversationVariableResponse, + ConversationVariableInfiniteScrollPaginationResponse, ) @@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource): 404: "Conversation not found", } ) + @service_api_ns.response( + 200, + "Variables retrieved successfully", + service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): """List all variables for a conversation. @@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource): last_id = str(query_args.last_id) if query_args.last_id else None try: - return ConversationService.get_conversational_variable( + pagination = ConversationService.get_conversational_variable( app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name ) + return ConversationVariableInfiniteScrollPaginationResponse.model_validate( + pagination, from_attributes=True + ).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource): 404: "Conversation or variable not found", } ) + @service_api_ns.response( + 200, + "Variable updated successfully", + service_api_ns.models[ConversationVariableResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) def put(self, app_model: App, end_user: EndUser, c_id, variable_id): """Update a conversation variable's value. @@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource): payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.update_conversation_variable( + variable = ConversationService.update_conversation_variable( app_model, conversation_id, variable_id, end_user, payload.value ) + return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationVariableNotExistsError: diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index e0a64ffe26..d5544ff473 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,13 +1,15 @@ import logging +from collections.abc import Mapping +from datetime import datetime from typing import Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource, fields from graphon.enums import WorkflowExecutionStatus from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -33,9 +35,10 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount from libs import helper -from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +def _enum_value(value): + return getattr(value, "value", value) + + class WorkflowRunStatusField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value + return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: + status = _enum_value(obj.status) + if status == WorkflowExecutionStatus.PAUSED.value: return {} outputs = obj.outputs_dict return outputs or {} -workflow_run_fields = { - "id": fields.String, - "workflow_id": fields.String, - "status": WorkflowRunStatusField, - "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, - "error": fields.String, - "total_steps": fields.Integer, - "total_tokens": fields.Integer, - "created_at": TimestampField, - "finished_at": OptionalTimestampField, - "elapsed_time": fields.Float, -} +class WorkflowRunResponse(ResponseModel): + id: str + workflow_id: str + status: str + inputs: dict | list | str | int | float | bool | None = None + outputs: dict = Field(default_factory=dict) + error: str | None = None + total_steps: int | None = None + total_tokens: int | None = None + created_at: int | None = None + finished_at: int | None = None + elapsed_time: float | int | None = None + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_workflow_run_model(api_or_ns: Namespace): - """Build the workflow run model for the API or Namespace.""" - return api_or_ns.model("WorkflowRun", workflow_run_fields) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | int | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", "triggered_from", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: dict | list | str | int | float | bool | None = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_from", "created_by_role", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +register_schema_models( + service_api_ns, + WorkflowRunResponse, + WorkflowRunForLogResponse, + WorkflowAppLogPartialResponse, + WorkflowAppLogPaginationResponse, +) + + +def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict: + status = _enum_value(workflow_run.status) + raw_outputs = workflow_run.outputs_dict + if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: + outputs: dict = {} + elif isinstance(raw_outputs, dict): + outputs = raw_outputs + elif isinstance(raw_outputs, Mapping): + outputs = dict(raw_outputs) + else: + outputs = {} + return WorkflowRunResponse.model_validate( + { + "id": workflow_run.id, + "workflow_id": workflow_run.workflow_id, + "status": status, + "inputs": workflow_run.inputs, + "outputs": outputs, + "error": workflow_run.error, + "total_steps": workflow_run.total_steps, + "total_tokens": workflow_run.total_tokens, + "created_at": workflow_run.created_at, + "finished_at": workflow_run.finished_at, + "elapsed_time": workflow_run.elapsed_time, + } + ).model_dump(mode="json") + + +def _serialize_workflow_log_pagination(pagination) -> dict: + return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json") @service_api_ns.route("/workflows/run/") @@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) + @service_api_ns.response( + 200, + "Workflow run details retrieved successfully", + service_api_ns.models[WorkflowRunResponse.__name__], + ) def get(self, app_model: App, workflow_run_id: str): """Get a workflow task running detail. @@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource): ) if not workflow_run: raise NotFound("Workflow run not found.") - return workflow_run + return _serialize_workflow_run(workflow_run) @service_api_ns.route("/workflows/run") @@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) + @service_api_ns.response( + 200, + "Logs retrieved successfully", + service_api_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) def get(self, app_model: App): """Get workflow app logs. @@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return _serialize_workflow_log_pagination(workflow_app_log_pagination) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 46c1f1230d..8cccd2be6d 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -1,7 +1,7 @@ import json import re from collections.abc import Generator -from typing import Union +from typing import Any, Union from graphon.model_runtime.entities.llm_entities import LLMResultChunk @@ -11,7 +11,7 @@ from core.agent.entities import AgentScratchpadUnit class CotAgentOutputParser: @classmethod def handle_react_stream_output( - cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any] ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: action_name = None diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 90aa7b5fd4..8d25863a91 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None features: list[AgentFeature] | None = None meta_version: str | None = None # pydantic configs diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 7d1b11c008..c8ec7cb44d 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id: str, config: dict, only_structure_validate: bool = False - ) -> tuple[dict, list[str]]: + cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False + ) -> tuple[dict[str, Any], list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 981bd26961..9d980e5ca3 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -41,7 +41,7 @@ class ModelConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for model config diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 5c9bc43992..fe2702ed69 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 0c146c388f..731c6ee12e 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 6e5a86505c..406d07927e 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -24,7 +24,7 @@ class AppGenerateResponseConverter(ABC): return cls.convert_blocking_full_response(response) else: - def _generate_full_response() -> Generator[dict | str, Any, None]: + def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_full_response(response) return _generate_full_response() @@ -33,7 +33,7 @@ class AppGenerateResponseConverter(ABC): return cls.convert_blocking_simple_response(response) else: - def _generate_simple_response() -> Generator[dict | str, Any, None]: + def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_simple_response(response) return _generate_simple_response() @@ -52,14 +52,14 @@ class AppGenerateResponseConverter(ABC): @abstractmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @abstractmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index f23ee7f89f..3d0375151d 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index a4f574642d..71886b39ba 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index a77a978946..02b3160b7c 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index c64f44a603..c69826cbef 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index f1b8b08eaa..96387133b1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -682,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None): invoke_from = self._application_generate_entity.invoke_from - if invoke_from == InvokeFrom.SERVICE_API: - created_from = WorkflowAppLogCreatedFrom.SERVICE_API - elif invoke_from == InvokeFrom.EXPLORE: - created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP - elif invoke_from == InvokeFrom.WEB_APP: - created_from = WorkflowAppLogCreatedFrom.WEB_APP - else: - # not save log for debugging - return + match invoke_from: + case InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + case InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + case InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION: + # not save log for debugging + return if not workflow_run_id: return diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index f9e6099049..01139d07e2 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast import httpx @@ -14,7 +14,7 @@ class APIBasedExtensionRequestor: self.api_endpoint = api_endpoint self.api_key = api_key - def request(self, point: APIBasedExtensionPoint, params: dict): + def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]: """ Request the api. @@ -49,4 +49,4 @@ class APIBasedExtensionRequestor: if response.status_code != 200: raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}") - return cast(dict, response.json()) + return cast(dict[str, Any], response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index b79dbeb7e0..c08e319aac 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -21,8 +21,8 @@ class ExtensionModule(StrEnum): class ModuleExtension(BaseModel): extension_class: Any | None = None name: str - label: dict | None = None - form_schema: list | None = None + label: dict[str, Any] | None = None + form_schema: list[dict[str, Any]] | None = None builtin: bool = True position: int | None = None diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index f7a64cea1b..f404aa7286 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -13,7 +13,7 @@ class ExternalDataToolFactory: ) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]): + def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None: """ Validate the incoming form config data. diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 60f5434bc1..f8f56e12d2 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,3 +1,5 @@ +from typing import Any + from flask import Flask from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel @@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: dict | None = None + credentials: dict[str, Any] | None = None quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 9aaf85dc0f..8d0a8b99b4 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -737,7 +737,7 @@ class IndexingRunner: def _update_document_index_status( document_id: str, after_indexing_status: IndexingStatus, - extra_update_params: dict[Any, Any] | None = None, + extra_update_params: Mapping[Any, Any] | None = None, ): """ Update the document indexing status. @@ -764,7 +764,7 @@ class IndexingRunner: db.session.commit() @staticmethod - def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]): + def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]): """ Update the document segment by document id. """ diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index aa258c9f89..c43c0274cd 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,7 +2,7 @@ import json import logging import re from collections.abc import Sequence -from typing import Protocol, TypedDict, cast +from typing import Any, Protocol, TypedDict, cast import json_repair from graphon.enums import WorkflowNodeExecutionMetadataKey @@ -533,7 +533,7 @@ class LLMGenerator: def __instruction_modify_common( tenant_id: str, model_config: ModelConfig, - last_run: dict | None, + last_run: dict[str, Any] | None, current: str | None, error_message: str | None, instruction: str, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 9bdca1e83b..a8ad7c9179 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -202,7 +202,7 @@ def _handle_native_json_schema( structured_output_schema: Mapping, model_parameters: dict[str, Any], rules: list[ParameterRule], -): +) -> dict[str, Any]: """ Handle structured output for models with native JSON schema support. @@ -224,7 +224,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]): +def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None: """ Set the appropriate response format parameter based on model rules. @@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema return {"schema": processed_schema, "name": "llm_response"} -def remove_additional_properties(schema: dict[str, Any]): +def remove_additional_properties(schema: dict[str, Any]) -> None: """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. @@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict[str, Any]): remove_additional_properties(item) -def convert_boolean_to_string(schema: dict): +def convert_boolean_to_string(schema: dict[str, Any]) -> None: """ Convert boolean type specifications to string in JSON schema. diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index c715b9171c..a4b24ff849 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -1,6 +1,7 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from typing import Any from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -226,7 +227,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): # invoke model response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) - def handle() -> Generator[dict, None, None]: + def handle() -> Generator[dict[str, Any], None, None]: for chunk in response: yield {"result": hexlify(chunk).decode("utf-8")} diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d4e17613a2..dc8391a6a5 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -313,7 +313,7 @@ class SimplePromptTransform(PromptTransform): return prompt_message - def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str): + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict[str, Any]: """ Get simple prompt rule. :param app_mode: app mode @@ -325,7 +325,7 @@ class SimplePromptTransform(PromptTransform): # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return cast(dict, prompt_file_contents[prompt_file_name]) + return cast(dict[str, Any], prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -338,7 +338,7 @@ class SimplePromptTransform(PromptTransform): # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return cast(dict, content) + return cast(dict[str, Any], content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index f5f5f541da..9f1c73ec88 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -106,7 +106,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" # use doc embedding cache or store if not exists multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index ab190d2c42..7ae5c09ab7 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -11,7 +11,7 @@ class Embeddings(ABC): raise NotImplementedError @abstractmethod - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" raise NotImplementedError diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 449be6a448..fbd2a6db93 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -95,9 +95,9 @@ class ExtractProcessor: ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE: with tempfile.TemporaryDirectory() as temp_dir: + upload_file = extract_setting.upload_file if not file_path: - assert extract_setting.upload_file is not None, "upload_file is required" - upload_file: UploadFile = extract_setting.upload_file + assert upload_file is not None, "upload_file is required" suffix = Path(upload_file.key).suffix # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore @@ -113,6 +113,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( @@ -123,6 +124,7 @@ class ExtractProcessor: elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".doc": extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -149,12 +151,14 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 89bdd56a6c..556158cf00 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -174,21 +174,25 @@ class FirecrawlApp: return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}" def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.post(url, headers=headers, json=data) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.get(url, headers=headers) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _handle_error(self, response, action): diff --git a/api/core/rag/retrieval/output_parser/react_output.py b/api/core/rag/retrieval/output_parser/react_output.py index 9a14d41716..29abae4280 100644 --- a/api/core/rag/retrieval/output_parser/react_output.py +++ b/api/core/rag/retrieval/output_parser/react_output.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import Any, NamedTuple, Union @dataclass @@ -10,7 +10,7 @@ class ReactAction: tool: str """The name of the Tool to execute.""" - tool_input: Union[str, dict] + tool_input: Union[str, dict[str, Any]] """The input to pass in to the Tool.""" log: str """Additional information to log about the action.""" @@ -19,7 +19,7 @@ class ReactAction: class ReactFinish(NamedTuple): """The final return value of an ReactFinish.""" - return_values: dict + return_values: dict[str, Any] """Dictionary of return values.""" log: str """Additional information to log about the return value""" diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index dd280cdf6a..9b223075d8 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Union +from typing import Any, Union from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool @@ -139,7 +139,7 @@ class ReactMultiDatasetRouter: def _invoke_llm( self, - completion_param: dict, + completion_param: dict[str, Any], model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: list[str], diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 8977611f93..7f2117e2dd 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -63,7 +63,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict[str, Any]] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 6e26664ac2..e267c1abd9 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -254,7 +254,7 @@ def resolve_dify_schema_refs( return resolver.resolve(schema) -def _remove_metadata_fields(schema: dict) -> dict: +def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]: """ Remove metadata fields from schema that shouldn't be included in resolved output diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 10710c4376..4e07b7157a 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import Any from pydantic import BaseModel, Field @@ -26,6 +27,6 @@ class ApiToolBundle(BaseModel): # icon icon: str | None = None # openapi operation - openapi: dict + openapi: dict[str, Any] # output schema output_schema: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 1afaa9cfaf..d060fa8b49 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -47,7 +47,7 @@ class ToolEngine: @staticmethod def agent_invoke( tool: Tool, - tool_parameters: Union[str, dict], + tool_parameters: Union[str, dict[str, Any]], user_id: str, tenant_id: str, message: Message, @@ -85,7 +85,7 @@ class ToolEngine: invocation_meta_dict: dict[str, ToolInvokeMeta] = {} def message_callback( - invocation_meta_dict: dict[str, Any], + invocation_meta_dict: dict[str, ToolInvokeMeta], messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None], ): for message in messages: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 58190d1089..d8969a3391 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,5 @@ from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -19,10 +20,18 @@ class ToolLabelManager: return list(set(tool_labels)) @classmethod - def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + def update_tool_labels( + cls, controller: ToolProviderController, labels: list[str], session: Session | None = None + ) -> None: """ Update tool labels + + :param controller: tool provider controller + :param labels: list of tool labels + :param session: database session, if None, a new session will be created + :return: None """ + labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): @@ -30,26 +39,46 @@ class ToolLabelManager: else: raise ValueError("Unsupported tool type") + if session is not None: + cls._update_tool_labels_logics(session, provider_id, controller, labels) + else: + with sessionmaker(db.engine).begin() as _session: + cls._update_tool_labels_logics(_session, provider_id, controller, labels) + + @classmethod + def _update_tool_labels_logics( + cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str] + ) -> None: + """ + Update tool labels logics + + :param session: database session + :param provider_id: tool provider ID + :param controller: tool provider controller + :param labels: list of tool labels + :return: None + """ + # delete old labels - db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) + _ = session.execute( + delete(ToolLabelBinding).where( + ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type + ) + ) # insert new labels for label in labels: - db.session.add( - ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type, - label_name=label, - ) - ) - - db.session.commit() + session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label)) @classmethod def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: """ Get tool labels + + :param controller: tool provider controller + :return: list of tool labels (str) """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): @@ -60,9 +89,11 @@ class ToolLabelManager: ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type, ) - labels = db.session.scalars(stmt).all() - return list(labels) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + labels: list[str] = list(_session.scalars(stmt).all()) + + return labels @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: @@ -78,16 +109,22 @@ class ToolLabelManager: if not tool_providers: return {} + provider_ids: list[str] = [] + provider_types: set[str] = set() + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - - provider_ids = [] - for controller in tool_providers: - assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) + provider_types.add(controller.provider_type) - labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() + labels: list[ToolLabelBinding] = [] + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + stmt = select(ToolLabelBinding).where( + ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types)) + ) + labels = list(_session.scalars(stmt).all()) tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 6a189fa6aa..0d1dc7273b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -39,7 +39,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset_id: str user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity - inputs: dict + inputs: dict[str, Any] @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 552fbab1a4..7c4f8ee03a 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -277,7 +277,7 @@ class WorkflowTool(Tool): session.expunge(app) return app - def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]: + def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, str | None]]]: """ transform the tool parameters @@ -355,7 +355,7 @@ class WorkflowTool(Tool): return result, files - def _update_file_mapping(self, file_dict: dict[str, Any]): + def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) match transfer_method: diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py index 9cda0bf90a..c564ace584 100644 --- a/api/enterprise/telemetry/metric_handler.py +++ b/api/enterprise/telemetry/metric_handler.py @@ -329,7 +329,7 @@ class EnterpriseMetricHandler: return include_content = exporter.include_content - attrs: dict = { + attrs: dict[str, Any] = { "dify.message.id": payload.get("message_id"), "dify.tenant_id": envelope.tenant_id, "dify.event.id": envelope.event_id, diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 86b0550187..340f514fcc 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import normalize_redis_key_prefix class _CelerySentinelKwargsDict(TypedDict): @@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict): password: str | None -class CelerySentinelTransportDict(TypedDict): +class CelerySentinelTransportDict(TypedDict, total=False): master_name: str | None sentinel_kwargs: _CelerySentinelKwargsDict + global_keyprefix: str class CelerySSLOptionsDict(TypedDict): @@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None: def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" + transport_options: CelerySentinelTransportDict | dict[str, Any] if dify_config.CELERY_USE_SENTINEL: - return CelerySentinelTransportDict( + transport_options = CelerySentinelTransportDict( master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, sentinel_kwargs=_CelerySentinelKwargsDict( socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, password=dify_config.CELERY_SENTINEL_PASSWORD, ), ) - return {} + else: + transport_options = {} + + global_keyprefix = get_celery_redis_global_keyprefix() + if global_keyprefix: + transport_options["global_keyprefix"] = global_keyprefix + + return transport_options + + +def get_celery_redis_global_keyprefix() -> str | None: + """Return the Redis transport prefix for Celery when namespace isolation is enabled.""" + normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + if not normalized_prefix: + return None + return f"{normalized_prefix}:" def init_app(app: DifyApp) -> Celery: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 20f05b8b9e..9f7f73765e 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import Any, Union, cast import redis from redis import RedisError @@ -18,17 +18,26 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import ( + normalize_redis_key_prefix, + serialize_redis_name, + serialize_redis_name_arg, + serialize_redis_name_args, +) from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel -if TYPE_CHECKING: - from redis.lock import Lock - logger = logging.getLogger(__name__) +_normalize_redis_key_prefix = normalize_redis_key_prefix +_serialize_redis_name = serialize_redis_name +_serialize_redis_name_arg = serialize_redis_name_arg +_serialize_redis_name_args = serialize_redis_name_args + + class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global @@ -59,68 +68,148 @@ class RedisClientWrapper: if self._client is None: self._client = client - if TYPE_CHECKING: - # Type hints for IDE support and static analysis - # These are not executed at runtime but provide type information - def get(self, name: str | bytes) -> Any: ... - - def set( - self, - name: str | bytes, - value: Any, - ex: int | None = None, - px: int | None = None, - nx: bool = False, - xx: bool = False, - keepttl: bool = False, - get: bool = False, - exat: int | None = None, - pxat: int | None = None, - ) -> Any: ... - - def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... - def setnx(self, name: str | bytes, value: Any) -> Any: ... - def delete(self, *names: str | bytes) -> Any: ... - def incr(self, name: str | bytes, amount: int = 1) -> Any: ... - def expire( - self, - name: str | bytes, - time: int | timedelta, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def lock( - self, - name: str, - timeout: float | None = None, - sleep: float = 0.1, - blocking: bool = True, - blocking_timeout: float | None = None, - thread_local: bool = True, - ) -> Lock: ... - def zadd( - self, - name: str | bytes, - mapping: dict[str | bytes | int | float, float | int | str | bytes], - nx: bool = False, - xx: bool = False, - ch: bool = False, - incr: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... - def zcard(self, name: str | bytes) -> Any: ... - def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... - def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... - - def __getattr__(self, item: str) -> Any: + def _require_client(self) -> redis.Redis | RedisCluster: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") - return getattr(self._client, item) + return self._client + + def _get_prefix(self) -> str: + return dify_config.REDIS_KEY_PREFIX + + def get(self, name: str | bytes) -> Any: + return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix())) + + def set( + self, + name: str | bytes, + value: Any, + ex: int | None = None, + px: int | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: int | None = None, + pxat: int | None = None, + ) -> Any: + return self._require_client().set( + _serialize_redis_name_arg(name, self._get_prefix()), + value, + ex=ex, + px=px, + nx=nx, + xx=xx, + keepttl=keepttl, + get=get, + exat=exat, + pxat=pxat, + ) + + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: + return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value) + + def setnx(self, name: str | bytes, value: Any) -> Any: + return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value) + + def delete(self, *names: str | bytes) -> Any: + return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix())) + + def incr(self, name: str | bytes, amount: int = 1) -> Any: + return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount) + + def expire( + self, + name: str | bytes, + time: int | timedelta, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().expire( + _serialize_redis_name_arg(name, self._get_prefix()), + time, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) + + def exists(self, *names: str | bytes) -> Any: + return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix())) + + def ttl(self, name: str | bytes) -> Any: + return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix())) + + def getdel(self, name: str | bytes) -> Any: + return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix())) + + def lock( + self, + name: str, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + thread_local: bool = True, + ) -> Any: + return self._require_client().lock( + _serialize_redis_name(name, self._get_prefix()), + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any: + return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs) + + def hgetall(self, name: str | bytes) -> Any: + return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix())) + + def hdel(self, name: str | bytes, *keys: str | bytes) -> Any: + return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys) + + def hlen(self, name: str | bytes) -> Any: + return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix())) + + def zadd( + self, + name: str | bytes, + mapping: dict[str | bytes | int | float, float | int | str | bytes], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().zadd( + _serialize_redis_name_arg(name, self._get_prefix()), + cast(Any, mapping), + nx=nx, + xx=xx, + ch=ch, + incr=incr, + gt=gt, + lt=lt, + ) + + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: + return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max) + + def zcard(self, name: str | bytes) -> Any: + return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix())) + + def pubsub(self) -> PubSub: + return self._require_client().pubsub() + + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: + return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint) + + def __getattr__(self, item: str) -> Any: + return getattr(self._require_client(), item) redis_client: RedisClientWrapper = RedisClientWrapper() diff --git a/api/extensions/redis_names.py b/api/extensions/redis_names.py new file mode 100644 index 0000000000..9e63416daf --- /dev/null +++ b/api/extensions/redis_names.py @@ -0,0 +1,32 @@ +from configs import dify_config + + +def normalize_redis_key_prefix(prefix: str | None) -> str: + """Normalize the configured Redis key prefix for consistent runtime use.""" + if prefix is None: + return "" + return prefix.strip() + + +def get_redis_key_prefix() -> str: + """Read and normalize the current Redis key prefix from config.""" + return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + + +def serialize_redis_name(name: str, prefix: str | None = None) -> str: + """Convert a logical Redis name into the physical name used in Redis.""" + normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix) + if not normalized_prefix: + return name + return f"{normalized_prefix}:{name}" + + +def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes: + """Prefix string Redis names while preserving bytes inputs unchanged.""" + if isinstance(name, bytes): + return name + return serialize_redis_name(name, prefix) + + +def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]: + return tuple(serialize_redis_name_arg(name, prefix) for name in names) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 96f5915ff0..cd7f7db295 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -2,6 +2,7 @@ import logging import os from collections.abc import Generator from pathlib import Path +from typing import Any import opendal from dotenv import dotenv_values @@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars: dict = dotenv_values(env_file_path) or {} + file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index 36aa1cd3e8..b76a23eb3c 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -32,12 +33,13 @@ class Topic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.publish(self._topic, payload) + self._client.publish(self._redis_topic, payload) def as_subscriber(self) -> Subscriber: return self @@ -46,7 +48,7 @@ class Topic: return _RedisSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index dddc92d099..919d8d622e 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -30,12 +31,13 @@ class ShardedTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr] + self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr] def as_subscriber(self) -> Subscriber: return self @@ -44,7 +46,7 @@ class ShardedTopic: return _RedisShardedSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 983f785027..55ff6cd4f9 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -6,6 +6,7 @@ import threading from collections.abc import Iterator from typing import Self +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from libs.broadcast_channel.exc import SubscriptionClosedError from redis import Redis, RedisCluster @@ -35,7 +36,7 @@ class StreamsTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): self._client = redis_client self._topic = topic - self._key = f"stream:{topic}" + self._key = serialize_redis_name(f"stream:{topic}") self._retention_seconds = retention_seconds self.max_length = 5000 diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index ca8956e397..b5fe38342a 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -103,7 +103,10 @@ class DbMigrationAutoRenewLock: timeout=self._ttl_seconds, thread_local=False, ) - acquired = bool(self._lock.acquire(*args, **kwargs)) + lock = self._lock + if lock is None: + raise RuntimeError("Redis lock initialization failed.") + acquired = bool(lock.acquire(*args, **kwargs)) self._acquired = acquired if acquired: self._start_heartbeat() diff --git a/api/libs/helper.py b/api/libs/helper.py index f28de92927..69bd483515 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -120,10 +120,22 @@ class AppIconUrlField(fields.Raw): obj = obj["app"] if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: - return file_helpers.get_signed_file_url(obj.icon) + return build_icon_url(obj.icon_type, obj.icon) return None +def build_icon_url(icon_type: Any, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + + from models.model import IconType + + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) + + class AvatarUrlField(fields.Raw): def output(self, key, obj, **kwargs): if obj is None: diff --git a/api/models/dataset.py b/api/models/dataset.py index a48afa7ea7..4540c29206 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1551,7 +1551,7 @@ class PipelineBuiltInTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False) @@ -1584,7 +1584,7 @@ class PipelineCustomizedTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1657,7 +1657,7 @@ class DocumentPipelineExecutionLog(TypeBase): datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) datasource_info: Mapped[str] = mapped_column(LongText, nullable=False) datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/model.py b/api/models/model.py index 47b096d0bf..8eabf45363 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1007,7 +1007,7 @@ class OAuthProviderApp(TypeBase): app_icon: Mapped[str] = mapped_column(String(255), nullable=False) client_id: Mapped[str] = mapped_column(String(255), nullable=False) client_secret: Mapped[str] = mapped_column(String(255), nullable=False) - app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict) + app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict) redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list) scope: Mapped[str] = mapped_column( String(255), @@ -2495,7 +2495,7 @@ class TraceAppConfig(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True) - tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True) + tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/oauth.py b/api/models/oauth.py index 1db2552469..bd04d890d3 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any import sqlalchemy as sa from sqlalchemy import func @@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase): ) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) - system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) class DatasourceProvider(TypeBase): @@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase): provider: Mapped[str] = mapped_column(sa.String(128), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) - encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default") is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1) @@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) + client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/source.py b/api/models/source.py index 8078b32f8c..8fce7df205 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -25,7 +25,7 @@ class DataSourceOauthBinding(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) access_token: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) - source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py index 87b9d813ec..e2f390402a 100644 --- a/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py @@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py index c335bc610d..e087ec30a5 100644 --- a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from numpy import ndarray @@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase): __tablename__: str id: Mapped[UUID] text: Mapped[str] - meta: Mapped[dict] + meta: Mapped[dict[str, Any]] vector: Mapped[ndarray] diff --git a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py index 84b5eba0ac..9c721c8bde 100644 --- a/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py @@ -67,7 +67,7 @@ class PGVectoRS(BaseVector): primary_key=True, ) text: Mapped[str] - meta: Mapped[dict] = mapped_column(postgresql.JSONB) + meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB) vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) self._table = _Table diff --git a/api/pyproject.toml b/api/pyproject.toml index 3b7e5f8e1f..f22bafb03a 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -136,6 +136,9 @@ dify-vdb-weaviate = { workspace = true } [tool.uv] default-groups = ["storage", "tools", "vdb-all"] package = false +override-dependencies = [ + "pyarrow>=18.0.0", +] [dependency-groups] diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index f6a670415b..b1fe352861 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -87,7 +87,7 @@ class RetrievalModel(BaseModel): class MetaDataConfig(BaseModel): doc_type: str - doc_metadata: dict + doc_metadata: dict[str, Any] class KnowledgeConfig(BaseModel): diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 3f37c9b176..bf208c9bc7 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,4 +1,5 @@ import logging +from typing import Any from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule @@ -168,7 +169,9 @@ class ModelProviderService: model_name=model, ) - def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: str | None = None + ) -> dict[str, Any] | None: """ get provider credentials. @@ -180,7 +183,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) return provider_configuration.get_provider_credential(credential_id=credential_id) - def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]): """ validate provider credentials before saving. @@ -192,7 +195,7 @@ class ModelProviderService: provider_configuration.validate_provider_credentials(credentials) def create_provider_credential( - self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None + self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create and save new provider credentials. @@ -210,7 +213,7 @@ class ModelProviderService: self, tenant_id: str, provider: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: @@ -254,7 +257,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> dict | None: + ) -> dict[str, Any] | None: """ Retrieve model-specific credentials. @@ -270,7 +273,9 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict): + def validate_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any] + ): """ validate model credentials. @@ -287,7 +292,13 @@ class ModelProviderService: ) def create_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict[str, Any], + credential_name: str | None, ) -> None: """ create and save model credentials. @@ -314,7 +325,7 @@ class ModelProviderService: provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 5fc5b412b3..605689226a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -104,7 +104,7 @@ class RagPipelineService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @classmethod - def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -120,7 +120,7 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None: """ Get pipeline template detail. @@ -131,7 +131,7 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) if built_in_result is None: logger.warning( "pipeline template retrieval returned empty result, template_id: %s, mode: %s", @@ -142,7 +142,7 @@ class RagPipelineService: else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) return customized_result @classmethod @@ -297,7 +297,7 @@ class RagPipelineService: self, *, pipeline: Pipeline, - graph: dict, + graph: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -467,7 +467,9 @@ class RagPipelineService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: + def get_default_block_config( + self, node_type: str, filters: dict[str, Any] | None = None + ) -> Mapping[str, object] | None: """ Get default config of node. :param node_type: node type @@ -500,7 +502,7 @@ class RagPipelineService: return default_config def run_draft_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account ) -> WorkflowNodeExecutionModel | None: """ Run draft workflow node @@ -582,7 +584,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -749,7 +751,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -979,7 +981,7 @@ class RagPipelineService: return workflow_node_execution def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes @@ -1099,7 +1101,9 @@ class RagPipelineService: ] return datasource_provider_variables - def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + def get_rag_pipeline_paginate_workflow_runs( + self, pipeline: Pipeline, args: dict[str, Any] + ) -> InfiniteScrollPagination: """ Get debug workflow run list Only return triggered_from == debugging @@ -1169,7 +1173,7 @@ class RagPipelineService: return list(node_executions) @classmethod - def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]): """ Publish customized pipeline template """ @@ -1259,7 +1263,7 @@ class RagPipelineService: ) return node_exec - def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account): + def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account): """ Set datasource variables """ @@ -1346,7 +1350,7 @@ class RagPipelineService: ) return workflow_node_execution_db_model - def get_recommended_plugins(self, type: str) -> dict: + def get_recommended_plugins(self, type: str) -> dict[str, Any]: # Query active recommended plugins stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) if type and type != "all": diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 779f7c4511..be2572b592 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -139,62 +139,82 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - # check if the name is unique - existing_workflow_tool_provider = db.session.scalar( - select(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id, - ) - .limit(1) - ) + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name exists for other tools + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .limit(1) + ) + + # if the name exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) + # query the workflow tool provider + workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + # if not found raise error if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App | None = db.session.scalar( - select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar( + select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) + ) + # if not found raise error if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider.name = name - workflow_tool_provider.label = label - workflow_tool_provider.icon = json.dumps(icon) - workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) - workflow_tool_provider.privacy_policy = privacy_policy - workflow_tool_provider.version = workflow.version - workflow_tool_provider.updated_at = datetime.now() + with sessionmaker(db.engine).begin() as _session: + _session.add(workflow_tool_provider) - try: - WorkflowToolProviderController.from_db(workflow_tool_provider) - except Exception as e: - raise ValueError(str(e)) + # update workflow tool provider + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() - db.session.commit() + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) - if labels is not None: - ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels - ) + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), + labels, + session=_session, + ) return {"result": "success"} diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 839b9e3319..0e1864ce9a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -241,8 +241,8 @@ class WorkflowService: self, *, app_model: App, - graph: dict, - features: dict, + graph: dict[str, Any], + features: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -576,7 +576,7 @@ class WorkflowService: except Exception as e: raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") - def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None: """ Validate load balancing credentials for a workflow node. @@ -1214,7 +1214,7 @@ class WorkflowService: return variable_pool def run_free_workflow_node( - self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: """ Run free workflow node @@ -1361,7 +1361,7 @@ class WorkflowService: node_execution.status = WorkflowNodeExecutionStatus.FAILED node_execution.error = error - def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: + def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App: """ Basic mode of chatbot app(expert mode) to workflow Completion App to Workflow App @@ -1421,7 +1421,7 @@ class WorkflowService: if node_type == BuiltinNodeTypes.HUMAN_INPUT: self._validate_human_input_node_data(node_data) - def validate_features_structure(self, app_model: App, features: dict): + def validate_features_structure(self, app_model: App, features: dict[str, Any]): match app_model.mode: case AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( @@ -1434,7 +1434,7 @@ class WorkflowService: case _: raise ValueError(f"Invalid app mode: {app_model.mode}") - def _validate_human_input_node_data(self, node_data: dict) -> None: + def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None: """ Validate HumanInput node data format. @@ -1452,7 +1452,7 @@ class WorkflowService: raise ValueError(f"Invalid HumanInput node data: {str(e)}") def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f84d39aeb5..c07ab6d6bf 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -33,6 +33,7 @@ REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false REDIS_DB=0 +REDIS_KEY_PREFIX= # PostgreSQL database configuration DB_USERNAME=postgres diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 54e0496dbd..15dec06311 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -432,7 +432,7 @@ class TestWorkflowAppLogEndpoints: monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker) def fake_get_paginate(self, **_kwargs): - return {"items": [], "total": 0} + return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} monkeypatch.setattr( workflow_app_log_module.WorkflowAppService, @@ -443,7 +443,7 @@ class TestWorkflowAppLogEndpoints: with app.test_request_context("/?page=1&limit=20"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result == {"items": [], "total": 0} + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} class TestWorkflowDraftVariableEndpoints: @@ -608,7 +608,8 @@ class TestWorkflowTriggerEndpoints: with app.test_request_context("/?node_id=node-1"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is trigger + assert isinstance(result, dict) + assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys()) class TestWrapsEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py new file mode 100644 index 0000000000..fad0b8b10e --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -0,0 +1,73 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.console.app.conversation import _get_conversation +from models.enums import ConversationFromSource +from models.model import AppMode, Conversation +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def test_get_conversation_mark_read_keeps_updated_at_unchanged( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + original_updated_at = datetime(2026, 2, 8, 0, 0, 0) + conversation = Conversation( + app_id=app.id, + name="read timestamp test", + inputs={}, + status="normal", + mode=AppMode.CHAT, + from_source=ConversationFromSource.CONSOLE, + from_account_id=account.id, + updated_at=original_updated_at, + ) + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + + read_at = datetime(2026, 2, 9, 0, 0, 0) + + with ( + patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ), + patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=read_at, + autospec=True, + ), + ): + loaded = _get_conversation(app, conversation.id) + + db_session_with_containers.refresh(conversation) + + assert loaded.id == conversation.id + assert conversation.read_at == read_at + assert conversation.read_account_id == account.id + assert conversation.updated_at == original_updated_at + + +def test_get_conversation_raises_not_found_for_missing_conversation( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ): + with pytest.raises(NotFound): + _get_conversation(app, "00000000-0000-0000-0000-000000000000") diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/test_containers_integration_tests/controllers/web/test_site.py similarity index 51% rename from api/tests/unit_tests/controllers/web/test_site.py rename to api/tests/test_containers_integration_tests/controllers/web/test_site.py index 6e9d754c43..9adb26ff3d 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_site.py @@ -1,28 +1,48 @@ -"""Unit tests for controllers.web.site endpoints.""" +"""Testcontainers integration tests for controllers.web.site endpoints.""" from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.web.site import AppSiteApi, AppSiteInfo +from models import Tenant, TenantStatus +from models.model import App, AppMode, CustomizeTokenStrategy, Site -def _tenant(*, status: str = "normal") -> SimpleNamespace: - return SimpleNamespace( - id="tenant-1", - status=status, - plan="basic", - custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="test-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="test-app", + enable_site=enable_site, + enable_api=True, ) + db_session.add(app_model) + db_session.commit() + return app_model -def _site() -> SimpleNamespace: - return SimpleNamespace( +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, title="Site", icon_type="emoji", icon="robot", @@ -31,77 +51,64 @@ def _site() -> SimpleNamespace: default_language="en", chat_color_theme="light", chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, + code=f"code-{app_id[-6:]}", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, ) + db_session.add(site) + db_session.commit() + return site -# --------------------------------------------------------------------------- -# AppSiteApi -# --------------------------------------------------------------------------- class TestAppSiteApi: @patch("controllers.web.site.FeatureService.get_features") - @patch("controllers.web.site.db") - def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_features.return_value = SimpleNamespace(can_replace_logo=False) - site_obj = _site() - mock_db.session.scalar.return_value = site_obj - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): result = AppSiteApi().get(app_model, end_user) - # marshal_with serializes AppSiteInfo to a dict - assert result["app_id"] == "app-1" + assert result["app_id"] == app_model.id assert result["plan"] == "basic" assert result["enable_site"] is True - @patch("controllers.web.site.db") - def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.scalar.return_value = None - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) - @patch("controllers.web.site.db") - def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + @patch("controllers.web.site.FeatureService.get_features") + def test_archived_tenant_raises_forbidden( + self, mock_features, app: Flask, db_session_with_containers: Session + ) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - from models.account import TenantStatus - - mock_db.session.scalar.return_value = _site() - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.ARCHIVE, - plan="basic", - custom_config_dict={}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) -# --------------------------------------------------------------------------- -# AppSiteInfo -# --------------------------------------------------------------------------- class TestAppSiteInfo: def test_basic_fields(self) -> None: - tenant = _tenant() - site_obj = _site() + tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={}) + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) assert info.app_id == "app-1" @@ -118,7 +125,7 @@ class TestAppSiteInfo: plan="pro", custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, ) - site_obj = _site() + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) assert info.can_replace_logo is True diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py new file mode 100644 index 0000000000..1b4179c9c7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -0,0 +1,613 @@ +"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetAutoDisableLog, + DatasetCollectionBinding, + DatasetPermission, + DatasetPermissionEnum, +) +from models.enums import DataSourceType +from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account, tenant + + @staticmethod + def create_account_in_tenant( + db_session_with_containers: Session, + tenant: Tenant, + role: TenantAccountRole = TenantAccountRole.EDITOR, + ) -> Account: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + *, + tenant_id: str, + created_by: str, + name: str | None = None, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, + enable_api: bool = True, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=name or f"dataset-{uuid4()}", + description="desc", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique=indexing_technique, + created_by=created_by, + provider="vendor", + permission=permission, + retrieval_model={"top_k": 2}, + ) + dataset.enable_api = enable_api + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_dataset_permission( + db_session_with_containers: Session, + *, + dataset_id: str, + tenant_id: str, + account_id: str, + ) -> DatasetPermission: + permission = DatasetPermission( + dataset_id=dataset_id, + tenant_id=tenant_id, + account_id=account_id, + has_permission=True, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + return permission + + @staticmethod + def create_app_dataset_join( + db_session_with_containers: Session, + *, + dataset_id: str, + ) -> AppDatasetJoin: + join = AppDatasetJoin( + app_id=str(uuid4()), + dataset_id=dataset_id, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return join + + @staticmethod + def create_collection_binding( + db_session_with_containers: Session, + *, + provider_name: str, + model_name: str, + collection_type: str = "dataset", + ) -> DatasetCollectionBinding: + binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name=f"collection_{uuid4().hex}", + type=collection_type, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + return binding + + @staticmethod + def create_auto_disable_log( + db_session_with_containers: Session, + *, + tenant_id: str, + dataset_id: str, + document_id: str, + ) -> DatasetAutoDisableLog: + log = DatasetAutoDisableLog( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + ) + db_session_with_containers.add(log) + db_session_with_containers.commit() + return log + + +class TestDatasetServicePermissionsAndLifecycle: + def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session): + owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + + result = DatasetService.delete_dataset(str(uuid4()), user=owner) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, user=owner) + + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + send_deleted_signal.assert_called_once_with(dataset) + + def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_app_dataset_join( + db_session_with_containers, + dataset_id=dataset.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is True + + def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is False + + def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, outsider) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session): + creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=creator.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_allows_partial_team_member_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_allows_partial_team_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=operator.id, + ) + + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status(str(uuid4()), True) + + def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(id=None)): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + now = datetime(2026, 4, 14, 18, 0, 0) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.naive_utc_now", return_value=now), + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + db_session_with_containers.refresh(dataset) + assert dataset.enable_api is True + assert dataset.updated_by == owner.id + assert dataset.updated_at == now + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(str(uuid4())) + + assert result == {"document_ids": [], "count": 0} + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(dataset.id) + + assert result["count"] == 2 + assert len(result["document_ids"]) == 2 + + +class TestDatasetCollectionBindingServiceIntegration: + def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result.id == binding.id + + def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session): + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model") + + persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id) + assert persisted is not None + assert persisted.provider_name == "provider" + assert persisted.model_name == "missing-model" + assert persisted.type == "dataset" + assert persisted.collection_name + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + + assert result.id == binding.id + + +class TestDatasetPermissionServiceIntegration: + def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_a.id, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_b.id, + ) + + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + assert set(result) == {member_a.id, member_b.id} + + def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=stale_member.id, + ) + + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + [{"user_id": member_a.id}, {"user_id": member_b.id}], + ) + + permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id} + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetPermissionService.clear_partial_member_list(dataset.id) + + remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert remaining == [] diff --git a/api/tests/test_containers_integration_tests/services/test_schedule_service.py b/api/tests/test_containers_integration_tests/services/test_schedule_service.py new file mode 100644 index 0000000000..87f3306258 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_schedule_service.py @@ -0,0 +1,387 @@ +"""Testcontainers integration tests for schedule service SQL-backed behavior.""" + +from datetime import datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate +from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError +from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.trigger import WorkflowSchedulePlan +from services.errors.account import AccountNotFoundError +from services.trigger.schedule_service import ScheduleService + + +class ScheduleServiceIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_schedule_plan( + db_session_with_containers: Session, + *, + tenant_id: str, + app_id: str | None = None, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", + next_run_at: datetime | None = None, + ) -> WorkflowSchedulePlan: + schedule = WorkflowSchedulePlan( + tenant_id=tenant_id, + app_id=app_id or str(uuid4()), + node_id=node_id, + cron_expression=cron_expression, + timezone=timezone, + next_run_at=next_run_at, + ) + db_session_with_containers.add(schedule) + db_session_with_containers.commit() + return schedule + + +def _cron_workflow( + *, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", +): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": node_id, + "data": { + "type": "trigger-schedule", + "mode": "cron", + "cron_expression": cron_expression, + "timezone": timezone, + }, + } + ] + } + ) + + +def _no_schedule_workflow(): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": "node-1", + "data": {"type": "llm"}, + } + ] + } + ) + + +class TestScheduleServiceIntegration: + def test_create_schedule_persists_schedule(self, db_session_with_containers: Session): + account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + expected_next_run = datetime(2026, 1, 1, 10, 30, 0) + config = ScheduleConfig( + node_id="start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + schedule = ScheduleService.create_schedule( + session=db_session_with_containers, + tenant_id=tenant.id, + app_id=str(uuid4()), + config=config, + ) + + persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) + assert persisted is not None + assert persisted.tenant_id == tenant.id + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + cron_expression="30 10 * * *", + timezone="UTC", + ) + expected_next_run = datetime(2026, 1, 2, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate( + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + db_session_with_containers.refresh(updated) + assert updated.cron_expression == "0 12 * * *" + assert updated.timezone == "America/New_York" + assert updated.next_run_at == expected_next_run + + def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + initial_next_run = datetime(2026, 1, 1, 10, 0, 0) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + next_run_at=initial_next_run, + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + calls: list[tuple] = [] + + def _track(*args, **kwargs): + calls.append((args, kwargs)) + return datetime(2026, 1, 9, 10, 0, 0) + + monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + db_session_with_containers.refresh(updated) + assert updated.node_id == "node-new" + assert updated.next_run_at == initial_next_run + assert calls == [] + + def test_update_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + def test_delete_schedule_removes_row(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + db_session_with_containers.commit() + + assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None + + def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session): + owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.OWNER, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == owner.id + + def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session): + admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.ADMIN, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == admin.id + + def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + missing_account_id = str(uuid4()) + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=missing_account_id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=missing_account_id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=tenant.id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + expected_next_run = datetime(2026, 1, 3, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + + db_session_with_containers.refresh(schedule) + assert result == expected_next_run + assert schedule.next_run_at == expected_next_run + + def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + +class TestSyncScheduleFromWorkflowIntegration: + def test_sync_schedule_create_new(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + expected_next_run = datetime(2026, 1, 4, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow(), + ) + + assert result is not None + persisted = db_session_with_containers.execute( + select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id) + ).scalar_one() + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_update_existing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + node_id="old-start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + existing_id = existing.id + expected_next_run = datetime(2026, 1, 5, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow( + node_id="start", + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + assert result is not None + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id) + assert persisted is not None + assert persisted.node_id == "start" + assert persisted.cron_expression == "0 12 * * *" + assert persisted.timezone == "America/New_York" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + ) + existing_id = existing.id + + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_no_schedule_workflow(), + ) + + assert result is None + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4b04c1accb..fcc15aad42 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Verify logs exist before processing - existing_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + existing_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(existing_logs) == 2 # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify auto disable logs were deleted - remaining_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + remaining_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(remaining_logs) == 0 # Verify index processing occurred normally diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6cbbe43137..e29ca7ebab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType @@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_with_image_files( @@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Verify that the task completed successfully by checking the log output @@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify database cleanup db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( @@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document_id).limit(1) + ) assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( @@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_multiple_documents( @@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( @@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None except Exception as e: @@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( @@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Verify initial state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1)) + is not None + ) # Store original IDs for verification document_id = document.id @@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify final database state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id) + ) + == 0 + ) + assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index f9ae33b32f..05827112d4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that segments were created - segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(segments) == 3 # Verify segment content and metadata @@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created (since dataset doesn't exist) - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify no documents were modified - documents = db_session_with_containers.query(Document).all() + documents = db_session_with_containers.scalars(select(Document)).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( @@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) db_session_with_containers.refresh(dataset) - segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments_for_dataset = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( @@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) + ).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( @@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions - all_segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + all_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(all_segments) == 6 # 3 existing + 3 new # Verify position ordering diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1dd37fbc92..32bc2fc0bd 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -52,18 +53,18 @@ class TestCleanDatasetTask: from extensions.ext_redis import redis_client # Clear all test data using the provided session fixture - db_session_with_containers.query(DatasetMetadataBinding).delete() - db_session_with_containers.query(DatasetMetadata).delete() - db_session_with_containers.query(AppDatasetJoin).delete() - db_session_with_containers.query(DatasetQuery).delete() - db_session_with_containers.query(DatasetProcessRule).delete() - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DatasetMetadataBinding)) + db_session_with_containers.execute(delete(DatasetMetadata)) + db_session_with_containers.execute(delete(AppDatasetJoin)) + db_session_with_containers.execute(delete(DatasetQuery)) + db_session_with_containers.execute(delete(DatasetProcessRule)) + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -302,28 +303,40 @@ class TestCleanDatasetTask: # Verify results # Check that dataset-related data was cleaned up - documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() assert len(documents) == 0 - segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(metadata) == 0 - bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.scalars( + select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id) + ).all() assert len(process_rules) == 0 - queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.scalars( + select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id) + ).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.scalars( + select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id) + ).all() assert len(app_joins) == 0 # Verify index processor was called @@ -414,24 +427,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify index processor was called @@ -485,12 +506,14 @@ class TestCleanDatasetTask: # Check that all data was cleaned up - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 - remaining_segments = ( - db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - ) + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Recreate data for next test case @@ -538,11 +561,15 @@ class TestCleanDatasetTask: # Verify results - even with vector cleanup failure, documents and segments should be deleted # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -622,18 +649,22 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = ( - db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() - ) + remaining_image_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(image_file_ids)) + ).all() assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -738,24 +769,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify performance expectations @@ -826,7 +865,9 @@ class TestCleanDatasetTask: # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file.id) + ).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -976,19 +1017,27 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file_id) + ).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 926c839c8b..2fb62e0fc0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -145,11 +146,16 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3 assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id.in_(document_ids)) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ) == 6 ) @@ -158,9 +164,9 @@ class TestCleanNotionDocumentTask: # Verify segments are deleted assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ) == 0 ) @@ -323,9 +329,9 @@ class TestCleanNotionDocumentTask: # Verify segments are deleted assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == document.id) - .count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) == 0 ) @@ -411,7 +417,9 @@ class TestCleanNotionDocumentTask: # Verify segments are deleted assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) == 0 ) @@ -499,9 +507,16 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5 assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) + ) + == 5 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == 10 ) @@ -514,19 +529,26 @@ class TestCleanNotionDocumentTask: # Verify only specified documents' segments are deleted assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(documents_to_clean)) - .count() + db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) + .where(DocumentSegment.document_id.in_(documents_to_clean)) + ) == 0 ) # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs)) + ) + == 2 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs)) + ) == 4 ) @@ -613,7 +635,9 @@ class TestCleanNotionDocumentTask: # Verify all segments exist before cleanup assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) == 4 ) @@ -622,7 +646,9 @@ class TestCleanNotionDocumentTask: # Verify all segments are deleted regardless of status assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) == 0 ) @@ -795,11 +821,15 @@ class TestCleanNotionDocumentTask: # Verify all data exists before cleanup assert ( - db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) + ) == num_documents ) assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == num_documents * num_segments_per_doc ) @@ -809,7 +839,9 @@ class TestCleanNotionDocumentTask: # Verify all segments are deleted assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == 0 ) @@ -906,8 +938,8 @@ class TestCleanNotionDocumentTask: # Verify all data exists before cleanup # Note: There may be documents from previous tests, so we check for at least 3 - assert db_session_with_containers.query(Document).count() >= 3 - assert db_session_with_containers.query(DocumentSegment).count() >= 9 + assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3 + assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9 # Clean up documents from only the first dataset target_dataset = datasets[0] @@ -919,19 +951,26 @@ class TestCleanNotionDocumentTask: # Verify only documents' segments from target dataset are deleted assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == target_document.id) - .count() + db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) + .where(DocumentSegment.document_id == target_document.id) + ) == 0 ) # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id.in_(remaining_docs)) + ) + == 2 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id.in_(remaining_docs)) + ) == 6 ) @@ -1028,11 +1067,13 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len( - document_statuses - ) + assert db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.dataset_id == dataset.id) + ) == len(document_statuses) assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == len(document_statuses) * 2 ) @@ -1042,7 +1083,9 @@ class TestCleanNotionDocumentTask: # Verify all segments are deleted regardless of status assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ) == 0 ) @@ -1142,9 +1185,16 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1 assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(Document).where(Document.id == document.id) + ) + == 1 + ) + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) == 3 ) @@ -1153,7 +1203,9 @@ class TestCleanNotionDocumentTask: # Verify segments are deleted assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) == 0 ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 13ea94348a..684097851b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration: def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: """Return the latest persisted document state.""" - return db_session_with_containers.query(Document).where(Document.id == document_id).first() + return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 3e9a0c8f7f..6e03bd9351 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask: db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() - ) + updated_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + ).all() for segment in updated_segments: assert segment.enabled is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index d4021143ef..d94c1fdf24 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy import delete, func, select, update from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask: """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) - db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( - {"data_source_info": None} + db_session_with_containers.execute( + update(Document).where(Document.id == context["document"].id).values(data_source_info=None) ) db_session_with_containers.commit() @@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR @@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.COMPLETED @@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None @@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask: context = self._create_notion_sync_context(db_session_with_containers) def _delete_dataset_before_clean() -> str: - db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id)) db_session_with_containers.commit() return "2024-01-02T00:00:00Z" @@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index d94abf2b40..a9a8c0f30c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import func, select from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask: db_session_with_containers.expire_all() # Assert document status updated before reindex - updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1)) assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining == 0 @@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask: mock_external_dependencies["runner_instance"].run.assert_called_once() # Segments should remain (since clean failed before DB delete) - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining > 0 diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index d6933e2180..bad246a4bb 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): - """Test that DB_EXTRAS options are properly merged with default timezone setting""" + """Test that DB_EXTRAS options are merged with the default timezone startup option.""" # Set environment variables monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") @@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): # Create config config = DifyConfig() - # Get engine options - engine_options = config.SQLALCHEMY_ENGINE_OPTIONS - - # Verify options contains both search_path and timezone - options = engine_options["connect_args"]["options"] + options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options +def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema") + monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "") + + config = DifyConfig() + + assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == { + "options": "-c search_path=myschema", + } + + def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): os.environ.clear() @@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest. _ = DifyConfig().normalized_pubsub_redis_url +def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "" + + +def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "enterprise-a" + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037d..35d07a987d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -138,12 +138,15 @@ def app_models(app_module): def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" - def _fake_signed_url(key: str | None) -> str | None: - if not key: + def _fake_build_icon_url(_icon_type, key: str | None) -> str | None: + if key is None: + return None + icon_type = str(_icon_type).lower() + if icon_type != "image": return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url) def _ts(hour: int = 12) -> datetime: diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py deleted file mode 100644 index f588ab261d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from controllers.console.app.conversation import _get_conversation - - -def test_get_conversation_mark_read_keeps_updated_at_unchanged(): - app_model = SimpleNamespace(id="app-id") - account = SimpleNamespace(id="account-id") - conversation = MagicMock() - conversation.id = "conversation-id" - - with ( - patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, None), - autospec=True, - ), - patch( - "controllers.console.app.conversation.naive_utc_now", - return_value=datetime(2026, 2, 9, 0, 0, 0), - autospec=True, - ), - patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, - ): - mock_session.scalar.return_value = conversation - - _get_conversation(app_model, "conversation-id") - - statement = mock_session.execute.call_args[0][0] - compiled = statement.compile() - sql_text = str(compiled).lower() - compact_sql_text = sql_text.replace(" ", "") - params = compiled.params - - assert "updated_at=current_timestamp" not in compact_sql_text - assert "updated_at=conversations.updated_at" in compact_sql_text - assert "read_at=:read_at" in compact_sql_text - assert "read_account_id=:read_account_id" in compact_sql_text - assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0) - assert params["read_account_id"] == "account-id" diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py new file mode 100644 index 0000000000..42b3420c31 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from contextlib import nullcontext +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from graphon.variables.types import SegmentType +from pydantic import ValidationError + +from controllers.console.app import conversation_variables as conversation_variables_module + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + created_at = datetime(2026, 1, 1, tzinfo=UTC) + updated_at = datetime(2026, 1, 2, tzinfo=UTC) + row = SimpleNamespace( + created_at=created_at, + updated_at=updated_at, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-1", + "name": "my_var", + "value_type": "string", + "value": "value", + "description": "desc", + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["page"] == 1 + assert response["limit"] == 100 + assert response["total"] == 1 + assert response["has_more"] is False + assert response["data"][0]["id"] == "var-1" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["updated_at"] == int(updated_at.timestamp()) + + +def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + row = SimpleNamespace( + created_at=None, + updated_at=None, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-2", + "name": "my_var_2", + "value_type": SegmentType.INTEGER, + "value": 42, + "description": None, + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["data"][0]["value_type"] == "number" + assert response["data"][0]["value"] == "42" + + +def test_get_conversation_variables_requires_conversation_id(app) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): + with pytest.raises(ValidationError): + method(app_model=SimpleNamespace(id="app-1")) diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index a76e958829..c984dbef5d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import UTC, datetime + import pytest from controllers.console.app import message as message_module @@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) assert len(response.data) == 2 assert response.data[0] == "What is AI?" + + +def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageDetailResponse normalizes alias fields and datetime timestamps.""" + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = message_module.MessageDetailResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "conversation_id": "550e8400-e29b-41d4-a716-446655440001", + "inputs": {"foo": "bar"}, + "query": "hello", + "re_sign_file_url_answer": "world", + "from_source": "user", + "status": "normal", + "created_at": created_at, + "message_metadata_dict": {"token_usage": 3}, + } + ) + assert response.answer == "world" + assert response.metadata == {"token_usage": 3} + assert response.created_at == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py new file mode 100644 index 0000000000..2adb69c704 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from graphon.enums import WorkflowExecutionStatus + +from controllers.console.app import workflow_app_log as workflow_app_log_module + + +def test_workflow_app_log_query_parses_bool_and_datetime(): + query = workflow_app_log_module.WorkflowAppLogQuery.model_validate( + { + "detail": "true", + "created_at__before": "2026-01-02T03:04:05Z", + "page": "2", + "limit": "10", + } + ) + + assert query.detail is True + assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + assert query.page == 2 + assert query.limit == 10 + + +def test_workflow_app_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "log-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.SUCCEEDED, + "created_at": created_at, + "finished_at": created_at, + }, + "details": {"trigger_metadata": {}}, + "created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"}, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "succeeded" + assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + + +def test_workflow_archived_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "archived-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.FAILED, + }, + "trigger_metadata": {"type": "trigger-plugin"}, + "created_by_end_user": { + "id": "eu-1", + "type": "anonymous", + "is_anonymous": True, + "session_id": "session-1", + }, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "failed" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py new file mode 100644 index 0000000000..5363aa154f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from controllers.console.app import workflow_trigger as workflow_trigger_module + + +def test_parser_models_validate(): + parser = workflow_trigger_module.Parser(node_id="node-1") + enable_parser = workflow_trigger_module.ParserEnable( + trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True + ) + + assert parser.node_id == "node-1" + assert enable_parser.enable_trigger is True + + +def test_workflow_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="https://example.com/icon", + status="enabled", + created_at=created_at, + updated_at=created_at, + ) + + payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump( + mode="json" + ) + assert payload["id"] == "trigger-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" + assert payload["updated_at"] == "2026-01-02T03:04:05Z" + + +def test_webhook_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + webhook = { + "id": "webhook-1", + "webhook_id": "whk-1", + "webhook_url": "https://example.com/hook", + "webhook_debug_url": "https://example.com/hook/debug", + "node_id": "node-1", + "created_at": created_at, + } + + payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json") + assert payload["webhook_id"] == "whk-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 726c0a5cf3..09ed2aaf69 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -99,6 +99,57 @@ class TestHitTestingApi: assert "records" in result assert result["records"] == [] + def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + } + records = [ + { + "segment": None, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } + ] + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": payload["query"], "records": records}, + ), + ): + result = method(api, dataset_id) + + assert result["query"] == payload["query"] + assert result["records"] == records + def test_hit_testing_dataset_not_found(self, app, dataset_id): api = HitTestingApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 02c7507ea7..76c863577a 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import controllers.console.explore.recommended_app as module +from models.model import AppMode, IconType def unwrap(func): @@ -90,3 +91,48 @@ class TestRecommendedAppApi: service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") assert result == result_data + + +class TestRecommendedAppResponseModels: + def test_recommended_app_info_response_computes_icon_url(self): + with patch.object(module, "build_icon_url", return_value="https://signed/icon.png"): + payload = module.RecommendedAppInfoResponse.model_validate( + { + "id": "app-1", + "name": "App", + "mode": AppMode.CHAT, + "icon": "icon.png", + "icon_type": IconType.IMAGE, + "icon_background": "#fff", + } + ).model_dump(mode="json") + + assert payload["icon_url"] == "https://signed/icon.png" + + def test_recommended_app_list_response_serialization(self): + response = module.RecommendedAppListResponse.model_validate( + { + "recommended_apps": [ + { + "app": { + "id": "app-1", + "name": "App", + "mode": "chat", + "icon": "icon.png", + "icon_type": "emoji", + "icon_background": "#fff", + }, + "app_id": "app-1", + "description": "desc", + "category": "cat", + "position": 1, + "is_listed": True, + "can_trial": False, + } + ], + "categories": ["cat"], + } + ).model_dump(mode="json") + + assert response["recommended_apps"][0]["app_id"] == "app-1" + assert response["categories"] == ["cat"] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index b2d13dbbdf..e82a29f045 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -18,6 +18,7 @@ from controllers.console.workspace.workspace import ( CustomConfigWorkspaceApi, SwitchWorkspaceApi, TenantApi, + TenantInfoResponse, TenantListApi, WebappLogoWorkspaceApi, WorkspaceInfoApi, @@ -435,6 +436,23 @@ class TestTenantApi: assert status == 200 +class TestTenantInfoResponse: + def test_tenant_info_response_normalizes_enum_and_datetime(self): + created_at = naive_utc_now() + payload = TenantInfoResponse.model_validate( + { + "id": "t1", + "status": TenantStatus.NORMAL, + "plan": CloudPlan.TEAM, + "created_at": created_at, + } + ).model_dump(mode="json") + + assert payload["status"] == "normal" + assert payload["plan"] == "team" + assert payload["created_at"] == int(created_at.timestamp()) + + class TestSwitchWorkspaceApi: def test_switch_success(self, app): api = SwitchWorkspaceApi() diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 957d7fbd9b..0895fac3a4 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -2,6 +2,7 @@ Unit tests for inner_api plugin decorators """ +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -232,11 +233,11 @@ class TestGetUserTenant: class PluginTestPayload: """Simple test payload class""" - def __init__(self, data: dict): + def __init__(self, data: dict[str, Any]): self.value = data.get("value") @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): return cls(data) @@ -277,7 +278,7 @@ class TestPluginData: # Arrange class InvalidPayload: @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): raise Exception("Validation failed") @plugin_data(payload_type=InvalidPayload) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index dbd06677d8..97fdf1a011 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -15,10 +15,12 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.variables.types import SegmentType from werkzeug.exceptions import BadRequest, NotFound import services @@ -29,6 +31,8 @@ from controllers.service_api.app.conversation import ( ConversationRenameApi, ConversationRenamePayload, ConversationVariableDetailApi, + ConversationVariableInfiniteScrollPaginationResponse, + ConversationVariableResponse, ConversationVariablesApi, ConversationVariablesQuery, ConversationVariableUpdatePayload, @@ -261,6 +265,46 @@ class TestConversationVariableUpdatePayload: assert payload.value == nested +class TestConversationVariableResponseModels: + def test_variable_response_normalizes_value_type_and_timestamps(self): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "description": "desc", + "created_at": created_at, + "updated_at": created_at, + } + ) + assert response.value_type == "number" + assert response.value == "1" + assert response.created_at == int(created_at.timestamp()) + assert response.updated_at == int(created_at.timestamp()) + + def test_variable_pagination_response(self): + response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( + { + "limit": 1, + "has_more": False, + "data": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": "string", + "value": "bar", + } + ], + } + ) + assert response.limit == 1 + assert response.has_more is False + assert len(response.data) == 1 + assert response.data[0].name == "foo" + + class TestConversationAppModeValidation: """Test app mode validation for conversation endpoints.""" @@ -549,6 +593,44 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "get_conversational_variable", + lambda *_args, **_kwargs: SimpleNamespace( + limit=1, + has_more=False, + data=[ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + } + ], + ), + ) + + api = ConversationVariablesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables?limit=20", + method="GET", + ): + result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + + assert result["limit"] == 1 + assert result["has_more"] is False + assert result["data"][0]["value_type"] == "number" + assert result["data"][0]["value"] == "1" + assert result["data"][0]["created_at"] == int(created_at.timestamp()) + class TestConversationVariableDetailApiController: def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -602,3 +684,41 @@ class TestConversationVariableDetailApiController: c_id="00000000-0000-0000-0000-000000000001", variable_id="00000000-0000-0000-0000-000000000002", ) + + def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "update_conversation_variable", + lambda *_args, **_kwargs: { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + }, + ) + + api = ConversationVariableDetailApi() + handler = _unwrap(api.put) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables/2", + method="PUT", + json={"value": 1}, + ): + result = handler( + api, + app_model=app_model, + end_user=end_user, + c_id="00000000-0000-0000-0000-000000000001", + variable_id="00000000-0000-0000-0000-000000000002", + ) + + assert result["id"] == "550e8400-e29b-41d4-a716-446655440000" + assert result["value_type"] == "number" + assert result["value"] == "1" + assert result["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index cfa21bf2dd..74a3c75839 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -15,6 +15,7 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch @@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService +def _make_mock_workflow_run(run_id: str = "run-1"): + run = Mock() + run.id = run_id + run.workflow_id = "wf-1" + run.status = WorkflowExecutionStatus.SUCCEEDED + run.inputs = {"input": "value"} + run.outputs_dict = {"output": "value"} + run.error = None + run.total_steps = 1 + run.total_tokens = 10 + run.created_at = datetime(2026, 1, 1, tzinfo=UTC) + run.finished_at = datetime(2026, 1, 1, tzinfo=UTC) + run.elapsed_time = 0.1 + return run + + class TestWorkflowRunPayload: """Test suite for WorkflowRunPayload Pydantic model.""" @@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi: handler(api, app_model=app_model, workflow_run_id="run") def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - run = SimpleNamespace(id="run") + run = _make_mock_workflow_run(run_id="run") repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) @@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi: handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") - assert handler(api, app_model=app_model, workflow_run_id="run") == run + result = handler(api, app_model=app_model, workflow_run_id="run") + assert result["id"] == "run" + assert result["workflow_id"] == "wf-1" + assert result["status"] == "succeeded" class TestWorkflowRunApi: @@ -490,7 +510,7 @@ class TestWorkflowAppLogApi: monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", - lambda *_args, **_kwargs: {"items": [], "total": 0}, + lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}, ) api = WorkflowAppLogApi() @@ -500,7 +520,7 @@ class TestWorkflowAppLogApi: with app.test_request_context("/workflows/logs", method="GET"): response = handler(api, app_model=app_model) - assert response == {"items": [], "total": 0} + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} # ============================================================================= @@ -527,9 +547,8 @@ def mock_workflow_app(): class TestWorkflowRunDetailApiGet: """Test suite for WorkflowRunDetailApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) - and ``@service_api_ns.marshal_with``. We call the unwrapped method - directly; ``marshal_with`` is a no-op when calling directly. + ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``), + and we call the unwrapped method directly in tests. """ @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet: mock_workflow_app, ): """Test successful workflow run detail retrieval.""" - mock_run = Mock() - mock_run.id = "run-1" - mock_run.status = "succeeded" + mock_run = _make_mock_workflow_run(run_id="run-1") mock_repo = Mock() mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo @@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet: api = WorkflowRunDetailApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) - assert result == mock_run + assert result["id"] == mock_run.id + assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") def test_get_workflow_run_wrong_app_mode(self, mock_db, app): @@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost: class TestWorkflowAppLogApiGet: """Test suite for WorkflowAppLogApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` and - ``@service_api_ns.marshal_with``. + ``get`` is wrapped by ``@validate_app_token``. """ @patch("controllers.service_api.app.workflow.WorkflowAppService") @@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet: ): """Test successful workflow log retrieval.""" mock_pagination = Mock() + mock_pagination.page = 1 + mock_pagination.limit = 20 + mock_pagination.total = 0 + mock_pagination.has_more = False mock_pagination.data = [] mock_svc_instance = Mock() mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination @@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet: api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) - assert result == mock_pagination + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py index 216cda83c5..63e887f904 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_base.py +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from core.extension.extensible import ExtensionModule @@ -12,10 +14,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) @@ -28,10 +30,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") @@ -43,10 +45,10 @@ class TestExternalDataTool: def test_validate_config_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" with pytest.raises(NotImplementedError): @@ -55,10 +57,10 @@ class TestExternalDataTool: def test_query_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index 3a97ad5c5d..4c668ee96b 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -10,6 +10,7 @@ This module tests all aspects of the content moderation system including: - Configuration validation """ +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -28,7 +29,7 @@ class TestKeywordsModeration: """Test suite for custom keyword-based content moderation.""" @pytest.fixture - def keywords_config(self) -> dict: + def keywords_config(self) -> dict[str, Any]: """ Fixture providing a standard keywords moderation configuration. @@ -48,7 +49,7 @@ class TestKeywordsModeration: } @pytest.fixture - def keywords_moderation(self, keywords_config: dict) -> KeywordsModeration: + def keywords_moderation(self, keywords_config: dict[str, Any]) -> KeywordsModeration: """ Fixture providing a KeywordsModeration instance. @@ -64,7 +65,7 @@ class TestKeywordsModeration: config=keywords_config, ) - def test_validate_config_success(self, keywords_config: dict): + def test_validate_config_success(self, keywords_config: dict[str, Any]): """Test successful validation of keywords moderation configuration.""" # Should not raise any exception KeywordsModeration.validate_config("test-tenant", keywords_config) @@ -274,7 +275,7 @@ class TestOpenAIModeration: """Test suite for OpenAI-based content moderation.""" @pytest.fixture - def openai_config(self) -> dict: + def openai_config(self) -> dict[str, Any]: """ Fixture providing OpenAI moderation configuration. @@ -293,7 +294,7 @@ class TestOpenAIModeration: } @pytest.fixture - def openai_moderation(self, openai_config: dict) -> OpenAIModeration: + def openai_moderation(self, openai_config: dict[str, Any]) -> OpenAIModeration: """ Fixture providing an OpenAIModeration instance. @@ -309,7 +310,7 @@ class TestOpenAIModeration: config=openai_config, ) - def test_validate_config_success(self, openai_config: dict): + def test_validate_config_success(self, openai_config: dict[str, Any]): """Test successful validation of OpenAI moderation configuration.""" # Should not raise any exception OpenAIModeration.validate_config("test-tenant", openai_config) diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index bbdd476914..136ac0c72a 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -1,5 +1,6 @@ import json from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest @@ -57,7 +58,7 @@ class _FakeSelect: return self -def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict[str, Any] | None = None): return SimpleNamespace( data_source_type=data_source_type, keyword_table_dict=keyword_table_dict, diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 8b104597a8..0baf85c314 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, call, patch from uuid import uuid4 @@ -20,7 +21,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index d4b987c832..7ae0da03ff 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest @@ -71,7 +72,9 @@ class TestParagraphIndexProcessor: with pytest.raises(ValueError, match="No rules found in process rule"): processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) - def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_validates_segmentation( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: rules_without_segmentation = SimpleNamespace(segmentation=None) with patch( @@ -84,7 +87,9 @@ class TestParagraphIndexProcessor: process_rule={"mode": "custom", "rules": {"enabled": True}}, ) - def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_builds_split_documents( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) splitter = Mock() splitter.split_documents.return_value = [ diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index b1b1835a52..bfae9001b7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch import pandas as pd @@ -77,7 +78,7 @@ class TestQAIndexProcessor: processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) def test_transform_preview_calls_formatter_once( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) split_node = Document(page_content=".question", metadata={}) @@ -119,7 +120,7 @@ class TestQAIndexProcessor: mock_format.assert_called_once() def test_transform_non_preview_uses_thread_batches( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: documents = [ Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 1b17cbc368..508fe80e2b 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,6 +1,7 @@ import threading from contextlib import contextmanager, nullcontext from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -45,7 +46,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -2021,7 +2022,7 @@ def create_mock_document_methods( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -4091,7 +4092,7 @@ def _doc( dataset_id: str = "dataset-1", document_id: str = "document-1", doc_id: str = "node-1", - extra: dict | None = None, + extra: dict[str, Any] | None = None, ) -> Document: metadata = { "score": score, diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 48782515d0..90feb4cf01 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -55,7 +56,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py index 79b8eaaa87..f35546b025 100644 --- a/api/tests/unit_tests/core/tools/test_custom_tool.py +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -1,6 +1,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import Any import httpx import pytest @@ -14,7 +15,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvo from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -def _build_tool(*, openapi: dict | None = None) -> ApiTool: +def _build_tool(*, openapi: dict[str, Any] | None = None) -> ApiTool: entity = ToolEntity( identity=ToolIdentity( author="author", diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 8c0e7e9419..e13f430f9b 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -2,7 +2,7 @@ from __future__ import annotations from types import SimpleNamespace from typing import Any -from unittest.mock import PropertyMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -12,11 +12,13 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +# Create a mock class for testing abstract/base classes class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): return None +# Factory function to create a "lightweight" controller for testing def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: controller = object.__new__(ApiToolProviderController) controller.provider_id = provider_id @@ -29,6 +31,7 @@ def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderContr return controller +# Test pure logic: filtering and deduplication def test_tool_label_manager_filter_tool_labels(): filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) assert set(filtered) == {"search", "news"} @@ -36,22 +39,68 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): + """ + Test the database update logic for tool labels. + Focus: Verify that labels are filtered, de-duplicated, and safely handled within a database session. + """ + # 1. Setup expected data from the controller controller = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: + expected_id = controller.provider_id + expected_type = controller.provider_type + + # 2. Patching External Dependencies + # - We patch 'db' to prevent Flask from trying to access a real database. + # - We patch 'sessionmaker' to intercept and control the creation of SQLAlchemy sessions. + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + # 3. Constructing the "Mocking Chain" + # In the business logic, we use: with sessionmaker(db.engine).begin() as _session: + # We need to link our 'mock_session' to the end of this complex context manager chain: + # Step A: sessionmaker(db.engine) -> returns an object (mock_sessionmaker.return_value) + # Step B: .begin() -> returns a context manager (begin.return_value) + # Step C: with ... as _session: -> calls __enter__(), and _session gets the __enter__.return_value + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # 4. Trigger the logic under test + # Input: ["search", "search", "invalid"] + # Logic: + # - "invalid" should be filtered out (not in default_tool_label_name_list). + # - The duplicate "search" should be merged (unique labels). ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - mock_db.session.execute.assert_called_once() - # only one valid unique label should be inserted. - assert mock_db.session.add.call_count == 1 - mock_db.session.commit.assert_called_once() + # 5. Behavior Assertion: DELETE operation + # Verify that the manager first attempts to clear existing labels for this specific tool. + # This ensures the update is idempotent. + mock_session.execute.assert_called_once() + + # 6. Behavior Assertion: INSERT operation + # Verify that only ONE valid label ("search") was added after filtering and deduplication. + # If call_count == 1, it proves filter_tool_labels() worked as expected. + assert mock_session.add.call_count == 1 + + # 7. State Assertion: Data Integrity & Isolation + # Inspect the actual object passed to session.add() to ensure it has correct properties. + # This confirms that the data isolation (tool_id + tool_type) we refactored is active. + call_args = mock_session.add.call_args + added_label = call_args[0][0] # Retrieve the ToolLabelBinding instance + + assert added_label.label_name == "search", "The label name should be 'search' after filtering." + assert added_label.tool_id == expected_id, "The tool_id must match the provider_id for correct binding." + assert added_label.tool_type == expected_type, "Isolation failed: tool_type must be verified during update." +# Test error handling def test_tool_label_manager_update_tool_labels_unsupported(): with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] +# Test retrieval logic def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + # Mocking a property (@property) using PropertyMock with patch.object( _ConcreteBuiltinToolProviderController, "tool_labels", @@ -62,29 +111,67 @@ def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] api = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = ["search", "news"] - labels = ToolLabelManager.get_tool_labels(api) - assert labels == ["search", "news"] + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + # Inject mock data into the query result: session.scalars(stmt).all() + mock_session.scalars.return_value.all.return_value = ["search", "news"] + + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + +def test_tool_label_manager_get_tool_labels_unsupported(): + """ + Negative Test: Ensure get_tool_labels raises ValueError for unsupported controller types. + This protects the internal API contract against accidental regressions during refactoring. + """ + # Passing a generic object() which doesn't match Api, Workflow, or Builtin controllers. with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] +# Test batch processing and mapping def test_tool_label_manager_get_tools_labels_batch(): assert ToolLabelManager.get_tools_labels([]) == {} api = _api_controller("api-1") wf = _workflow_controller("wf-1") + + # SimpleNamespace is a quick way to simulate SQLAlchemy row objects records = [ SimpleNamespace(tool_id="api-1", label_name="search"), SimpleNamespace(tool_id="api-1", label_name="news"), SimpleNamespace(tool_id="wf-1", label_name="utilities"), ] - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = records + + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # Simulating the batch query result + mock_session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + + # Verify the final dictionary mapping assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + +def test_tool_label_manager_get_tools_labels_unsupported(): + """ + Negative Test: Ensure get_tools_labels raises ValueError if the list contains + unsupported controller types, even alongside valid ones. + """ + api = _api_controller("api-1") + + # Passing a list with one valid controller and one invalid object() with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index 6454a5bcd1..5f34135af4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest import core.tools.utils.message_transformer as mt @@ -13,7 +15,7 @@ class _FakeToolFile: class _FakeToolFileManager: """Fake ToolFileManager to capture the mimetype passed in.""" - last_call: dict | None = None + last_call: dict[str, Any] | None = None def __init__(self, *args, **kwargs): pass diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 52f262e1cf..84b3f71d5e 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -10,6 +10,7 @@ from __future__ import annotations from decimal import Decimal from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest @@ -25,7 +26,7 @@ from graphon.model_runtime.errors.invoke import ( from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils -def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: +def _mock_model_instance(*, schema: dict[str, Any] | None = None) -> SimpleNamespace: model_type_instance = Mock() model_type_instance.get_model_schema.return_value = ( SimpleNamespace(model_properties=schema or {}) if schema is not None else None diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 40f91b12a0..032b1377a4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,4 +1,5 @@ from json.decoder import JSONDecodeError +from typing import Any from unittest.mock import Mock, patch import pytest @@ -259,8 +260,8 @@ def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): }, } - extra_info: dict = {} - warning: dict = {} + extra_info: dict[str, Any] = {} + warning: dict[str, Any] = {} with app.test_request_context(headers={"X-Request-Env": "prod"}): bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @@ -298,7 +299,7 @@ def test_parse_swagger_to_openapi_branches(): } ) - warning: dict = {"seed": True} + warning: dict[str, Any] = {"seed": True} converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( { "servers": [{"url": "https://x"}], diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index 78622b78b6..0c67effe90 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -8,6 +8,7 @@ and select_trigger_debug_events orchestrator. from __future__ import annotations from datetime import datetime +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -30,7 +31,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID -def _make_poller_args(node_config: dict | None = None) -> dict: +def _make_poller_args(node_config: dict[str, Any] | None = None) -> dict[str, Any]: return { "tenant_id": "t1", "user_id": "u1", diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index cccd3fb676..8056217479 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -6,6 +6,7 @@ to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" erro when passing files to downstream LLM nodes. """ +from typing import Any from unittest.mock import Mock, patch from graphon.entities import GraphInitParams @@ -97,7 +98,7 @@ def create_test_file_dict( } -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="webhook-node-1", @@ -105,7 +106,7 @@ def build_webhook_variable_pool(inputs: dict) -> VariablePool: ) -def expected_factory_mapping(file_dict: dict) -> dict: +def expected_factory_mapping(file_dict: dict[str, Any]) -> dict[str, Any]: return {**file_dict, "upload_file_id": file_dict["related_id"]} diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 34c66a4f9f..c19e28bbd5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import patch import pytest @@ -62,7 +63,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="1", diff --git a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py index bb1f78b80c..1ce9581aa1 100644 --- a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py +++ b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py @@ -4,6 +4,7 @@ from __future__ import annotations import json from datetime import UTC, datetime +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -57,7 +58,7 @@ _T1 = datetime(2024, 1, 10, 12, 0, 5, tzinfo=UTC) def make_workflow_info(**overrides) -> WorkflowTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "workflow_id": "wf-001", "tenant_id": "tenant-abc", "workflow_run_id": "run-001", @@ -86,7 +87,7 @@ def make_workflow_info(**overrides) -> WorkflowTraceInfo: def make_node_info(**overrides) -> WorkflowNodeTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "workflow_id": "wf-001", "workflow_run_id": "run-001", "tenant_id": "tenant-abc", @@ -115,7 +116,7 @@ def make_node_info(**overrides) -> WorkflowNodeTraceInfo: def make_draft_node_info(**overrides) -> DraftNodeExecutionTrace: - defaults: dict = { + defaults: dict[str, Any] = { "workflow_id": "wf-001", "workflow_run_id": "run-draft-001", "tenant_id": "tenant-abc", @@ -136,7 +137,7 @@ def make_draft_node_info(**overrides) -> DraftNodeExecutionTrace: def make_message_info(**overrides) -> MessageTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "message_id": "msg-001", "conversation_model": "gpt-4", "message_tokens": 40, @@ -161,7 +162,7 @@ def make_message_info(**overrides) -> MessageTraceInfo: def make_tool_info(**overrides) -> ToolTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "message_id": "msg-001", "tool_name": "web_search", "tool_inputs": {"query": "test"}, @@ -176,7 +177,7 @@ def make_tool_info(**overrides) -> ToolTraceInfo: def make_moderation_info(**overrides) -> ModerationTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "message_id": "msg-001", "flagged": False, "action": "pass", @@ -189,7 +190,7 @@ def make_moderation_info(**overrides) -> ModerationTraceInfo: def make_suggested_question_info(**overrides) -> SuggestedQuestionTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "message_id": "msg-001", "total_tokens": 30, "suggested_question": ["Question A?", "Question B?"], @@ -206,7 +207,7 @@ def make_suggested_question_info(**overrides) -> SuggestedQuestionTraceInfo: def make_dataset_retrieval_info(**overrides) -> DatasetRetrievalTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "message_id": "msg-001", "documents": [ { @@ -236,7 +237,7 @@ def make_dataset_retrieval_info(**overrides) -> DatasetRetrievalTraceInfo: def make_generate_name_info(**overrides) -> GenerateNameTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "message_id": "msg-001", "tenant_id": "tenant-abc", "conversation_id": "conv-001", @@ -251,7 +252,7 @@ def make_generate_name_info(**overrides) -> GenerateNameTraceInfo: def make_prompt_generation_info(**overrides) -> PromptGenerationTraceInfo: - defaults: dict = { + defaults: dict[str, Any] = { "tenant_id": "tenant-abc", "user_id": "user-001", "app_id": "app-001", diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py index 81687ce5f8..366e45d86d 100644 --- a/api/tests/unit_tests/extensions/test_celery_ssl.py +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -7,6 +7,47 @@ from unittest.mock import MagicMock, patch class TestCelerySSLConfiguration: """Test suite for Celery SSL configuration.""" + def test_get_celery_broker_transport_options_includes_global_keyprefix_for_redis(self): + mock_config = MagicMock() + mock_config.CELERY_USE_SENTINEL = False + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import get_celery_broker_transport_options + + result = get_celery_broker_transport_options() + + assert result["global_keyprefix"] == "enterprise-a:" + + def test_get_celery_broker_transport_options_omits_global_keyprefix_when_prefix_empty(self): + mock_config = MagicMock() + mock_config.CELERY_USE_SENTINEL = False + mock_config.REDIS_KEY_PREFIX = " " + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import get_celery_broker_transport_options + + result = get_celery_broker_transport_options() + + assert "global_keyprefix" not in result + + def test_get_celery_broker_transport_options_keeps_sentinel_and_adds_global_keyprefix(self): + mock_config = MagicMock() + mock_config.CELERY_USE_SENTINEL = True + mock_config.CELERY_SENTINEL_MASTER_NAME = "mymaster" + mock_config.CELERY_SENTINEL_SOCKET_TIMEOUT = 0.1 + mock_config.CELERY_SENTINEL_PASSWORD = "secret" + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import get_celery_broker_transport_options + + result = get_celery_broker_transport_options() + + assert result["master_name"] == "mymaster" + assert result["sentinel_kwargs"]["password"] == "secret" + assert result["global_keyprefix"] == "enterprise-a:" + def test_get_celery_ssl_options_when_ssl_disabled(self): """Test SSL options when BROKER_USE_SSL is False.""" from configs import DifyConfig @@ -151,3 +192,49 @@ class TestCelerySSLConfiguration: # Check that SSL is also applied to Redis backend assert "redis_backend_use_ssl" in celery_app.conf assert celery_app.conf["redis_backend_use_ssl"] is not None + + def test_celery_init_applies_global_keyprefix_to_broker_and_backend_transport(self): + mock_config = MagicMock() + mock_config.BROKER_USE_SSL = False + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1 + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.CELERY_BACKEND = "redis" + mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" + mock_config.CELERY_USE_SENTINEL = False + mock_config.LOG_FORMAT = "%(message)s" + mock_config.LOG_TZ = "UTC" + mock_config.LOG_FILE = None + mock_config.CELERY_TASK_ANNOTATIONS = {} + + mock_config.CELERY_BEAT_SCHEDULER_TIME = 1 + mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False + mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False + mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False + mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False + mock_config.ENABLE_CLEAN_MESSAGES = False + mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False + mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False + mock_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK = False + mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False + mock_config.MARKETPLACE_ENABLED = False + mock_config.WORKFLOW_LOG_CLEANUP_ENABLED = False + mock_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK = False + mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False + mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1 + mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False + mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15 + mock_config.ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK = False + mock_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL = 30 + mock_config.ENTERPRISE_ENABLED = False + mock_config.ENTERPRISE_TELEMETRY_ENABLED = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from dify_app import DifyApp + from extensions.ext_celery import init_app + + app = DifyApp(__name__) + celery_app = init_app(app) + + assert celery_app.conf["broker_transport_options"]["global_keyprefix"] == "enterprise-a:" + assert celery_app.conf["result_backend_transport_options"]["global_keyprefix"] == "enterprise-a:" diff --git a/api/tests/unit_tests/extensions/test_pubsub_channel.py b/api/tests/unit_tests/extensions/test_pubsub_channel.py index a5b41a7266..926c406ad4 100644 --- a/api/tests/unit_tests/extensions/test_pubsub_channel.py +++ b/api/tests/unit_tests/extensions/test_pubsub_channel.py @@ -6,6 +6,7 @@ from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastCh def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object()) channel = ext_redis.get_pubsub_broadcast_channel() @@ -14,6 +15,7 @@ def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch): def test_get_pubsub_broadcast_channel_sharded(monkeypatch): monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + monkeypatch.setattr(ext_redis, "_pubsub_redis_client", object()) channel = ext_redis.get_pubsub_broadcast_channel() diff --git a/api/tests/unit_tests/extensions/test_redis.py b/api/tests/unit_tests/extensions/test_redis.py index 5e9be4ab9b..21248439bf 100644 --- a/api/tests/unit_tests/extensions/test_redis.py +++ b/api/tests/unit_tests/extensions/test_redis.py @@ -1,12 +1,15 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from redis import RedisError from redis.retry import Retry from extensions.ext_redis import ( + RedisClientWrapper, _get_base_redis_params, _get_cluster_connection_health_params, _get_connection_health_params, + _normalize_redis_key_prefix, + _serialize_redis_name, redis_fallback, ) @@ -123,3 +126,99 @@ class TestRedisFallback: assert test_func.__name__ == "test_func" assert test_func.__doc__ == "Test function docstring" + + +class TestRedisKeyPrefixHelpers: + def test_normalize_redis_key_prefix_trims_whitespace(self): + assert _normalize_redis_key_prefix(" enterprise-a ") == "enterprise-a" + + def test_normalize_redis_key_prefix_treats_whitespace_only_as_empty(self): + assert _normalize_redis_key_prefix(" ") == "" + + def test_serialize_redis_name_returns_original_when_prefix_empty(self): + assert _serialize_redis_name("model_lb_index:test", "") == "model_lb_index:test" + + def test_serialize_redis_name_adds_single_colon_separator(self): + assert _serialize_redis_name("model_lb_index:test", "enterprise-a") == "enterprise-a:model_lb_index:test" + + +class TestRedisClientWrapperKeyPrefix: + def test_wrapper_get_prefixes_string_keys(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.get("oauth_state:abc") + + mock_client.get.assert_called_once_with("enterprise-a:oauth_state:abc") + + def test_wrapper_delete_prefixes_multiple_keys(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.delete("key:a", "key:b") + + mock_client.delete.assert_called_once_with("enterprise-a:key:a", "enterprise-a:key:b") + + def test_wrapper_lock_prefixes_lock_name(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.lock("resource-lock", timeout=10) + + mock_client.lock.assert_called_once() + args, kwargs = mock_client.lock.call_args + assert args == ("enterprise-a:resource-lock",) + assert kwargs["timeout"] == 10 + + def test_wrapper_hash_operations_prefix_key_name(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.hset("hash:key", "field", "value") + wrapper.hgetall("hash:key") + + mock_client.hset.assert_called_once_with("enterprise-a:hash:key", "field", "value") + mock_client.hgetall.assert_called_once_with("enterprise-a:hash:key") + + def test_wrapper_zadd_prefixes_sorted_set_name(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + wrapper.zadd("zset:key", {"member": 1}) + + mock_client.zadd.assert_called_once() + args, kwargs = mock_client.zadd.call_args + assert args == ("enterprise-a:zset:key", {"member": 1}) + assert kwargs["nx"] is False + + def test_wrapper_preserves_keys_when_prefix_is_empty(self): + mock_client = MagicMock() + wrapper = RedisClientWrapper() + wrapper.initialize(mock_client) + + with patch("extensions.ext_redis.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = " " + + wrapper.get("plain:key") + + mock_client.get.assert_called_once_with("plain:key") diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index 460374b6f6..8bef01c1ed 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -139,6 +139,28 @@ class TestTopic: mock_redis_client.publish.assert_called_once_with("test-topic", payload) + def test_publish_prefixes_regular_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + topic = Topic(mock_redis_client, "test-topic") + + topic.publish(b"test message") + + mock_redis_client.publish.assert_called_once_with("enterprise-a:test-topic", b"test message") + + def test_subscribe_prefixes_regular_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + topic = Topic(mock_redis_client, "test-topic") + + subscription = topic.subscribe() + try: + subscription._start_if_needed() + finally: + subscription.close() + + mock_redis_client.pubsub.return_value.subscribe.assert_called_once_with("enterprise-a:test-topic") + class TestShardedTopic: """Test cases for the ShardedTopic class.""" @@ -176,6 +198,15 @@ class TestShardedTopic: mock_redis_client.spublish.assert_called_once_with("test-sharded-topic", payload) + def test_publish_prefixes_sharded_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic") + + sharded_topic.publish(b"test sharded message") + + mock_redis_client.spublish.assert_called_once_with("enterprise-a:test-sharded-topic", b"test sharded message") + def test_subscribe_returns_sharded_subscription(self, sharded_topic: ShardedTopic, mock_redis_client: MagicMock): """Test that subscribe() returns a _RedisShardedSubscription instance.""" subscription = sharded_topic.subscribe() @@ -185,6 +216,19 @@ class TestShardedTopic: assert subscription._pubsub is mock_redis_client.pubsub.return_value assert subscription._topic == "test-sharded-topic" + def test_subscribe_prefixes_sharded_topic(self, mock_redis_client: MagicMock): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + sharded_topic = ShardedTopic(mock_redis_client, "test-sharded-topic") + + subscription = sharded_topic.subscribe() + try: + subscription._start_if_needed() + finally: + subscription.close() + + mock_redis_client.pubsub.return_value.ssubscribe.assert_called_once_with("enterprise-a:test-sharded-topic") + @dataclasses.dataclass(frozen=True) class SubscriptionTestCase: diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py index 0886b70ee5..c6f57c7e59 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -1,7 +1,8 @@ import threading import time from dataclasses import dataclass -from typing import cast +from typing import Any, cast +from unittest.mock import patch import pytest @@ -29,7 +30,7 @@ class FakeStreamsRedis: self._dollar_snapshots: dict[str, int] = {} # Publisher API - def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str: """Append entry to stream; accept optional maxlen for API compatibility. The test double ignores maxlen trimming semantics; only records the entry. @@ -44,7 +45,7 @@ class FakeStreamsRedis: self._expire_calls[key] = self._expire_calls.get(key, 0) + 1 # Consumer API - def xread(self, streams: dict, block: int | None = None, count: int | None = None): + def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None): # Expect a single key assert len(streams) == 1 key, last_id = next(iter(streams.items())) @@ -79,7 +80,7 @@ class BlockingRedis: def __init__(self) -> None: self._release = threading.Event() - def xread(self, streams: dict, block: int | None = None, count: int | None = None): + def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None): self._release.wait(timeout=block / 1000.0 if block else None) return [] @@ -150,6 +151,25 @@ class TestStreamsBroadcastChannel: # Expire called after publish assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + def test_topic_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + + topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("alpha") + + assert topic._topic == "alpha" + assert topic._key == "enterprise-a:stream:alpha" + + def test_publish_uses_prefixed_stream_key(self, fake_redis: FakeStreamsRedis): + with patch("extensions.redis_names.dify_config") as mock_config: + mock_config.REDIS_KEY_PREFIX = "enterprise-a" + topic = StreamsBroadcastChannel(fake_redis, retention_seconds=60).topic("beta") + + topic.publish(b"hello") + + assert fake_redis._store["enterprise-a:stream:beta"][0][1] == {b"data": b"hello"} + assert fake_redis._expire_calls.get("enterprise-a:stream:beta", 0) >= 1 + def test_topic_exposes_self_as_producer_and_subscriber(self, streams_channel: StreamsBroadcastChannel): topic = streams_channel.topic("producer-subscriber") @@ -225,7 +245,7 @@ class TestStreamsSubscription: self._fields = fields self._calls = 0 - def xread(self, streams: dict, block: int | None = None, count: int | None = None): + def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None): self._calls += 1 if self._calls == 1: key = next(iter(streams)) diff --git a/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py index 7087490845..cad9d47bba 100644 --- a/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py +++ b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py @@ -1,4 +1,5 @@ import json +from typing import Any from libs.pyrefly_type_coverage import ( CoverageSummary, @@ -8,11 +9,11 @@ from libs.pyrefly_type_coverage import ( ) -def _make_report(summary: dict) -> str: +def _make_report(summary: dict[str, Any]) -> str: return json.dumps({"module_reports": [], "summary": summary}) -_SAMPLE_SUMMARY: dict = { +_SAMPLE_SUMMARY: dict[str, Any] = { "n_modules": 100, "n_typable": 1000, "n_typed": 400, diff --git a/api/tests/unit_tests/libs/test_sendgrid_client.py b/api/tests/unit_tests/libs/test_sendgrid_client.py index 85744003c7..a65a9b1882 100644 --- a/api/tests/unit_tests/libs/test_sendgrid_client.py +++ b/api/tests/unit_tests/libs/test_sendgrid_client.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -6,7 +7,7 @@ from python_http_client.exceptions import UnauthorizedError from libs.sendgrid import SendGridClient -def _mail(to: str = "user@example.com") -> dict: +def _mail(to: str = "user@example.com") -> dict[str, Any]: return {"to": to, "subject": "Hi", "html": "Hi"} diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py index 1edf4899ac..96d62de2d6 100644 --- a/api/tests/unit_tests/libs/test_smtp_client.py +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import ANY, MagicMock, patch import pytest @@ -5,7 +6,7 @@ import pytest from libs.smtp import SMTPClient -def _mail() -> dict: +def _mail() -> dict[str, Any]: return {"to": "user@example.com", "subject": "Hi", "html": "Hi"} diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py index 4b5a97bf3f..b31af996ae 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -4,6 +4,7 @@ import importlib.util import sys from pathlib import Path from types import ModuleType +from typing import Any from unittest.mock import MagicMock import httpx @@ -30,8 +31,8 @@ def jina_module() -> ModuleType: return module -def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict: - config: dict = {} if api_key is None else {"api_key": api_key} +def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict[str, Any]: + config: dict[str, Any] = {} if api_key is None else {"api_key": api_key} return {"auth_type": auth_type, "config": config} @@ -47,7 +48,7 @@ def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None: @pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}]) -def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None: +def test_init_requires_api_key(jina_module: ModuleType, credentials: dict[str, Any]) -> None: with pytest.raises(ValueError, match="No API key provided"): jina_module.JinaAuth(credentials) diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index da557de8a4..c6972b5ef4 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -7,6 +7,7 @@ document, and segment service test modules that exercise import json from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest @@ -166,7 +167,7 @@ class DatasetServiceUnitDataFactory: built_in_field_enabled: bool = False, doc_form: str | None = "text_model", enable_api: bool = False, - summary_index_setting: dict | None = None, + summary_index_setting: dict[str, Any] | None = None, **kwargs, ) -> Mock: dataset = Mock(spec=Dataset) @@ -214,12 +215,12 @@ class DatasetServiceUnitDataFactory: archived: bool = False, enabled: bool = True, data_source_type: str = "upload_file", - data_source_info_dict: dict | None = None, + data_source_info_dict: dict[str, Any] | None = None, data_source_info: str | None = None, doc_form: str = "text_model", need_summary: bool = True, position: int = 0, - doc_metadata: dict | None = None, + doc_metadata: dict[str, Any] | None = None, name: str = "Document", **kwargs, ) -> Mock: diff --git a/api/tests/unit_tests/services/document_service_status.py b/api/tests/unit_tests/services/document_service_status.py deleted file mode 100644 index 1b682d5762..0000000000 --- a/api/tests/unit_tests/services/document_service_status.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Unit tests for non-SQL validation in DocumentService status management methods.""" - -from unittest.mock import Mock, create_autospec - -import pytest - -from models import Account -from models.dataset import Dataset -from services.dataset_service import DocumentService - - -class DocumentStatusTestDataFactory: - """Factory class for creating test data and mock objects for document status tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - name: str = "Test Dataset", - built_in_field_enabled: bool = False, - **kwargs, - ) -> Mock: - """Create a mock Dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.name = name - dataset.built_in_field_enabled = built_in_field_enabled - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock user (Account) with specified attributes.""" - user = create_autospec(Account, instance=True) - user.id = user_id - user.current_tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - -class TestDocumentServiceBatchUpdateDocumentStatus: - """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" - - def test_batch_update_document_status_invalid_action_error(self): - """ - Test error handling for invalid action. - - Verifies that when an invalid action is provided, a ValueError - is raised. - - This test ensures: - - Invalid actions are rejected - - Error message is clear - - Error type is correct - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock() - document_ids = ["document-123"] - - # Act & Assert - with pytest.raises(ValueError, match="Invalid action"): - DocumentService.batch_update_document_status(dataset, document_ids, "invalid_action", user) diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index dd41c0c97e..83bae370eb 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -51,7 +51,7 @@ class ExternalDatasetTestDataFactory: tenant_id: str = "tenant-1", name: str = "Test API", description: str = "Description", - settings: dict | None = None, + settings: dict[str, Any] | None = None, ) -> ExternalKnowledgeApis: """ Create a concrete ``ExternalKnowledgeApis`` instance with minimal fields. @@ -220,7 +220,7 @@ class TestExternalDatasetServiceValidateApiList: ({"endpoint": "https://example.com"}, "api_key is required"), ], ) - def test_validate_api_list_failures(self, config: dict, expected_message: str): + def test_validate_api_list_failures(self, config: dict[str, Any], expected_message: str): """ Invalid configs should raise ``ValueError`` with a clear message. """ diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index 22ab8503df..ddbc7dc041 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -6,6 +6,7 @@ which handles retrieval testing operations for datasets, including internal dataset retrieval and external knowledge base retrieval. """ +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -30,7 +31,7 @@ class HitTestingTestDataFactory: dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", provider: str = "vendor", - retrieval_model: dict | None = None, + retrieval_model: dict[str, Any] | None = None, **kwargs, ) -> Mock: """ @@ -83,7 +84,7 @@ class HitTestingTestDataFactory: @staticmethod def create_document_mock( content: str = "Test document content", - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, **kwargs, ) -> Mock: """ diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py index 6813a1bf2a..434a99b5ae 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from typing import cast +from typing import Any, cast from unittest.mock import MagicMock, Mock import pytest @@ -558,7 +558,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None: "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], ) - export_data: dict = {} + export_data: dict[str, Any] = {} pipeline = Mock(id="p1", tenant_id="t1") service._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=False) @@ -641,7 +641,7 @@ def test_append_workflow_export_data_encrypts_knowledge_retrieval_dataset_ids(mo "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], ) - export_data: dict = {} + export_data: dict[str, Any] = {} pipeline = Mock(id="p1", tenant_id="t1") service._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=False) diff --git a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py index e66d52f66b..30aa359b45 100644 --- a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py +++ b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py @@ -1,6 +1,7 @@ import json import uuid from collections import defaultdict, deque +from typing import Any import pytest @@ -60,7 +61,7 @@ class _FakeStreams: self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list) self._seq: dict[str, int] = defaultdict(int) - def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + def xadd(self, key: str, fields: dict[str, Any], *, maxlen: int | None = None) -> str: # maxlen is accepted for API compatibility with redis-py; ignored in this test double self._seq[key] += 1 eid = f"{self._seq[key]}-0" @@ -71,7 +72,7 @@ class _FakeStreams: # no-op for tests return None - def xread(self, streams: dict, block: int | None = None, count: int | None = None): + def xread(self, streams: dict[str, Any], block: int | None = None, count: int | None = None): assert len(streams) == 1 key, last_id = next(iter(streams.items())) entries = self._data.get(key, []) diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index af8fc1e84f..83258fd1b7 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -53,6 +53,7 @@ Tests available voice retrieval: - text_to_speech: Enables TTS functionality """ +from typing import Any from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest @@ -109,7 +110,7 @@ class AudioServiceTestDataFactory: return app @staticmethod - def create_workflow_mock(features_dict: dict | None = None, **kwargs) -> Mock: + def create_workflow_mock(features_dict: dict[str, Any] | None = None, **kwargs) -> Mock: """ Create a mock Workflow object. @@ -128,8 +129,8 @@ class AudioServiceTestDataFactory: @staticmethod def create_app_model_config_mock( - speech_to_text_dict: dict | None = None, - text_to_speech_dict: dict | None = None, + speech_to_text_dict: dict[str, Any] | None = None, + text_to_speech_dict: dict[str, Any] | None = None, **kwargs, ) -> Mock: """ diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index 2913ae20fe..3d08b6fd09 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -1,29 +1,20 @@ """Unit tests for DatasetService and dataset-related collaborators.""" from .dataset_service_test_helpers import ( - CloudPlan, - Dataset, - DatasetCollectionBindingService, DatasetNameDuplicateError, DatasetPermissionEnum, DatasetPermissionService, - DatasetProcessRule, DatasetService, DatasetServiceUnitDataFactory, - DocumentIndexingError, - DocumentService, LLMBadRequestError, MagicMock, - Mock, ModelFeature, ModelType, NoPermissionError, - NotFound, PipelineIconInfo, ProviderTokenNotInitError, RagPipelineDatasetCreateEntity, SimpleNamespace, - TenantAccountRole, _make_knowledge_configuration, _make_retrieval_model, _make_session_context, @@ -33,127 +24,6 @@ from .dataset_service_test_helpers import ( ) -class TestDatasetServiceQueries: - """Unit tests for DatasetService query composition and fallback branches.""" - - @pytest.fixture - def mock_dataset_query_dependencies(self): - with ( - patch("services.dataset_service.db") as mock_db, - patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped-search") as escape_like, - patch("services.dataset_service.TagService.get_target_ids_by_tag_ids") as get_target_ids, - ): - mock_db.paginate.return_value = SimpleNamespace(items=["dataset"], total=1) - yield { - "db": mock_db, - "escape_like_pattern": escape_like, - "get_target_ids": get_target_ids, - } - - def test_get_datasets_returns_paginated_results_for_public_view(self, mock_dataset_query_dependencies): - items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1") - - assert items == ["dataset"] - assert total == 1 - mock_dataset_query_dependencies["db"].paginate.assert_called_once() - mock_dataset_query_dependencies["escape_like_pattern"].assert_not_called() - - def test_get_datasets_short_circuits_for_dataset_operator_without_permissions( - self, mock_dataset_query_dependencies - ): - user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR) - mock_dataset_query_dependencies["db"].session.scalars.return_value.all.return_value = [] - - items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user) - - assert items == [] - assert total == 0 - mock_dataset_query_dependencies["db"].paginate.assert_not_called() - - def test_get_datasets_short_circuits_when_tag_lookup_returns_no_target_ids(self, mock_dataset_query_dependencies): - mock_dataset_query_dependencies["get_target_ids"].return_value = [] - - items, total = DatasetService.get_datasets( - page=1, - per_page=20, - tenant_id="tenant-1", - tag_ids=["tag-1"], - ) - - assert items == [] - assert total == 0 - mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) - mock_dataset_query_dependencies["db"].paginate.assert_not_called() - - def test_get_datasets_search_and_tag_filters_call_collaborators(self, mock_dataset_query_dependencies): - mock_dataset_query_dependencies["get_target_ids"].return_value = ["dataset-1"] - - items, total = DatasetService.get_datasets( - page=2, - per_page=10, - tenant_id="tenant-1", - search="report", - tag_ids=["tag-1"], - ) - - assert items == ["dataset"] - assert total == 1 - mock_dataset_query_dependencies["escape_like_pattern"].assert_called_once_with("report") - mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) - mock_dataset_query_dependencies["db"].paginate.assert_called_once() - - def test_get_process_rules_returns_latest_rule_when_present(self): - dataset_process_rule = Mock(spec=DatasetProcessRule) - dataset_process_rule.mode = "automatic" - dataset_process_rule.rules_dict = {"delimiter": "\n"} - - with patch("services.dataset_service.db") as mock_db: - (mock_db.session.execute.return_value.scalar_one_or_none.return_value) = dataset_process_rule - - result = DatasetService.get_process_rules("dataset-1") - - assert result == {"mode": "automatic", "rules": {"delimiter": "\n"}} - - def test_get_process_rules_falls_back_to_default_rules_when_missing(self): - with patch("services.dataset_service.db") as mock_db: - (mock_db.session.execute.return_value.scalar_one_or_none.return_value) = None - - result = DatasetService.get_process_rules("dataset-1") - - assert result == { - "mode": DocumentService.DEFAULT_RULES["mode"], - "rules": DocumentService.DEFAULT_RULES["rules"], - } - - def test_get_datasets_by_ids_returns_empty_for_missing_ids(self): - with patch("services.dataset_service.db") as mock_db: - items, total = DatasetService.get_datasets_by_ids([], "tenant-1") - - assert items == [] - assert total == 0 - mock_db.paginate.assert_not_called() - - def test_get_datasets_by_ids_uses_paginate_for_non_empty_input(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.paginate.return_value = SimpleNamespace(items=["dataset-1"], total=1) - - items, total = DatasetService.get_datasets_by_ids(["dataset-1"], "tenant-1") - - assert items == ["dataset-1"] - assert total == 1 - mock_db.paginate.assert_called_once() - - def test_get_dataset_returns_first_match(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock() - - with patch("services.dataset_service.db") as mock_db: - mock_db.session.get.return_value = dataset - - result = DatasetService.get_dataset(dataset.id) - - assert result is dataset - - class TestDatasetServiceValidation: """Unit tests for DatasetService validation helpers.""" @@ -1337,103 +1207,6 @@ class TestDatasetServiceRagPipelineSettings: class TestDatasetServicePermissionsAndLifecycle: """Unit tests for dataset permissions, deletion, and metadata helpers.""" - def test_delete_dataset_returns_false_when_dataset_is_missing(self): - with patch.object(DatasetService, "get_dataset", return_value=None): - result = DatasetService.delete_dataset("dataset-1", user=SimpleNamespace(id="user-1")) - - assert result is False - - def test_delete_dataset_checks_permission_and_deletes_dataset(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock() - - with ( - patch.object(DatasetService, "get_dataset", return_value=dataset), - patch.object(DatasetService, "check_dataset_permission") as check_permission, - patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal, - patch("services.dataset_service.db") as mock_db, - ): - result = DatasetService.delete_dataset(dataset.id, user=SimpleNamespace(id="user-1")) - - assert result is True - check_permission.assert_called_once_with(dataset, SimpleNamespace(id="user-1")) - send_deleted_signal.assert_called_once_with(dataset) - mock_db.session.delete.assert_called_once_with(dataset) - mock_db.session.commit.assert_called_once() - - def test_dataset_use_check_returns_scalar_result(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.session.execute.return_value.scalar_one.return_value = True - - result = DatasetService.dataset_use_check("dataset-1") - - assert result is True - - def test_check_dataset_permission_rejects_cross_tenant_access(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock(tenant_id="tenant-a") - user = DatasetServiceUnitDataFactory.create_user_mock(tenant_id="tenant-b") - - with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_permission(dataset, user) - - def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock( - permission=DatasetPermissionEnum.ONLY_ME, - created_by="owner-1", - ) - user = DatasetServiceUnitDataFactory.create_user_mock( - user_id="member-1", - role=TenantAccountRole.EDITOR, - ) - - with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_permission(dataset, user) - - def test_check_dataset_permission_rejects_partial_team_user_without_binding(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock( - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="owner-1", - ) - user = DatasetServiceUnitDataFactory.create_user_mock( - user_id="member-1", - role=TenantAccountRole.EDITOR, - ) - - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalar.return_value = None - - with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_permission(dataset, user) - - def test_check_dataset_permission_allows_partial_team_creator_without_lookup(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock( - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="creator-1", - ) - user = DatasetServiceUnitDataFactory.create_user_mock( - user_id="creator-1", - role=TenantAccountRole.EDITOR, - ) - - with patch("services.dataset_service.db") as mock_db: - DatasetService.check_dataset_permission(dataset, user) - - mock_db.session.scalar.assert_not_called() - - def test_check_dataset_permission_allows_partial_team_member_with_binding(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock( - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="owner-1", - ) - user = DatasetServiceUnitDataFactory.create_user_mock( - user_id="member-1", - role=TenantAccountRole.EDITOR, - ) - - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalar.return_value = object() - - DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_operator_permission_validates_required_arguments(self): with pytest.raises(ValueError, match="Dataset not found"): DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None) @@ -1441,279 +1214,14 @@ class TestDatasetServicePermissionsAndLifecycle: with pytest.raises(ValueError, match="User not found"): DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1")) - def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock( - permission=DatasetPermissionEnum.ONLY_ME, - created_by="owner-1", - ) - user = DatasetServiceUnitDataFactory.create_user_mock( - user_id="member-1", - role=TenantAccountRole.EDITOR, - ) - - with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - def test_check_dataset_operator_permission_rejects_partial_team_without_binding(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - user = DatasetServiceUnitDataFactory.create_user_mock( - user_id="member-1", - role=TenantAccountRole.EDITOR, - ) - - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = [] - - with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - def test_get_dataset_queries_delegates_to_paginate(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.desc.side_effect = lambda column: column - mock_db.paginate.return_value = SimpleNamespace(items=["query"], total=1) - - items, total = DatasetService.get_dataset_queries("dataset-1", page=1, per_page=20) - - assert items == ["query"] - assert total == 1 - mock_db.paginate.assert_called_once() - - def test_get_related_apps_returns_ordered_query_results(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.desc.side_effect = lambda column: column - mock_db.session.scalars.return_value.all.return_value = ["relation-1"] - - result = DatasetService.get_related_apps("dataset-1") - - assert result == ["relation-1"] - - def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self): - with patch.object(DatasetService, "get_dataset", return_value=None): - with pytest.raises(NotFound, match="Dataset not found"): - DatasetService.update_dataset_api_status("dataset-1", True) - - def test_update_dataset_api_status_requires_current_user_id(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) - - with ( - patch.object(DatasetService, "get_dataset", return_value=dataset), - patch("services.dataset_service.current_user", SimpleNamespace(id=None)), - ): - with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.update_dataset_api_status(dataset.id, True) - - def test_update_dataset_api_status_updates_fields_and_commits(self): - dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) - now = object() - - with ( - patch.object(DatasetService, "get_dataset", return_value=dataset), - patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), - patch("services.dataset_service.naive_utc_now", return_value=now), - patch("services.dataset_service.db") as mock_db, - ): - DatasetService.update_dataset_api_status(dataset.id, True) - - assert dataset.enable_api is True - assert dataset.updated_by == "user-1" - assert dataset.updated_at is now - mock_db.session.commit.assert_called_once() - - def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(self): - class FakeAccount: - pass - - current_user = FakeAccount() - current_user.current_tenant_id = "tenant-1" - - features = SimpleNamespace( - billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) - ) - - with ( - patch("services.dataset_service.Account", FakeAccount), - patch("services.dataset_service.current_user", current_user), - patch("services.dataset_service.FeatureService.get_features", return_value=features), - patch("services.dataset_service.db") as mock_db, - ): - result = DatasetService.get_dataset_auto_disable_logs("dataset-1") - - assert result == {"document_ids": [], "count": 0} - mock_db.session.scalars.assert_not_called() - - def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self): - class FakeAccount: - pass - - current_user = FakeAccount() - current_user.current_tenant_id = "tenant-1" - logs = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] - features = SimpleNamespace( - billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) - ) - - with ( - patch("services.dataset_service.Account", FakeAccount), - patch("services.dataset_service.current_user", current_user), - patch("services.dataset_service.FeatureService.get_features", return_value=features), - patch("services.dataset_service.db") as mock_db, - ): - mock_db.session.scalars.return_value.all.return_value = logs - - result = DatasetService.get_dataset_auto_disable_logs("dataset-1") - - assert result == {"document_ids": ["doc-1", "doc-2"], "count": 2} - - -class TestDatasetServiceDocumentIndexing: - """Unit tests for pause/recover/retry orchestration without SQL assertions.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - with ( - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.db.session") as mock_db_session, - patch("services.dataset_service.current_user") as mock_current_user, - ): - mock_current_user.id = "user-123" - yield { - "redis_client": mock_redis, - "db_session": mock_db_session, - "current_user": mock_current_user, - } - - def test_pause_document_success(self, mock_document_service_dependencies): - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") - - DocumentService.pause_document(document) - - assert document.is_paused is True - assert document.paused_by == "user-123" - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( - f"document_{document.id}_is_paused", - "True", - ) - - def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") - - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - def test_recover_document_success(self, mock_document_service_dependencies): - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) - - with patch("services.dataset_service.recover_document_indexing_task") as recover_task: - DocumentService.recover_document(document) - - assert document.is_paused is False - assert document.paused_by is None - assert document.paused_at is None - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( - f"document_{document.id}_is_paused" - ) - recover_task.delay.assert_called_once_with(document.dataset_id, document.id) - - def test_retry_document_indexing_success(self, mock_document_service_dependencies): - dataset_id = "dataset-123" - documents = [ - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), - ] - mock_document_service_dependencies["redis_client"].get.return_value = None - - with patch("services.dataset_service.retry_document_indexing_task") as retry_task: - DocumentService.retry_document(dataset_id, documents) - - assert all(document.indexing_status == "waiting" for document in documents) - assert mock_document_service_dependencies["db_session"].add.call_count == 2 - assert mock_document_service_dependencies["db_session"].commit.call_count == 2 - assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 - retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") - class TestDatasetCollectionBindingService: """Unit tests for dataset collection binding lookups and creation.""" - def test_get_dataset_collection_binding_returns_existing_binding(self): - binding = SimpleNamespace(id="binding-1") - - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalar.return_value = binding - - result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") - - assert result is binding - mock_db.session.add.assert_not_called() - - def test_get_dataset_collection_binding_creates_binding_when_missing(self): - created_binding = SimpleNamespace(id="binding-2") - - with ( - patch("services.dataset_service.db") as mock_db, - patch("services.dataset_service.select"), - patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls, - patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"), - ): - mock_db.session.scalar.return_value = None - - result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset") - - assert result is created_binding - binding_cls.assert_called_once_with( - provider_name="provider", - model_name="model", - collection_name="generated-collection", - type="dataset", - ) - mock_db.session.add.assert_called_once_with(created_binding) - mock_db.session.commit.assert_called_once() - - def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalar.return_value = None - - with pytest.raises(ValueError, match="Dataset collection binding not found"): - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") - - def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self): - binding = SimpleNamespace(id="binding-1") - - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalar.return_value = binding - - result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") - - assert result is binding - class TestDatasetPermissionService: """Unit tests for dataset partial-member management helpers.""" - def test_get_dataset_partial_member_list_returns_scalar_results(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = ["user-1", "user-2"] - - result = DatasetPermissionService.get_dataset_partial_member_list("dataset-1") - - assert result == ["user-1", "user-2"] - - def test_update_partial_member_list_replaces_permissions_and_commits(self): - with patch("services.dataset_service.db") as mock_db: - DatasetPermissionService.update_partial_member_list( - "tenant-1", - "dataset-1", - [{"user_id": "user-1"}, {"user_id": "user-2"}], - ) - - mock_db.session.execute.assert_called() - mock_db.session.add_all.assert_called_once() - mock_db.session.commit.assert_called_once() - def test_update_partial_member_list_rolls_back_on_exception(self): with patch("services.dataset_service.db") as mock_db: mock_db.session.add_all.side_effect = RuntimeError("boom") @@ -1777,13 +1285,6 @@ class TestDatasetPermissionService: [{"user_id": "user-1"}], ) - def test_clear_partial_member_list_deletes_permissions_and_commits(self): - with patch("services.dataset_service.db") as mock_db: - DatasetPermissionService.clear_partial_member_list("dataset-1") - - mock_db.session.execute.assert_called() - mock_db.session.commit.assert_called_once() - def test_clear_partial_member_list_rolls_back_on_exception(self): with patch("services.dataset_service.db") as mock_db: mock_db.session.execute.side_effect = RuntimeError("boom") diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index 9c1a92b4d9..fdea0ba869 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -8,6 +8,7 @@ Target: 1500+ lines of comprehensive test coverage. import json import re from datetime import datetime +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -31,7 +32,7 @@ class ExternalDatasetServiceTestDataFactory: api_id: str = "api-123", tenant_id: str = "tenant-123", name: str = "Test API", - settings: dict | None = None, + settings: dict[str, Any] | None = None, **kwargs, ) -> Mock: """Create a mock ExternalKnowledgeApis object.""" @@ -120,8 +121,8 @@ class ExternalDatasetServiceTestDataFactory: def create_api_setting_mock( url: str = "https://api.example.com/retrieval", request_method: str = "post", - headers: dict | None = None, - params: dict | None = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, ) -> ExternalKnowledgeApiSetting: """Create an ExternalKnowledgeApiSetting object.""" if headers is None: diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py index f3efc4463e..5fcad615c8 100644 --- a/api/tests/unit_tests/services/test_messages_clean_service.py +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -1,4 +1,5 @@ import datetime +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -18,7 +19,7 @@ def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage: return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1)) -def make_plan_provider(tenant_plans: dict) -> MagicMock: +def make_plan_provider(tenant_plans: dict[str, Any]) -> MagicMock: """Helper to create a mock plan_provider that returns the given tenant_plans.""" provider = MagicMock() provider.return_value = tenant_plans diff --git a/api/tests/unit_tests/services/test_operation_service.py b/api/tests/unit_tests/services/test_operation_service.py index a4c69b23ac..e43a7fa649 100644 --- a/api/tests/unit_tests/services/test_operation_service.py +++ b/api/tests/unit_tests/services/test_operation_service.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import MagicMock, patch import httpx @@ -105,7 +106,7 @@ class TestOperationService: ) @patch.object(OperationService, "_send_request") def test_should_map_parameters_correctly_when_record_utm_called( - self, mock_send: MagicMock, utm_info: dict, expected_params: dict + self, mock_send: MagicMock, utm_info: dict[str, Any], expected_params: dict[str, Any] ): """Test that record_utm correctly maps utm_info to parameters and calls _send_request""" # Arrange diff --git a/api/tests/unit_tests/services/test_schedule_service.py b/api/tests/unit_tests/services/test_schedule_service.py index 334062242b..0f8f7ffab5 100644 --- a/api/tests/unit_tests/services/test_schedule_service.py +++ b/api/tests/unit_tests/services/test_schedule_service.py @@ -2,23 +2,16 @@ import unittest from datetime import UTC, datetime from types import SimpleNamespace from typing import Any, cast -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock import pytest from sqlalchemy.orm import Session from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig -from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from events.event_handlers.sync_workflow_schedule_when_app_published import ( - sync_schedule_from_workflow, -) +from core.workflow.nodes.trigger_schedule.entities import VisualConfig +from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h -from models.account import Account, TenantAccountJoin -from models.trigger import WorkflowSchedulePlan from models.workflow import Workflow -from services.errors.account import AccountNotFoundError -from services.trigger import schedule_service as service_module from services.trigger.schedule_service import ScheduleService @@ -83,180 +76,6 @@ class TestScheduleService(unittest.TestCase): with pytest.raises(UnknownTimeZoneError): calculate_next_run_at(cron_expr, timezone) - @patch("libs.schedule_utils.calculate_next_run_at") - def test_create_schedule(self, mock_calculate_next_run): - """Test creating a new schedule.""" - mock_session = MagicMock(spec=Session) - mock_calculate_next_run.return_value = datetime(2025, 8, 30, 10, 30, 0, tzinfo=UTC) - - config = ScheduleConfig( - node_id="start", - cron_expression="30 10 * * *", - timezone="UTC", - ) - - schedule = ScheduleService.create_schedule( - session=mock_session, - tenant_id="test-tenant", - app_id="test-app", - config=config, - ) - - assert schedule is not None - assert schedule.tenant_id == "test-tenant" - assert schedule.app_id == "test-app" - assert schedule.node_id == "start" - assert schedule.cron_expression == "30 10 * * *" - assert schedule.timezone == "UTC" - assert schedule.next_run_at is not None - mock_session.add.assert_called_once() - mock_session.flush.assert_called_once() - - @patch("services.trigger.schedule_service.calculate_next_run_at") - def test_update_schedule(self, mock_calculate_next_run): - """Test updating an existing schedule.""" - mock_session = MagicMock(spec=Session) - mock_schedule = Mock(spec=WorkflowSchedulePlan) - mock_schedule.cron_expression = "0 12 * * *" - mock_schedule.timezone = "America/New_York" - mock_session.get.return_value = mock_schedule - mock_calculate_next_run.return_value = datetime(2025, 8, 30, 12, 0, 0, tzinfo=UTC) - - updates = SchedulePlanUpdate( - cron_expression="0 12 * * *", - timezone="America/New_York", - ) - - result = ScheduleService.update_schedule( - session=mock_session, - schedule_id="test-schedule-id", - updates=updates, - ) - - assert result is not None - assert result.cron_expression == "0 12 * * *" - assert result.timezone == "America/New_York" - mock_calculate_next_run.assert_called_once() - mock_session.flush.assert_called_once() - - def test_update_schedule_not_found(self): - """Test updating a non-existent schedule raises exception.""" - from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError - - mock_session = MagicMock(spec=Session) - mock_session.get.return_value = None - - updates = SchedulePlanUpdate( - cron_expression="0 12 * * *", - ) - - with pytest.raises(ScheduleNotFoundError) as context: - ScheduleService.update_schedule( - session=mock_session, - schedule_id="non-existent-id", - updates=updates, - ) - - assert "Schedule not found: non-existent-id" in str(context.value) - mock_session.flush.assert_not_called() - - def test_delete_schedule(self): - """Test deleting a schedule.""" - mock_session = MagicMock(spec=Session) - mock_schedule = Mock(spec=WorkflowSchedulePlan) - mock_session.get.return_value = mock_schedule - - # Should not raise exception and complete successfully - ScheduleService.delete_schedule( - session=mock_session, - schedule_id="test-schedule-id", - ) - - mock_session.delete.assert_called_once_with(mock_schedule) - mock_session.flush.assert_called_once() - - def test_delete_schedule_not_found(self): - """Test deleting a non-existent schedule raises exception.""" - from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError - - mock_session = MagicMock(spec=Session) - mock_session.get.return_value = None - - # Should raise ScheduleNotFoundError - with pytest.raises(ScheduleNotFoundError) as context: - ScheduleService.delete_schedule( - session=mock_session, - schedule_id="non-existent-id", - ) - - assert "Schedule not found: non-existent-id" in str(context.value) - mock_session.delete.assert_not_called() - - @patch("services.trigger.schedule_service.select") - def test_get_tenant_owner(self, mock_select): - """Test getting tenant owner account.""" - mock_session = MagicMock(spec=Session) - mock_account = Mock(spec=Account) - mock_account.id = "owner-account-id" - - # Mock owner query - mock_owner_result = Mock(spec=TenantAccountJoin) - mock_owner_result.account_id = "owner-account-id" - - mock_session.execute.return_value.scalar_one_or_none.return_value = mock_owner_result - mock_session.get.return_value = mock_account - - result = ScheduleService.get_tenant_owner( - session=mock_session, - tenant_id="test-tenant", - ) - - assert result is not None - assert result.id == "owner-account-id" - - @patch("services.trigger.schedule_service.select") - def test_get_tenant_owner_fallback_to_admin(self, mock_select): - """Test getting tenant owner falls back to admin if no owner.""" - mock_session = MagicMock(spec=Session) - mock_account = Mock(spec=Account) - mock_account.id = "admin-account-id" - - # Mock admin query (owner returns None) - mock_admin_result = Mock(spec=TenantAccountJoin) - mock_admin_result.account_id = "admin-account-id" - - mock_session.execute.return_value.scalar_one_or_none.side_effect = [None, mock_admin_result] - mock_session.get.return_value = mock_account - - result = ScheduleService.get_tenant_owner( - session=mock_session, - tenant_id="test-tenant", - ) - - assert result is not None - assert result.id == "admin-account-id" - - @patch("services.trigger.schedule_service.calculate_next_run_at") - def test_update_next_run_at(self, mock_calculate_next_run): - """Test updating next run time after schedule triggered.""" - mock_session = MagicMock(spec=Session) - mock_schedule = Mock(spec=WorkflowSchedulePlan) - mock_schedule.cron_expression = "30 10 * * *" - mock_schedule.timezone = "UTC" - mock_session.get.return_value = mock_schedule - - next_time = datetime(2025, 8, 31, 10, 30, 0, tzinfo=UTC) - mock_calculate_next_run.return_value = next_time - - result = ScheduleService.update_next_run_at( - session=mock_session, - schedule_id="test-schedule-id", - ) - - assert result == next_time - assert mock_schedule.next_run_at == next_time - mock_session.flush.assert_called_once() - class TestVisualToCron(unittest.TestCase): """Test cases for visual configuration to cron conversion.""" @@ -678,108 +497,6 @@ class TestScheduleWithTimezone(unittest.TestCase): assert summer_next.hour == 14 -class TestSyncScheduleFromWorkflow(unittest.TestCase): - """Test cases for syncing schedule from workflow.""" - - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.db") - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService") - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.select") - def test_sync_schedule_create_new(self, mock_select, mock_service, mock_db): - """Test creating new schedule when none exists.""" - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.__enter__ = MagicMock(return_value=mock_session) - mock_session.__exit__ = MagicMock(return_value=None) - sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session))) - with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker): - mock_session.scalar.return_value = None # No existing plan - - # Mock extract_schedule_config to return a ScheduleConfig object - mock_config = Mock(spec=ScheduleConfig) - mock_config.node_id = "start" - mock_config.cron_expression = "30 10 * * *" - mock_config.timezone = "UTC" - mock_service.extract_schedule_config.return_value = mock_config - - mock_new_plan = Mock(spec=WorkflowSchedulePlan) - mock_service.create_schedule.return_value = mock_new_plan - - workflow = Mock(spec=Workflow) - result = sync_schedule_from_workflow("tenant-id", "app-id", workflow) - - assert result == mock_new_plan - mock_service.create_schedule.assert_called_once() - mock_session.commit.assert_not_called() - - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.db") - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService") - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.select") - def test_sync_schedule_update_existing(self, mock_select, mock_service, mock_db): - """Test updating existing schedule.""" - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.__enter__ = MagicMock(return_value=mock_session) - mock_session.__exit__ = MagicMock(return_value=None) - sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session))) - - with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker): - mock_existing_plan = Mock(spec=WorkflowSchedulePlan) - mock_existing_plan.id = "existing-plan-id" - mock_session.scalar.return_value = mock_existing_plan - - # Mock extract_schedule_config to return a ScheduleConfig object - mock_config = Mock(spec=ScheduleConfig) - mock_config.node_id = "start" - mock_config.cron_expression = "0 12 * * *" - mock_config.timezone = "America/New_York" - mock_service.extract_schedule_config.return_value = mock_config - - mock_updated_plan = Mock(spec=WorkflowSchedulePlan) - mock_service.update_schedule.return_value = mock_updated_plan - - workflow = Mock(spec=Workflow) - result = sync_schedule_from_workflow("tenant-id", "app-id", workflow) - - assert result == mock_updated_plan - mock_service.update_schedule.assert_called_once() - # Verify the arguments passed to update_schedule - call_args = mock_service.update_schedule.call_args - assert call_args.kwargs["session"] == mock_session - assert call_args.kwargs["schedule_id"] == "existing-plan-id" - updates_obj = call_args.kwargs["updates"] - assert isinstance(updates_obj, SchedulePlanUpdate) - assert updates_obj.node_id == "start" - assert updates_obj.cron_expression == "0 12 * * *" - assert updates_obj.timezone == "America/New_York" - mock_session.commit.assert_not_called() - - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.db") - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService") - @patch("events.event_handlers.sync_workflow_schedule_when_app_published.select") - def test_sync_schedule_remove_when_no_config(self, mock_select, mock_service, mock_db): - """Test removing schedule when no schedule config in workflow.""" - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.__enter__ = MagicMock(return_value=mock_session) - mock_session.__exit__ = MagicMock(return_value=None) - sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session))) - - with patch("events.event_handlers.sync_workflow_schedule_when_app_published.sessionmaker", sessionmaker): - mock_existing_plan = Mock(spec=WorkflowSchedulePlan) - mock_existing_plan.id = "existing-plan-id" - mock_session.scalar.return_value = mock_existing_plan - - mock_service.extract_schedule_config.return_value = None # No schedule config - - workflow = Mock(spec=Workflow) - result = sync_schedule_from_workflow("tenant-id", "app-id", workflow) - - assert result is None - # Now using ScheduleService.delete_schedule instead of session.delete - mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id") - mock_session.commit.assert_not_called() - - @pytest.fixture def session_mock() -> MagicMock: return MagicMock(spec=Session) @@ -789,62 +506,6 @@ def _workflow(**kwargs: Any) -> Workflow: return cast(Workflow, SimpleNamespace(**kwargs)) -def test_update_schedule_should_update_only_node_id_without_recomputing_time( - session_mock: MagicMock, - monkeypatch: pytest.MonkeyPatch, -) -> None: - # Arrange - schedule = MagicMock(spec=WorkflowSchedulePlan) - schedule.cron_expression = "0 10 * * *" - schedule.timezone = "UTC" - session_mock.get.return_value = schedule - - next_run_mock = MagicMock(return_value=datetime(2026, 1, 1, 10, 0, tzinfo=UTC)) - monkeypatch.setattr(service_module, "calculate_next_run_at", next_run_mock) - - # Act - result = ScheduleService.update_schedule( - session=session_mock, - schedule_id="schedule-1", - updates=SchedulePlanUpdate(node_id="node-new"), - ) - - # Assert - assert result is schedule - assert schedule.node_id == "node-new" - next_run_mock.assert_not_called() - session_mock.flush.assert_called_once() - - -def test_get_tenant_owner_should_raise_when_account_record_missing(session_mock: MagicMock) -> None: - # Arrange - join = SimpleNamespace(account_id="account-404") - session_mock.execute.return_value.scalar_one_or_none.return_value = join - session_mock.get.return_value = None - - # Act / Assert - with pytest.raises(AccountNotFoundError, match="Account not found: account-404"): - ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1") - - -def test_get_tenant_owner_should_raise_when_no_owner_or_admin_found(session_mock: MagicMock) -> None: - # Arrange - session_mock.execute.return_value.scalar_one_or_none.side_effect = [None, None] - - # Act / Assert - with pytest.raises(AccountNotFoundError, match="Account not found for tenant: tenant-1"): - ScheduleService.get_tenant_owner(session=session_mock, tenant_id="tenant-1") - - -def test_update_next_run_at_should_raise_when_schedule_not_found(session_mock: MagicMock) -> None: - # Arrange - session_mock.get.return_value = None - - # Act / Assert - with pytest.raises(ScheduleNotFoundError, match="Schedule not found: schedule-1"): - ScheduleService.update_next_run_at(session=session_mock, schedule_id="schedule-1") - - def test_to_schedule_config_should_build_from_cron_mode() -> None: # Arrange node_config: dict[str, Any] = { diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index bd2e936b62..ebf1b36610 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextlib import json from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest @@ -28,9 +29,9 @@ def _mock_get_trigger_provider(mocker: MockerFixture, provider: object | None) - def _encrypter_mock( *, - decrypted: dict | None = None, - encrypted: dict | None = None, - masked: dict | None = None, + decrypted: dict[str, Any] | None = None, + encrypted: dict[str, Any] | None = None, + masked: dict[str, Any] | None = None, ) -> MagicMock: enc = MagicMock() enc.decrypt.return_value = decrypted or {} diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py index b0ddc7388a..2024aec13a 100644 --- a/api/tests/unit_tests/services/test_website_service.py +++ b/api/tests/unit_tests/services/test_website_service.py @@ -89,7 +89,7 @@ def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> Non ({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"), ], ) -def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None: +def test_website_crawl_api_request_from_args_requires_fields(args: dict[str, Any], missing_msg: str) -> None: with pytest.raises(ValueError, match=missing_msg): WebsiteCrawlApiRequest.from_args(args) diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 406b4fb9d0..7906daacfc 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -94,8 +94,8 @@ class TestWorkflowAssociatedDataFactory: app_id: str = "app-123", version: str = Workflow.VERSION_DRAFT, workflow_type: str = WorkflowType.WORKFLOW.value, - graph: dict | None = None, - features: dict | None = None, + graph: dict[str, Any] | None = None, + features: dict[str, Any] | None = None, unique_hash: str | None = None, **kwargs, ) -> MagicMock: @@ -1686,7 +1686,7 @@ class TestWorkflowServiceCredentialValidation: """Missing provider or model in node_data should be a no-op.""" # Arrange workflow = self._make_workflow([]) - node_data: dict = {} # no model key + node_data: dict[str, Any] = {} # no model key # Act + Assert (no error expected) service._validate_load_balancing_credentials(workflow, node_data, "node-1") @@ -2269,7 +2269,7 @@ class TestRebuildFileForUserInputsInStartNode: # Arrange file_var = self._make_variable("attachment", VariableEntityType.FILE) start_data = self._make_start_node_data([file_var]) - user_inputs: dict = {} # attachment not provided + user_inputs: dict[str, Any] = {} # attachment not provided # Act result = _rebuild_file_for_user_inputs_in_start_node( diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index ee9ba1c6d6..ad80beb4e3 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -114,6 +114,7 @@ This test suite follows a comprehensive testing strategy that covers: ================================================================================ """ +from typing import Any from unittest.mock import Mock, patch import pytest @@ -156,7 +157,7 @@ class VectorServiceTestDataFactory: indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_model_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", - index_struct_dict: dict | None = None, + index_struct_dict: dict[str, Any] | None = None, **kwargs, ) -> Mock: """ diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index 68359ba078..544e89fcee 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -1,5 +1,6 @@ import base64 from decimal import Decimal +from typing import Any from unittest.mock import Mock, patch import pytest @@ -20,7 +21,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvo from core.tools.mcp_tool.tool import MCPTool -def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: +def _make_mcp_tool(output_schema: dict[str, Any] | None = None) -> MCPTool: identity = ToolIdentity( author="test", name="test_mcp_tool", diff --git a/api/uv.lock b/api/uv.lock index 71d3a14880..9ed8d16107 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -42,6 +42,7 @@ members = [ "dify-vdb-vikingdb", "dify-vdb-weaviate", ] +overrides = [{ name = "pyarrow", specifier = ">=18.0.0" }] [[package]] name = "abnf" @@ -4986,20 +4987,17 @@ wheels = [ [[package]] name = "pyarrow" -version = "14.0.2" +version = "23.0.1" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/22/134986a4cc224d593c1afde5494d18ff629393d74cc2eddb176669f234a4/pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019", size = 1167336, upload-time = "2026-02-16T10:14:12.39Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" }, - { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" }, - { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" }, - { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" }, - { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" }, - { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" }, - { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4b/4166bb5abbfe6f750fc60ad337c43ecf61340fa52ab386da6e8dbf9e63c4/pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f", size = 34214575, upload-time = "2026-02-16T10:09:56.225Z" }, + { url = "https://files.pythonhosted.org/packages/e1/da/3f941e3734ac8088ea588b53e860baeddac8323ea40ce22e3d0baa865cc9/pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7", size = 35832540, upload-time = "2026-02-16T10:10:03.428Z" }, + { url = "https://files.pythonhosted.org/packages/88/7c/3d841c366620e906d54430817531b877ba646310296df42ef697308c2705/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9", size = 44470940, upload-time = "2026-02-16T10:10:10.704Z" }, + { url = "https://files.pythonhosted.org/packages/2c/a5/da83046273d990f256cb79796a190bbf7ec999269705ddc609403f8c6b06/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05", size = 47586063, upload-time = "2026-02-16T10:10:17.95Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3c/b7d2ebcff47a514f47f9da1e74b7949138c58cfeb108cdd4ee62f43f0cf3/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67", size = 48173045, upload-time = "2026-02-16T10:10:25.363Z" }, + { url = "https://files.pythonhosted.org/packages/43/b2/b40961262213beaba6acfc88698eb773dfce32ecdf34d19291db94c2bd73/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730", size = 50621741, upload-time = "2026-02-16T10:10:33.477Z" }, + { url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index 4426a882f1..856b04a3df 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -351,6 +351,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # Optional: limit total Redis connections used by API/Worker (unset for default) # Align with API's REDIS_MAX_CONNECTIONS in configs REDIS_MAX_CONNECTIONS= diff --git a/docker/README.md b/docker/README.md index 4c40317f37..3130fa9886 100644 --- a/docker/README.md +++ b/docker/README.md @@ -88,6 +88,7 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w 1. **Redis Configuration**: - `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings. + - `REDIS_KEY_PREFIX`: Optional global namespace prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. 1. **Celery Configuration**: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1fc1cfdf9e..c1ddba4f80 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} + REDIS_KEY_PREFIX: ${REDIS_KEY_PREFIX:-} REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8901c7948f..4444981601 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -24,9 +24,6 @@ catalogs: '@cucumber/cucumber': specifier: 12.8.0 version: 12.8.0 - '@date-fns/tz': - specifier: 1.4.1 - version: 1.4.1 '@egoist/tailwindcss-icons': specifier: 1.9.2 version: 1.9.2 @@ -270,9 +267,6 @@ catalogs: cron-parser: specifier: 5.5.0 version: 5.5.0 - date-fns: - specifier: 4.1.0 - version: 4.1.0 dayjs: specifier: 1.11.20 version: 1.11.20 @@ -655,9 +649,6 @@ importers: '@base-ui/react': specifier: 'catalog:' version: 1.4.0(@date-fns/tz@1.4.1)(@types/react@19.2.14)(date-fns@4.1.0)(react-dom@19.2.5(react@19.2.5))(react@19.2.5) - '@date-fns/tz': - specifier: 'catalog:' - version: 1.4.1 '@emoji-mart/data': specifier: 'catalog:' version: 1.2.1 @@ -760,9 +751,6 @@ importers: cron-parser: specifier: 'catalog:' version: 5.5.0 - date-fns: - specifier: 'catalog:' - version: 4.1.0 dayjs: specifier: 'catalog:' version: 1.11.20 diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index 765c7045e5..8e45367db4 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -207,20 +207,6 @@ vi.mock('@/app/components/app/switch-app-modal', () => ({ }, })) -vi.mock('@/app/components/base/confirm', () => ({ - default: ({ isShow, onConfirm, onCancel, title }: Record) => { - if (!isShow) - return null - return ( -
- {title as string} - - -
- ) - }, -})) - vi.mock('@/app/components/workflow/dsl-export-confirm-modal', () => ({ default: ({ onConfirm, onClose }: Record) => (
@@ -342,14 +328,16 @@ describe('App Card Operations Flow', () => { fireEvent.click(deleteBtn) }) - const confirmBtn = screen.queryByTestId('confirm-delete') - if (confirmBtn) { - fireEvent.click(confirmBtn) + await waitFor(() => { + expect(screen.getByText('app.deleteAppConfirmTitle')).toBeInTheDocument() + }) - await waitFor(() => { - expect(mockDeleteAppMutation).toHaveBeenCalledWith('app-to-delete') - }) - } + fireEvent.change(screen.getByRole('textbox'), { target: { value: 'Deletable App' } }) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' })) + + await waitFor(() => { + expect(mockDeleteAppMutation).toHaveBeenCalledWith('app-to-delete') + }) } }) }) diff --git a/web/__tests__/tools/tool-provider-detail-flow.test.tsx b/web/__tests__/tools/tool-provider-detail-flow.test.tsx index 3d66467695..ce5ffe531e 100644 --- a/web/__tests__/tools/tool-provider-detail-flow.test.tsx +++ b/web/__tests__/tools/tool-provider-detail-flow.test.tsx @@ -133,26 +133,6 @@ vi.mock('@/app/components/base/drawer', () => ({ ), })) -vi.mock('@/app/components/base/confirm', () => ({ - default: ({ title, isShow, onConfirm, onCancel }: { - title: string - content: string - isShow: boolean - onConfirm: () => void - onCancel: () => void - }) => ( - isShow - ? ( -
- {title} - - -
- ) - : null - ), -})) - vi.mock('@/app/components/base/ui/toast', () => ({ default: { notify: vi.fn() }, toast: { @@ -242,6 +222,8 @@ vi.mock('@/app/components/tools/provider/tool-item', () => ({ const { default: ProviderDetail } = await import('@/app/components/tools/provider/detail') +const getDeleteConfirmButton = () => screen.getByRole('button', { name: /operation\.confirm$/ }) + const makeCollection = (overrides: Partial = {}): Collection => ({ id: 'test-collection', name: 'test_collection', @@ -465,11 +447,10 @@ describe('Tool Provider Detail Flow Integration', () => { fireEvent.click(screen.getByTestId('custom-modal-remove')) await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() expect(screen.getByText('Delete Tool')).toBeInTheDocument() }) - fireEvent.click(screen.getByTestId('confirm-ok')) + fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { expect(mockRemoveCustomCollection).toHaveBeenCalledWith('test_collection') expect(mockOnRefreshData).toHaveBeenCalled() @@ -527,10 +508,10 @@ describe('Tool Provider Detail Flow Integration', () => { fireEvent.click(screen.getByTestId('wf-modal-remove')) await waitFor(() => { - expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + expect(screen.getByText('Delete Tool')).toBeInTheDocument() }) - fireEvent.click(screen.getByTestId('confirm-ok')) + fireEvent.click(getDeleteConfirmButton()) await waitFor(() => { expect(mockDeleteWorkflowTool).toHaveBeenCalledWith('test-collection') expect(mockOnRefreshData).toHaveBeenCalled() diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index 72913b4934..c63d482c8f 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -5,8 +5,6 @@ import { useBoolean } from 'ahooks' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' @@ -14,6 +12,16 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' +import { Button } from '@/app/components/base/ui/button' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, +} from '@/app/components/base/ui/alert-dialog' import { toast } from '@/app/components/base/ui/toast' import { addTracingConfig, removeTracingConfig, updateTracingConfig } from '@/service/apps' import { docURL } from './config' @@ -621,7 +629,7 @@ const ProviderConfigModal: FC = ({
@@ -679,14 +687,24 @@ const ProviderConfigModal: FC = ({ ) : ( - + !open && hideRemoveConfirm()}> + +
+ + {t(`${I18N_PREFIX}.removeConfirmTitle`, { ns: 'app', key: t(`tracing.${type}.title`, { ns: 'app' }) })!} + + + {t(`${I18N_PREFIX}.removeConfirmContent`, { ns: 'app' })} + +
+ + {t('operation.cancel', { ns: 'common' })} + + {t('operation.confirm', { ns: 'common' })} + + +
+
)} ) diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx index 221420aade..898dab8f4a 100644 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ b/web/app/(humanInputLayout)/form/[token]/form.tsx @@ -1,5 +1,5 @@ 'use client' -import type { ButtonProps } from '@/app/components/base/button' +import type { ButtonProps } from '@/app/components/base/ui/button' import type { FormInputItem, UserAction } from '@/app/components/workflow/nodes/human-input/types' import type { SiteInfo } from '@/models/share' import type { HumanInputFormError } from '@/service/use-share' @@ -13,12 +13,12 @@ import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import AppIcon from '@/app/components/base/app-icon' -import Button from '@/app/components/base/button' import ContentItem from '@/app/components/base/chat/chat/answer/human-input-content/content-item' import ExpirationTime from '@/app/components/base/chat/chat/answer/human-input-content/expiration-time' import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-input-content/utils' import Loading from '@/app/components/base/loading' import DifyLogo from '@/app/components/base/logo/dify-logo' +import { Button } from '@/app/components/base/ui/button' import useDocumentTitle from '@/hooks/use-document-title' import { useParams } from '@/next/navigation' import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' @@ -100,7 +100,7 @@ const FormContent = () => { if (success) { return (
-
+
@@ -109,7 +109,7 @@ const FormContent = () => {
{t('humanInput.thanks', { ns: 'share' })}
{t('humanInput.recorded', { ns: 'share' })}
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
+
{t('humanInput.submissionID', { id: token, ns: 'share' })}
{ if (expired) { return (
-
+
@@ -137,7 +137,7 @@ const FormContent = () => {
{t('humanInput.sorry', { ns: 'share' })}
{t('humanInput.expired', { ns: 'share' })}
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
+
{t('humanInput.submissionID', { id: token, ns: 'share' })}
{ if (submitted) { return (
-
+
@@ -165,7 +165,7 @@ const FormContent = () => {
{t('humanInput.sorry', { ns: 'share' })}
{t('humanInput.completed', { ns: 'share' })}
-
{t('humanInput.submissionID', { id: token, ns: 'share' })}
+
{t('humanInput.submissionID', { id: token, ns: 'share' })}
{ if (rateLimitExceeded) { return (
-
+
@@ -210,7 +210,7 @@ const FormContent = () => { if (!formData) { return (
-
+
@@ -245,7 +245,7 @@ const FormContent = () => { background={site.icon_background} imageUrl={site.icon_url} /> -
{site.title}
+
{site.title}
diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index b71e6b4767..d19e5a7d2d 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -2,8 +2,8 @@ import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' import { toast } from '@/app/components/base/ui/toast' import Countdown from '@/app/components/signin/countdown' import { useLocale } from '@/context/i18n' @@ -62,9 +62,9 @@ export default function CheckCode() {
-
+

{t('checkCode.checkYourEmail', { ns: 'login' })}

-

+

{t('checkCode.tipsPrefix', { ns: 'login' })} {email} @@ -76,7 +76,7 @@ export default function CheckCode() {

- + setVerifyCode(e.target.value)} maxLength={6} className="mt-1" placeholder={t('checkCode.verificationCodePlaceholder', { ns: 'login' }) || ''} /> @@ -88,7 +88,7 @@ export default function CheckCode() {
- {t('back', { ns: 'login' })} + {t('back', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index a25b4bb4ef..cb6ece219c 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -3,8 +3,8 @@ import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' import { toast } from '@/app/components/base/ui/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' @@ -64,9 +64,9 @@ export default function CheckCode() {
-
+

{t('resetPassword', { ns: 'login' })}

-

+

{t('resetPasswordDesc', { ns: 'login' })}

@@ -74,7 +74,7 @@ export default function CheckCode() {
- +
setEmail(e.target.value)} />
@@ -90,7 +90,7 @@ export default function CheckCode() {
- {t('backToLogin', { ns: 'login' })} + {t('backToLogin', { ns: 'login' })}
) diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index bc8f651d17..5b89084ea1 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -3,8 +3,8 @@ import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' import { toast } from '@/app/components/base/ui/toast' import { validPassword } from '@/config' import { useRouter, useSearchParams } from '@/next/navigation' @@ -91,7 +91,7 @@ const ChangePasswordForm = () => {

{t('changePassword', { ns: 'login' })}

-

+

{t('changePasswordTip', { ns: 'login' })}

@@ -100,7 +100,7 @@ const ChangePasswordForm = () => {
{/* Password */}
-
-
{t('error.passwordInvalid', { ns: 'login' })}
+
{t('error.passwordInvalid', { ns: 'login' })}
{/* Confirm Password */}
-
) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index fbd6b216df..f600dba8b2 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,8 +1,8 @@ import { noop } from 'es-toolkit/function' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' import { toast } from '@/app/components/base/ui/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' @@ -52,7 +52,7 @@ export default function MailAndCodeAuth() {
- +
setEmail(e.target.value)} />
diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 1e9355e7ba..7fe5363927 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -2,8 +2,8 @@ import { noop } from 'es-toolkit/function' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' +import { Button } from '@/app/components/base/ui/button' import { toast } from '@/app/components/base/ui/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' @@ -103,7 +103,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut return (
-