diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 0000000000..15c697730a --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,13 @@ +have_fun: false +memory_config: + disabled: false +code_review: + disable: true + comment_severity_threshold: MEDIUM + max_review_comments: -1 + pull_request_opened: + help: false + summary: false + code_review: false + include_drafts: false +ignore_patterns: [] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1bb7d06232..3f53811f85 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,7 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost +/api/graphon/model_runtime/ @laipz8200 @WH-2099 # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 6f3b3c08b4..24af948732 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,10 +4,9 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0 + uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 with: - node-version-file: web/.nvmrc + working-directory: web + node-version-file: .nvmrc cache: true - cache-dependency-path: web/pnpm-lock.yaml - run-install: | - cwd: ./web + run-install: true diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 657a481f74..23ae36f7b1 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -84,20 +84,20 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: ./.github/actions/setup-web + - name: Restore ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' + id: eslint-cache-restore + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: | - vp run lint:ci - # pnpm run lint:report - # continue-on-error: true - - # - name: Annotate Code - # if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request' - # uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae - # with: - # eslint-report: web/eslint_report.json - # github-token: ${{ secrets.GITHUB_TOKEN }} + run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' @@ -114,6 +114,13 @@ jobs: working-directory: ./web run: vp run knip + - name: Save ESLint cache + if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' + uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: web/.eslintcache + key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} + superlinter: name: SuperLinter runs-on: ubuntu-latest diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 84f8000a01..1869254295 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -120,7 +120,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76 + uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/api/.env.example b/api/.env.example index 40e1c2dfdf..9672a99d55 100644 --- a/api/.env.example +++ b/api/.env.example @@ -353,6 +353,9 @@ BAIDU_VECTOR_DB_SHARD=1 BAIDU_VECTOR_DB_REPLICAS=3 BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT=500 +BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05 +BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300 # Upstash configuration UPSTASH_VECTOR_URL=your-server-url diff --git a/api/.importlinter b/api/.importlinter index a836d09088..c2841f64d2 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,10 +1,14 @@ [importlinter] root_packages = core - dify_graph + constants + context + graphon configs controllers extensions + factories + libs models tasks services @@ -22,40 +26,30 @@ layers = runtime entities containers = - dify_graph + graphon ignore_imports = - dify_graph.nodes.base.node -> dify_graph.graph_events - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events - dify_graph.nodes.loop.loop_node -> dify_graph.graph_events + graphon.nodes.base.node -> graphon.graph_events + graphon.nodes.iteration.iteration_node -> graphon.graph_events + graphon.nodes.loop.loop_node -> graphon.graph_events - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine - dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine + graphon.nodes.iteration.iteration_node -> graphon.graph_engine + graphon.nodes.loop.loop_node -> graphon.graph_engine # TODO(QuantumGhost): fix the import violation later - dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities - -[importlinter:contract:workflow-infrastructure-dependencies] -name = Workflow Infrastructure Dependencies -type = forbidden -source_modules = - dify_graph -forbidden_modules = - extensions.ext_database - extensions.ext_redis -allow_indirect_imports = True -ignore_imports = - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + graphon.entities.pause_reason -> graphon.nodes.human_input.entities [importlinter:contract:workflow-external-imports] name = Workflow External Imports type = forbidden source_modules = - dify_graph + graphon forbidden_modules = + constants configs + context controllers extensions + factories + libs models services tasks @@ -88,46 +82,14 @@ forbidden_modules = core.tools core.trigger core.variables -ignore_imports = - dify_graph.nodes.llm.llm_utils -> core.model_manager - dify_graph.nodes.llm.protocols -> core.model_manager - dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.llm.node -> core.tools.signature - dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - dify_graph.nodes.tool.tool_node -> core.tools.tool_engine - dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager - dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output - dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.llm.file_saver -> core.tools.signature - dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager - dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.tool.tool_node -> services - dify_graph.model_runtime.model_providers.__base.ai_model -> configs - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.__base.large_language_model -> configs - dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - dify_graph.model_runtime.model_providers.model_provider_factory -> configs - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids + +[importlinter:contract:workflow-third-party-imports] +name = Workflow Third-Party Imports +type = forbidden +source_modules = + graphon +forbidden_modules = + sqlalchemy [importlinter:contract:rsc] name = RSC @@ -136,7 +98,7 @@ layers = graph_engine response_coordinator containers = - dify_graph.graph_engine + graphon.graph_engine [importlinter:contract:worker] name = Worker @@ -145,7 +107,7 @@ layers = graph_engine worker containers = - dify_graph.graph_engine + graphon.graph_engine [importlinter:contract:graph-engine-architecture] name = Graph Engine Architecture @@ -161,28 +123,28 @@ layers = worker_management domain containers = - dify_graph.graph_engine + graphon.graph_engine [importlinter:contract:domain-isolation] name = Domain Model Isolation type = forbidden source_modules = - dify_graph.graph_engine.domain + graphon.graph_engine.domain forbidden_modules = - dify_graph.graph_engine.worker_management - dify_graph.graph_engine.command_channels - dify_graph.graph_engine.layers - dify_graph.graph_engine.protocols + graphon.graph_engine.worker_management + graphon.graph_engine.command_channels + graphon.graph_engine.layers + graphon.graph_engine.protocols [importlinter:contract:worker-management] name = Worker Management type = forbidden source_modules = - dify_graph.graph_engine.worker_management + graphon.graph_engine.worker_management forbidden_modules = - dify_graph.graph_engine.orchestration - dify_graph.graph_engine.command_processing - dify_graph.graph_engine.event_management + graphon.graph_engine.orchestration + graphon.graph_engine.command_processing + graphon.graph_engine.event_management [importlinter:contract:graph-traversal-components] @@ -192,11 +154,11 @@ layers = edge_processor skip_propagator containers = - dify_graph.graph_engine.graph_traversal + graphon.graph_engine.graph_traversal [importlinter:contract:command-channels] name = Command Channels Independence type = independence modules = - dify_graph.graph_engine.command_channels.in_memory_channel - dify_graph.graph_engine.command_channels.redis_channel + graphon.graph_engine.command_channels.in_memory_channel + graphon.graph_engine.command_channels.redis_channel diff --git a/api/.ruff.toml b/api/.ruff.toml index b0947eb619..4b1252a861 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -100,7 +100,7 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"] +"graphon/model_runtime/callbacks/base_callback.py" = ["T201"] "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name diff --git a/api/commands/vector.py b/api/commands/vector.py index 4cf11c9ad1..cb7eb7c452 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -10,6 +10,7 @@ from configs import dify_config from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment @@ -85,7 +86,7 @@ def migrate_annotation_vector_database(): dataset = Dataset( id=app.id, tenant_id=app.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, @@ -177,7 +178,9 @@ def migrate_knowledge_vector_database(): while True: try: stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + select(Dataset) + .where(Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY) + .order_by(Dataset.created_at.desc()) ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) @@ -269,7 +272,7 @@ def migrate_knowledge_vector_database(): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == "hierarchical_model": + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 8f956745b1..c8e4f7309f 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -51,3 +51,18 @@ class BaiduVectorDBConfig(BaseSettings): description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)", default="COARSE_MODE", ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT: int = Field( + description="Auto build row count increment threshold (default is 500)", + default=500, + ) + + BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO: float = Field( + description="Auto build row count increment ratio threshold (default is 0.05)", + default=0.05, + ) + + BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS: int = Field( + description="Timeout in seconds for rebuilding the index in Baidu Vector Database (default is 3600 seconds)", + default=300, + ) diff --git a/api/context/__init__.py b/api/context/__init__.py index 969e5f583d..8df37138e8 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -1,74 +1,36 @@ """ -Core Context - Framework-agnostic context management. +Application-layer context adapters. -This module provides context management that is independent of any specific -web framework. Framework-specific implementations register their context -capture functions at application initialization time. - -This ensures the workflow layer remains completely decoupled from Flask -or any other web framework. +Concrete execution-context implementations live here so `graphon` only +depends on injected context managers rather than framework state capture. """ -import contextvars -from collections.abc import Callable - -from dify_graph.context.execution_context import ( +from context.execution_context import ( + AppContext, + ContextProviderNotFoundError, ExecutionContext, + ExecutionContextBuilder, IExecutionContext, NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) - -# Global capturer function - set by framework-specific modules -_capturer: Callable[[], IExecutionContext] | None = None - - -def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """ - Register a context capture function. - - This should be called by framework-specific modules (e.g., Flask) - during application initialization. - - Args: - capturer: Function that captures current context and returns IExecutionContext - """ - global _capturer - _capturer = capturer - - -def capture_current_context() -> IExecutionContext: - """ - Capture current execution context. - - This function uses the registered context capturer. If no capturer - is registered, it returns a minimal context with only contextvars - (suitable for non-framework environments like tests or standalone scripts). - - Returns: - IExecutionContext with captured context - """ - if _capturer is None: - # No framework registered - return minimal context - return ExecutionContext( - app_context=NullAppContext(), - context_vars=contextvars.copy_context(), - ) - - return _capturer() - - -def reset_context_provider() -> None: - """ - Reset the context capturer. - - This is primarily useful for testing to ensure a clean state. - """ - global _capturer - _capturer = None - +from context.models import SandboxContext __all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", "register_context_capturer", "reset_context_provider", ] diff --git a/api/dify_graph/context/execution_context.py b/api/context/execution_context.py similarity index 60% rename from api/dify_graph/context/execution_context.py rename to api/context/execution_context.py index e3007530f0..ba9a24d4f3 100644 --- a/api/dify_graph/context/execution_context.py +++ b/api/context/execution_context.py @@ -1,5 +1,8 @@ """ -Execution Context - Abstracted context management for workflow execution. +Application-layer execution context adapters. + +Concrete context capture lives outside `graphon` so the graph package only +consumes injected context managers when it needs to preserve thread-local state. """ import contextvars @@ -16,33 +19,33 @@ class AppContext(ABC): """ Abstract application context interface. - This abstraction allows workflow execution to work with or without Flask - by providing a common interface for application context management. + Application adapters can implement this to restore framework-specific state + such as Flask app context around worker execution. """ @abstractmethod def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" - pass + raise NotImplementedError @abstractmethod def get_extension(self, name: str) -> Any: - """Get Flask extension by name (e.g., 'db', 'cache').""" - pass + """Get application extension by name.""" + raise NotImplementedError @abstractmethod def enter(self) -> AbstractContextManager[None]: """Enter the application context.""" - pass + raise NotImplementedError @runtime_checkable class IExecutionContext(Protocol): """ - Protocol for execution context. + Protocol for enterable execution context objects. - This protocol defines the interface that all execution contexts must implement, - allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + Concrete implementations may carry extra framework state, but callers only + depend on standard context-manager behavior plus optional user metadata. """ def __enter__(self) -> "IExecutionContext": @@ -62,14 +65,10 @@ class IExecutionContext(Protocol): @final class ExecutionContext: """ - Execution context for workflow execution in worker threads. + Generic execution context used by application-layer adapters. - This class encapsulates all context needed for workflow execution: - - Application context (Flask app or standalone) - - Context variables for Python contextvars - - User information (optional) - - It is designed to be serializable and passable to worker threads. + It restores captured `contextvars` and optionally enters an application + context before the worker executes graph logic. """ def __init__( @@ -78,14 +77,6 @@ class ExecutionContext: context_vars: contextvars.Context | None = None, user: Any = None, ) -> None: - """ - Initialize execution context. - - Args: - app_context: Application context (Flask or standalone) - context_vars: Python contextvars to preserve - user: User object (optional) - """ self._app_context = app_context self._context_vars = context_vars self._user = user @@ -98,27 +89,21 @@ class ExecutionContext: @property def context_vars(self) -> contextvars.Context | None: - """Get context variables.""" + """Get captured context variables.""" return self._context_vars @property def user(self) -> Any: - """Get user object.""" + """Get captured user object.""" return self._user @contextmanager def enter(self) -> Generator[None, None, None]: - """ - Enter this execution context. - - This is a convenience method that creates a context manager. - """ - # Restore context variables if provided + """Enter this execution context.""" if self._context_vars: for var, val in self._context_vars.items(): var.set(val) - # Enter app context if available if self._app_context is not None: with self._app_context.enter(): yield @@ -141,18 +126,10 @@ class ExecutionContext: class NullAppContext(AppContext): """ - Null implementation of AppContext for non-Flask environments. - - This is used when running without Flask (e.g., in tests or standalone mode). + Null application context for non-framework environments. """ def __init__(self, config: dict[str, Any] | None = None) -> None: - """ - Initialize null app context. - - Args: - config: Optional configuration dictionary - """ self._config = config or {} self._extensions: dict[str, Any] = {} @@ -165,7 +142,7 @@ class NullAppContext(AppContext): return self._extensions.get(name) def set_extension(self, name: str, extension: Any) -> None: - """Set extension by name.""" + """Register an extension for tests or standalone execution.""" self._extensions[name] = extension @contextmanager @@ -176,9 +153,7 @@ class NullAppContext(AppContext): class ExecutionContextBuilder: """ - Builder for creating ExecutionContext instances. - - This provides a fluent API for building execution contexts. + Builder for creating `ExecutionContext` instances. """ def __init__(self) -> None: @@ -211,63 +186,42 @@ class ExecutionContextBuilder: _capturer: Callable[[], IExecutionContext] | None = None - -# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. -# Key mapping: -# (name, tenant_id) -> provider -# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") -# - tenant_id: tenant identifier string -# Value: -# provider: Callable[[], BaseModel] returning the typed context value -# Type-safety note: -# - This registry cannot enforce that all providers for a given name return the same BaseModel type. -# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), -# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and -# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} T = TypeVar("T", bound=BaseModel) class ContextProviderNotFoundError(KeyError): - """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + """Raised when a tenant-scoped context provider is missing.""" pass def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """Register a single enterable execution context capturer (e.g., Flask).""" + """Register an enterable execution context capturer.""" global _capturer _capturer = capturer def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: - """Register a tenant-specific provider for a named context. - - Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. - Consider adding a typed wrapper for this registration in your feature module. - """ + """Register a tenant-specific provider for a named context.""" _tenant_context_providers[(name, tenant_id)] = provider def read_context(name: str, *, tenant_id: str) -> BaseModel: - """ - Read a context value for a specific tenant. - - Raises KeyError if the provider for (name, tenant_id) is not registered. - """ - prov = _tenant_context_providers.get((name, tenant_id)) - if prov is None: + """Read a context value for a specific tenant.""" + provider = _tenant_context_providers.get((name, tenant_id)) + if provider is None: raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") - return prov() + return provider() def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal - context with NullAppContext + copy of current contextvars. + If no framework adapter is registered, return a minimal context that only + restores `contextvars`. """ if _capturer is None: return ExecutionContext( @@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext: def reset_context_provider() -> None: - """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + """Reset the capturer and tenant-scoped providers.""" global _capturer _capturer = None _tenant_context_providers.clear() + + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 324a9ee8b4..eddd6448d8 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,11 +10,7 @@ from typing import Any, final from flask import Flask, current_app, g -from dify_graph.context import register_context_capturer -from dify_graph.context.execution_context import ( - AppContext, - IExecutionContext, -) +from context.execution_context import AppContext, IExecutionContext, register_context_capturer @final diff --git a/api/dify_graph/context/models.py b/api/context/models.py similarity index 100% rename from api/dify_graph/context/models.py rename to api/context/models.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index c52dcf8a57..764f9f8ee2 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( - ContextVar("plugin_model_providers") -) - -plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( - ContextVar("plugin_model_providers_lock") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index ff5326dade..515a6a5125 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -4,7 +4,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, computed_field -from dify_graph.file import helpers as file_helpers +from graphon.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 6c54be84a8..783cb5c444 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset +from models.enums import ApiTokenType from models.model import ApiToken, App from services.api_token_service import ApiTokenCache @@ -47,7 +48,7 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None token_prefix: str | None = None @@ -91,6 +92,7 @@ class BaseApiKeyListResource(Resource): ) key = ApiToken.generate_api_key(self.token_prefix or "", 24) + assert self.resource_type is not None, "resource_type must be set" api_token = ApiToken() setattr(api_token, self.resource_id_field, resource_id) api_token.tenant_id = current_tenant_id @@ -104,7 +106,7 @@ class BaseApiKeyListResource(Resource): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type: str | None = None + resource_type: ApiTokenType | None = None resource_model: type | None = None resource_id_field: str | None = None @@ -159,7 +161,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): """Create a new API key for an app""" return super().post(resource_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" token_prefix = "app-" @@ -175,7 +177,7 @@ class AppApiKeyResource(BaseApiKeyResource): """Delete an API key for an app""" return super().delete(resource_id, api_key_id) - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = App resource_id_field = "app_id" @@ -199,7 +201,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): """Create a new API key for a dataset""" return super().post(resource_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" token_prefix = "ds-" @@ -215,6 +217,6 @@ class DatasetApiKeyResource(BaseApiKeyResource): """Delete an API key for a dataset""" return super().delete(resource_id, api_key_id) - resource_type = "dataset" + resource_type = ApiTokenType.DATASET resource_model = Dataset resource_id_field = "dataset_id" diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6..357697ed30 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -26,9 +26,9 @@ from controllers.console.wraps import ( from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 2c5e8d29ee..91fbe4a85a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -22,7 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 4d7ddfea13..fe274e4c9a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -26,7 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 74750981dd..d329d22309 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -458,9 +458,7 @@ class ChatConversationApi(Resource): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore subquery = ( - db.session.query( - Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") - ) + sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() ) @@ -595,10 +593,8 @@ class ChatConversationDetailApi(Resource): def _get_conversation(app_model, conversation_id): current_user, _ = current_account_with_tenant() - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() + conversation = db.session.scalar( + sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) if not conversation: diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index af4ac450bb..c720a5e074 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -18,8 +18,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService @@ -168,7 +168,7 @@ class InstructionGenerateApi(Resource): try: # Generate from nothing for a workflow node if (args.current in (code_template, "")) and args.node_id != "": - app = db.session.query(App).where(App.id == args.flow_id).first() + app = db.session.get(App, args.flow_id) if not app: return {"error": f"app {args.flow_id} not found"}, 400 workflow = WorkflowService().get_draft_workflow(app_model=app) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 4b20418b53..412fc8795a 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -47,7 +48,7 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_model) def get(self, app_model): - server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() + server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) return server @console_ns.doc("create_app_mcp_server") @@ -98,7 +99,7 @@ class AppMCPServerController(Resource): @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) - server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() + server = db.session.get(AppMCPServer, payload.id) if not server: raise NotFound() @@ -135,11 +136,10 @@ class AppMCPServerRefreshController(Resource): @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() - server = ( - db.session.query(AppMCPServer) - .where(AppMCPServer.id == server_id) - .where(AppMCPServer.tenant_id == current_tenant_id) - .first() + server = db.session.scalar( + select(AppMCPServer) + .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) + .limit(1) ) if not server: raise NotFound() diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 736e7dbe17..dc752939ae 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -24,9 +24,9 @@ from controllers.console.wraps import ( ) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.raws import FilesContainedField +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a85e54fb51..8bb5aa2c1b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -69,9 +69,7 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: # get original app model config - original_app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() - ) + original_app_model_config = db.session.get(AppModelConfig, app_model.app_model_config_id) if original_app_model_config is None: raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict @@ -90,6 +88,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -129,6 +128,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) except Exception: continue diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index db218d8b81..7f44a99ff1 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -2,6 +2,7 @@ from typing import Literal from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -75,7 +76,7 @@ class AppSite(Resource): def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -124,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index d59aa44718..2737dd1dfd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.helper.trace_id_helper import get_external_trace_id from core.plugin.impl.exc import PluginInvokeError from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE @@ -29,15 +30,15 @@ from core.trigger.debug.event_selectors import ( create_event_poller, select_trigger_debug_events, ) -from dify_graph.enums import NodeType -from dify_graph.file.models import File -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from graphon.enums import NodeType +from graphon.file.models import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" @@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence mappings=files, tenant_id=workflow.tenant_id, config=file_extra_config, + access_controller=_file_access_controller, ) return file_objs diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 9b148c3f18..8cf0004b09 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -9,12 +9,12 @@ from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, build_workflow_archived_log_pagination_model, ) +from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index b78d97a382..657b072490 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,14 +15,15 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.file import helpers as file_helpers -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable @@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -389,13 +391,21 @@ class VariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 7ac653395e..29fa96c4e6 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -12,8 +12,7 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus +from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -27,6 +26,8 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value @@ -496,6 +497,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + form_tokens_by_form_id = _load_form_tokens_by_form_id( + [reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)] + ) # Build response paused_at = pause_entity.paused_at if pause_entity else None @@ -514,7 +518,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): "pause_type": { "type": "human_input", "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), + "backstage_input_url": _build_backstage_input_url( + form_tokens_by_form_id.get(reason.form_id) + ), }, } ) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index e687d980fa..493022ffea 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar, Union +from sqlalchemy import select + from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -15,16 +17,14 @@ R1 = TypeVar("R1") def _load_app_model(app_id: str) -> App | None: _, current_tenant_id = current_account_with_tenant() - app_model = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app_model = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) return app_model def _load_app_model_with_trial(app_id: str) -> App | None: - app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + app_model = db.session.scalar(select(App).where(App.id == app_id, App.status == "normal").limit(1)) return app_model diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index c2a95ddad2..9e7faa09c5 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 1ed931b0d7..844f3c91ff 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,7 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( @@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: @@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() - session.commit() # Create workspace if needed if ( diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c9023f27b..5c7011fd22 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,7 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6e59d4203c..665a80802d 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ef65b87923..6dcea380ea 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -5,7 +5,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -29,12 +29,12 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest from core.indexing_runner import IndexingRunner -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_storage import storage from fields.app_fields import app_detail_kernel_fields, related_app_list @@ -56,10 +56,11 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum -from models.enums import SegmentStatus +from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -343,7 +344,7 @@ class DatasetListApi(Resource): ) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -367,7 +368,7 @@ class DatasetListApi(Resource): for item in data: # convert embedding_model_provider to plugin standard format - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: + if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: @@ -448,7 +449,7 @@ class DatasetApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) @@ -457,7 +458,7 @@ class DatasetApi(Resource): data.update({"partial_member_list": part_users_list}) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -466,7 +467,7 @@ class DatasetApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data["indexing_technique"] == "high_quality": + if data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: data["embedding_available"] = True @@ -497,7 +498,7 @@ class DatasetApi(Resource): current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( - payload.indexing_technique == "high_quality" + payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY and payload.embedding_model_provider is not None and payload.embedding_model is not None ): @@ -750,20 +751,23 @@ class DatasetIndexingStatusApi(Resource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -789,7 +793,7 @@ class DatasetIndexingStatusApi(Resource): class DatasetApiKeyApi(Resource): max_keys = 10 token_prefix = "dataset-" - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") @@ -814,9 +818,12 @@ class DatasetApiKeyApi(Resource): _, current_tenant_id = current_account_with_tenant() current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) - .count() + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -838,7 +845,7 @@ class DatasetApiKeyApi(Resource): @console_ns.route("/datasets/api-keys/") class DatasetApiDeleteApi(Resource): - resource_type = "dataset" + resource_type = ApiTokenType.DATASET @console_ns.doc("delete_dataset_api_key") @console_ns.doc(description="Delete dataset API key") @@ -851,14 +858,14 @@ class DatasetApiDeleteApi(Resource): def delete(self, api_key_id): _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( ApiToken.tenant_id == current_tenant_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -869,7 +876,7 @@ class DatasetApiDeleteApi(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.delete(key) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bc90c4ffbd..edb738aad8 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -10,7 +10,7 @@ import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field -from sqlalchemy import asc, desc, select +from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -27,8 +27,7 @@ from core.model_manager import ModelManager from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.dataset_fields import dataset_fields from fields.document_fields import ( @@ -38,6 +37,8 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -211,12 +212,11 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() ) if dataset_process_rule: mode = dataset_process_rule.mode @@ -330,21 +330,23 @@ class DatasetDocumentListApi(Resource): if fetch: for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) document.completed_segments = completed_segments document.total_segments = total_segments @@ -448,11 +450,11 @@ class DatasetInitApi(Resource): raise Forbidden() knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {}) - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=knowledge_config.embedding_model_provider, @@ -462,7 +464,7 @@ class DatasetInitApi(Resource): is_multimodal = DatasetService.check_is_multimodal_model( current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) - knowledge_config.is_multimodal = is_multimodal + knowledge_config.is_multimodal = is_multimodal # pyrefly: ignore[bad-assignment] except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." @@ -521,10 +523,10 @@ class DocumentIndexingEstimateApi(DocumentResource): if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -586,10 +588,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) + file_detail = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) if file_detail is None: @@ -672,20 +674,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -723,18 +728,23 @@ class DocumentIndexingStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) - .count() + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) + ) + or 0 ) # Create a dictionary with document attributes and additional fields @@ -1258,11 +1268,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - log = ( - db.session.query(DocumentPipelineExecutionLog) - .filter_by(document_id=document_id) + log = db.session.scalar( + select(DocumentPipelineExecutionLog) + .where(DocumentPipelineExecutionLog.document_id == document_id) .order_by(DocumentPipelineExecutionLog.created_at.desc()) - .first() + .limit(1) ) if not log: return { @@ -1328,7 +1338,7 @@ class DocumentGenerateSummaryApi(Resource): raise BadRequest("document_list cannot be empty.") # Check if dataset configuration supports summary generation - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: raise ValueError( f"Summary generation is only available for 'high_quality' indexing technique. " f"Current indexing technique: {dataset.indexing_technique}" diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3fd0f3b712..2fd84303d7 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,10 +26,11 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment @@ -45,7 +46,7 @@ def _get_segment_with_summary(segment, dataset_id): """Helper function to marshal segment and add summary information.""" from services.summary_index_service import SummaryIndexService - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore # Query summary for this segment (only enabled summaries) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None @@ -206,7 +207,7 @@ class DatasetDocumentSegmentListApi(Resource): # Add summary to each segment segments_with_summary = [] for segment in segments.items: - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore segment_dict["summary"] = summaries.get(segment.id) segments_with_summary.append(segment_dict) @@ -279,10 +280,10 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -333,9 +334,9 @@ class DatasetDocumentSegmentAddApi(Resource): if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -383,10 +384,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -401,10 +402,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -447,10 +448,10 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -494,7 +495,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): payload = BatchImportPayload.model_validate(console_ns.payload or {}) upload_file_id = payload.upload_file_id - upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1)) if not upload_file: raise NotFound("UploadFile not found.") @@ -559,19 +560,19 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -616,10 +617,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -666,10 +667,10 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") @@ -714,24 +715,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -771,24 +772,24 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = ( - db.session.query(DocumentSegment) + segment = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id) - .first() + .limit(1) ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ( - db.session.query(ChildChunk) + child_chunk = db.session.scalar( + select(ChildChunk) .where( ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_tenant_id, ChildChunk.segment_id == segment.id, ChildChunk.document_id == document_id, ) - .first() + .limit(1) ) if not child_chunk: raise NotFound("Child chunk not found.") diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 86090bcd10..fc6896f123 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService def _build_dataset_detail_model(): @@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel): class BedrockRetrievalPayload(BaseModel): - retrieval_setting: dict[str, object] + retrieval_setting: "BedrockRetrievalSetting" query: str knowledge_id: str diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cd568cf835..699fa599c8 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -19,8 +19,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index a4498005d8..946fa599e6 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -10,8 +10,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index c5dadb75f5..977ae93c03 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,11 +21,12 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline @@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() def _create_pagination_parser(): @@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 3912cc73ca..9079fbc29a 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -37,9 +37,9 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 3ef1341abc..d533e6c5b1 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar +from sqlalchemy import select + from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]): del kwargs["pipeline_id"] - pipeline = ( - db.session.query(Pipeline) - .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id) - .first() + pipeline = db.session.scalar( + select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1) ) if not pipeline: diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index ffb9e5bb6e..bc78ee6d2d 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index fcd52d2818..ccdccceaa6 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -24,8 +24,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 15e1aea361..a72cf6328a 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -21,9 +21,9 @@ from controllers.console.explore.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index a8d8036f0f..26aa086aac 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -42,8 +42,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.app_fields import ( @@ -61,6 +59,8 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7801cee473..17dbbdd534 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -21,9 +21,9 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 49162d4dae..2a46d2250a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -13,9 +13,9 @@ from controllers.common.errors import ( ) from controllers.console import console_ns from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index e2b504751b..764f488755 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -2,7 +2,7 @@ from flask_restx import Resource, fields from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 538c5fb561..f45b72f390 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -8,7 +8,7 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 0a9e54de99..2a6f37aec8 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -5,8 +5,8 @@ from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index db3b02ae94..b22b91706e 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -7,9 +7,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d7eceb656c..3c7b97d7fc 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource): ) if args.config_from == "predefined-model": - available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( - tenant_id=tenant_id, provider_name=provider + available_credentials = model_provider_service.get_provider_available_credentials( + tenant_id=tenant_id, + provider=provider, ) else: # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) normalized_model_type = args.model_type.to_origin_model_type() - available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model + available_credentials = model_provider_service.get_provider_model_available_credentials( + tenant_id=tenant_id, + provider=provider, + model_type=normalized_model_type, + model=args.model, ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ee537367c7..6564ff5e7f 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -14,7 +14,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index b38f05795a..1273b85bc3 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -26,8 +26,8 @@ from core.mcp.mcp_client import MCPClient from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index ad78d2a623..feedf074b7 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -14,8 +14,8 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 9e3fb3a90b..2f1e2f28bd 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -70,22 +70,25 @@ class ToolFileApi(Resource): except Exception: raise UnsupportedFileTypeError() + mime_type = tool_file.mime_type + filename = tool_file.filename + response = Response( stream, - mimetype=tool_file.mimetype, + mimetype=mime_type, direct_passthrough=True, headers={}, ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args.as_attachment: - encoded_filename = quote(tool_file.name) + if args.as_attachment and filename: + encoded_filename = quote(filename) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" enforce_download_for_html( response, - mime_type=tool_file.mimetype, - filename=tool_file.name, + mime_type=mime_type, + filename=filename, extension=extension, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 52690a12e1..ed3278a28b 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden import services +from core.tools.signature import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index 74005217ef..b38994f055 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -16,12 +16,14 @@ api = ExternalApi( inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") from . import mail as _mail +from .app import dsl as _app_dsl from .plugin import plugin as _plugin from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) __all__ = [ + "_app_dsl", "_mail", "_plugin", "_workspace", diff --git a/api/controllers/inner_api/app/__init__.py b/api/controllers/inner_api/app/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/controllers/inner_api/app/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py new file mode 100644 index 0000000000..56730cf37a --- /dev/null +++ b/api/controllers/inner_api/app/dsl.py @@ -0,0 +1,110 @@ +"""Inner API endpoints for app DSL import/export. + +Called by the enterprise admin-api service. Import requires ``creator_email`` +to attribute the created app; workspace/membership validation is done by the +Go admin-api caller. +""" + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + +from controllers.common.schema import register_schema_model +from controllers.console.wraps import setup_required +from controllers.inner_api import inner_api_ns +from controllers.inner_api.wraps import enterprise_inner_api_only +from extensions.ext_database import db +from models import Account, App +from models.account import AccountStatus +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus + + +class InnerAppDSLImportPayload(BaseModel): + yaml_content: str = Field(description="YAML DSL content") + creator_email: str = Field(description="Email of the workspace member who will own the imported app") + name: str | None = Field(default=None, description="Override app name from DSL") + description: str | None = Field(default=None, description="Override app description from DSL") + + +register_schema_model(inner_api_ns, InnerAppDSLImportPayload) + + +@inner_api_ns.route("/enterprise/workspaces//dsl/import") +class EnterpriseAppDSLImport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc("enterprise_app_dsl_import") + @inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__]) + @inner_api_ns.doc( + responses={ + 200: "Import completed", + 202: "Import pending (DSL version mismatch requires confirmation)", + 400: "Import failed (business error)", + 404: "Creator account not found or inactive", + } + ) + def post(self, workspace_id: str): + """Import a DSL into a workspace on behalf of a specified creator.""" + args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {}) + + account = _get_active_account(args.creator_email) + if account is None: + return {"message": f"account '{args.creator_email}' not found or inactive"}, 404 + + account.set_tenant_id(workspace_id) + + with Session(db.engine) as session: + dsl_service = AppDslService(session) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=args.yaml_content, + name=args.name, + description=args.description, + ) + session.commit() + + if result.status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +@inner_api_ns.route("/enterprise/apps//dsl") +class EnterpriseAppDSLExport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc( + "enterprise_app_dsl_export", + responses={ + 200: "Export successful", + 404: "App not found", + }, + ) + def get(self, app_id: str): + """Export an app's DSL as YAML.""" + include_secret = request.args.get("include_secret", "false").lower() == "true" + + app_model = db.session.query(App).filter_by(id=app_id).first() + if not app_model: + return {"message": "app not found"}, 404 + + data = AppDslService.export_dsl( + app_model=app_model, + include_secret=include_secret, + ) + + return {"data": data}, 200 + + +def _get_active_account(email: str) -> Account | None: + """Look up an active account by email. + + Workspace membership is already validated by the Go admin-api caller. + """ + account = db.session.query(Account).filter_by(email=email).first() + if account is None or account.status != AccountStatus.ACTIVE: + return None + return account diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8b3950e6..72cab3de73 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -28,8 +28,8 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file.helpers import get_signed_file_url_for_plugin -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from core.tools.signature import get_signed_file_url_for_plugin +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 9ddaaa315b..869fb73cf5 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -9,8 +9,8 @@ from controllers.common.schema import register_schema_model from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request -from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 38d292d0b9..86d88ddafb 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -21,7 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 98f09c44a1..31f2797d66 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -28,7 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index f853a124ef..5e7847d784 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -4,6 +4,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from controllers.common.file_response import enforce_download_for_html from controllers.common.schema import register_schema_model @@ -102,27 +103,27 @@ class FilePreviewApi(Resource): raise FileAccessDeniedError("Invalid file or app identifier") # First, find the MessageFile that references this upload file - message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1)) if not message_file: raise FileNotFoundError("File not found in message context") # Get the message and verify it belongs to the requesting app - message = ( - db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1) ) if not message: raise FileAccessDeniedError("File access denied: not owned by requesting app") # Get the actual upload file record - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = db.session.get(UploadFile, file_id) if not upload_file: raise FileNotFoundError("Upload file record not found") # Additional security: verify tenant isolation - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app and upload_file.tenant_id != app.tenant_id: raise FileAccessDeniedError("File access denied: tenant mismatch") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 8b47a887bb..bc06e8f386 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import Forbidden from controllers.common.fields import Site as SiteResponse @@ -28,7 +29,7 @@ class AppSiteApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 35dd22c801..94afd47f7f 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -27,12 +27,12 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 83d07087ab..dcf788f7a8 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -14,10 +14,11 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -139,10 +140,10 @@ class DatasetListApi(DatasetApiResource): query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -153,15 +154,20 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if ( + item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index] + and item["embedding_model_provider"] # pyrefly: ignore[bad-index] + ): + item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation] + ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index] + ) + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index] if item_model in model_names: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore else: - item["embedding_available"] = False + item["embedding_available"] = False # type: ignore else: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore response = { "data": data, "has_more": len(datasets) == query.limit, @@ -253,10 +259,10 @@ class DatasetApi(DatasetApiResource): raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -265,7 +271,7 @@ class DatasetApi(DatasetApiResource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == IndexTechniqueType.HIGH_QUALITY: item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True @@ -315,7 +321,7 @@ class DatasetApi(DatasetApiResource): # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model - if payload.indexing_technique == "high_quality" or embedding_model_provider: + if payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d34b4124ae..2c094aa3e6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,7 +6,7 @@ from uuid import UUID from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") @@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # get documents @@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2e3b7fd85e..28fa915117 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -3,6 +3,7 @@ from typing import Any from flask import request from flask_restx import marshal from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -17,9 +18,10 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -91,7 +93,9 @@ class SegmentApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -103,9 +107,9 @@ class SegmentApi(DatasetApiResource): if not document.enabled: raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -149,7 +153,9 @@ class SegmentApi(DatasetApiResource): # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -157,9 +163,9 @@ class SegmentApi(DatasetApiResource): if not document: raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -219,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -253,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -262,10 +272,10 @@ class DatasetSegmentApi(DatasetApiResource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -300,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -343,7 +355,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -358,9 +372,9 @@ class ChildChunkApi(DatasetApiResource): raise NotFound("Segment not found.") # check embedding model setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -401,7 +415,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -467,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -526,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 35aed40a59..5ac65fc4e6 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -3,7 +3,7 @@ from flask_restx import Resource from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 7aa5b2f092..1d52b8a737 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -9,6 +9,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan @@ -62,7 +63,7 @@ def validate_app_token( def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).where(App.id == api_token.app_id).first() + app_model = db.session.get(App, api_token.app_id) if not app_model: raise Forbidden("The app no longer exists.") @@ -72,7 +73,7 @@ def validate_app_token( if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() + tenant = db.session.get(Tenant, app_model.tenant_id) if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -106,8 +107,8 @@ def validate_app_token( else: # For service API without end-user context, ensure an Account is logged in # so services relying on current_account_with_tenant() work correctly. - tenant_owner_info = ( - db.session.query(Tenant, Account) + tenant_owner_info = db.session.execute( + select(Tenant, Account) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(Account, TenantAccountJoin.account_id == Account.id) .where( @@ -115,8 +116,7 @@ def validate_app_token( TenantAccountJoin.role == "owner", Tenant.status == TenantStatus.NORMAL, ) - .one_or_none() - ) + ).one_or_none() if tenant_owner_info: tenant_model, account = tenant_owner_info @@ -277,29 +277,28 @@ def validate_dataset_token( # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == api_token.tenant_id, ) - .first() + .limit(1) ) if not dataset: raise NotFound("Dataset not found.") if not dataset.enable_api: raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == api_token.tenant_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role.in_(["owner"])) .where(Tenant.status == TenantStatus.NORMAL) - .one_or_none() - ) # TODO: only owner information is required, so only one is returned. + ).one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).where(Account.id == ta.account_id).first() + account = db.session.get(Account, ta.account_id) # Login admin if account: account.current_tenant = tenant @@ -360,7 +359,9 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2b8f752668..8081dee0bd 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 8634c1f43c..0528184d79 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -25,7 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index aa56292614..4274b8c9ab 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -20,9 +20,9 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.enums import FeedbackRating diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 6a93ef6748..fe31e9d4ac 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -11,9 +11,9 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 508d1a756a..ccef6e5b7f 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -22,9 +22,9 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bdc8df813..a846cf4b0f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) +from core.app.file_access import DatabaseFileAccessController from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -26,8 +27,10 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from extensions.ext_database import db +from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, LLMUsage, PromptMessage, @@ -37,15 +40,14 @@ from dify_graph.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from extensions.ext_database import db -from factories import file_factory +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class BaseAgentRunner(AppRunner): @@ -138,6 +140,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + user_id=self.user_id, invoke_from=self.application_generate_entity.invoke_from, ) assert tool_entity.entity.description @@ -524,7 +527,10 @@ class BaseAgentRunner(AppRunner): image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config + message_files=files, + tenant_id=self.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) if not file_objs: return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9271ed10bd..0a0fdfdd29 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -15,8 +15,8 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, @@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tools=[], stop=app_generate_entity.model_conf.stop, stream=True, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 89451a0498..b3fc8d42e6 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,16 +1,16 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder class CotChatAgentRunner(CotAgentRunner): diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3023b9bc4d..51a30998ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,13 +1,13 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder class CotCompletionAgentRunner(CotAgentRunner): diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5e13a13b21..d38d24d1e7 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -11,8 +11,8 @@ from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessag from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, @@ -25,7 +25,7 @@ from dify_graph.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from models.model import Message logger = logging.getLogger(__name__) @@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): tools=prompt_messages_tools, stop=app_generate_entity.model_conf.stop, stream=self.stream_tool_call, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 82676f1ebd..c3e56fe011 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -4,7 +4,7 @@ from collections.abc import Generator from typing import Union from core.agent.entities import AgentScratchpadUnit -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 558b6e69a0..dbd7527fc6 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -4,10 +4,10 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: @@ -21,7 +21,7 @@ class ModelConfigConverter: """ model_config = app_config.model - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 0929f52e33..f279f769aa 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,9 +2,8 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -54,9 +53,12 @@ class ModelConfigManager: if not isinstance(config["model"], dict): raise ValueError("model must be of object type") + # Keep provider discovery and provider-backed model listing on the same + # request-scoped runtime so caller scope and provider caches stay aligned. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # model.provider - model_provider_factory = ModelProviderFactory(tenant_id) - provider_entities = model_provider_factory.get_providers() + provider_entities = assembly.model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") @@ -71,8 +73,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( + models = assembly.provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index b7073898d6..7715a5330a 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -7,7 +7,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 8de1224a89..6d63ae04d3 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -3,7 +3,7 @@ from typing import cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 95ea70bc40..c67412cc29 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -5,10 +5,10 @@ from typing import Any, Literal from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.file import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 0c4266fbeb..9092c1a17d 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any from constants import DEFAULT_FILE_NUMBER_LIMITS -from dify_graph.file import FileUploadConfig +from graphon.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index d2a9a73380..13ace32fd6 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,7 @@ import re from core.app.app_config.entities import RagPipelineVariableEntity -from dify_graph.variables.input_entities import VariableEntity +from graphon.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5d974335ff..853cbb426c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -24,6 +24,7 @@ from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -34,17 +35,13 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import ( - DraftVariableSaverFactory, -) -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.base import Base @@ -150,85 +147,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id ) - else: - file_objs = [] - # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + if invoke_from == InvokeFrom.DEBUGGER: + # always enable retriever resource in debugger mode + app_config.additional_features.show_retrieve_source = True # type: ignore - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + workflow_run_id=str(workflow_run_id), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) - if invoke_from == InvokeFrom.DEBUGGER: - # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True # type: ignore + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) - # init application generate entity - application_generate_entity = AdvancedChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) - - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - return self._generate( - workflow=workflow, - user=user, - invoke_from=invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - stream=streaming, - pause_state_config=pause_state_config, - ) + return self._generate( + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + stream=streaming, + pause_state_config=pause_state_config, + ) def resume( self, @@ -460,94 +459,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + is_first_conversation = conversation is None - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) - if is_first_conversation: - # update conversation features - conversation.override_model_configs = workflow.features - db.session.commit() - db.session.refresh(conversation) + if is_first_conversation: + # update conversation features + conversation.override_model_configs = workflow.features + db.session.commit() + db.session.refresh(conversation) - # get conversation dialogue count - # NOTE: dialogue_count should not start from 0, - # because during the first conversation, dialogue_count should be 1. - self._dialogue_count = get_thread_messages_length(conversation.id) + 1 + # get conversation dialogue count + # NOTE: dialogue_count should not start from 0, + # because during the first conversation, dialogue_count should be 1. + self._dialogue_count = get_thread_messages_length(conversation.id) + 1 - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - "context": context, - "variable_loader": variable_loader, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": context, + "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) - # db.session.refresh(user) - db.session.close() + worker_thread.start() - # return response or stream generator - response = self._handle_advanced_chat_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), - ) + # release database connection, because the following new thread operations may take a long time + with Session(bind=db.engine, expire_on_commit=False) as session: + workflow = _refresh_model(session, workflow) + message = _refresh_model(session, message) + db.session.close() - return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), + ) + + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 66037696af..d21fce144e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -25,19 +25,24 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import ( + build_bootstrap_variables, + build_system_variables, + system_variables_to_mapping, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import Variable from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable @@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - system_inputs = SystemVariable( + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, conversation_id=self.conversation.id, @@ -132,6 +137,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs @@ -150,7 +156,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.application_generate_entity.inputs = new_inputs self.application_generate_entity.query = new_query - system_inputs.query = new_query + system_inputs = build_system_variables( + system_variables_to_mapping(system_inputs), + query=new_query, + ) # annotation reply if self.handle_annotation_reply( @@ -166,14 +175,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Create a variable pool. # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=new_inputs, - environment_variables=self._workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=conversation_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + conversation_variables=conversation_variables, + ), ) + root_node_id = get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs) # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) @@ -185,6 +197,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, + root_node_id=root_node_id, ) db.session.close() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f7b5030d33..51febed32a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -14,6 +14,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -65,15 +66,15 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus @@ -117,7 +118,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( query=message.query, files=application_generate_entity.files, conversation_id=conversation.id, @@ -741,8 +742,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( tenant_id=self._workflow_tenant_id, + workflow_execution_id=self._workflow_run_id, ) - form = form_repository.get_form(self._workflow_run_id, node_id) + form = form_repository.get_form(node_id) if form is None: return None return form.id @@ -933,21 +935,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - message_files = [ - MessageFile( - message_id=message.id, - type=file["type"], - transfer_method=file["transfer_method"], - url=file["remote_url"], - belongs_to=MessageFileBelongsTo.ASSISTANT, - upload_file_id=file["related_id"], - created_by_role=CreatorUserRole.ACCOUNT - if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatorUserRole.END_USER, - created_by=message.from_account_id or message.from_end_user_id or "", + message_files: list[MessageFile] = [] + for file in self._recorded_files: + reference = file.get("reference") or file.get("related_id") + message_files.append( + MessageFile( + message_id=message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to=MessageFileBelongsTo.ASSISTANT, + upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None), + created_by_role=CreatorUserRole.ACCOUNT + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER, + created_by=message.from_account_id or message.from_end_user_id or "", + ) ) - for file in self._recorded_files - ] session.add_all(message_files) def _seed_graph_runtime_state_from_queue_manager(self) -> None: @@ -1003,13 +1007,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): return message def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 76a067d7b6..1a44cc235e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -21,9 +21,9 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService @@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args.get("files") or [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args.get("files") or [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) + # get tracing instance + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) - # init application generate entity - application_generate_entity = AgentChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - call_depth=0, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + call_depth=0, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) - # new thread with request context and contextvars - context = contextvars.copy_context() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "context": context, - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a81da2e91c..09ddce327e 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,10 +15,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index a92e3dd2ea..5c9ba4567a 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -6,7 +6,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20e6ac98ea..8e8ccf2b90 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,27 +1,89 @@ from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import NodeType -from dify_graph.file import File, FileUploadConfig -from dify_graph.repositories.draft_variable_repository import ( +from core.app.apps.draft_variable_saver import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) -from dify_graph.variables.input_entities import VariableEntityType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope +from extensions.ext_database import db from factories import file_factory +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl if TYPE_CHECKING: - from dify_graph.variables.input_entities import VariableEntity + from graphon.variables.input_entities import VariableEntity + + +@final +class _DebuggerDraftVariableSaver: + """Adapter that binds SQLAlchemy session setup outside the saver port.""" + + def __init__( + self, + *, + account: Account, + app_id: str, + node_id: str, + node_type: NodeType, + node_execution_id: str, + enclosing_node_id: str | None = None, + ) -> None: + self._account = account + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + with Session(db.engine) as session, session.begin(): + DraftVariableSaverImpl( + session=session, + app_id=self._app_id, + node_id=self._node_id, + node_type=self._node_type, + node_execution_id=self._node_execution_id, + enclosing_node_id=self._enclosing_node_id, + user=self._account, + ).save(process_data, outputs) class BaseAppGenerator: + _file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController() + + @staticmethod + def _bind_file_access_scope( + *, + tenant_id: str, + user: Account | EndUser, + invoke_from: InvokeFrom, + ) -> AbstractContextManager[None]: + """Bind request-scoped file ownership markers for downstream file lookups.""" + + user_id = getattr(user, "id", None) + if not isinstance(user_id, str) or not user_id: + return nullcontext() + + user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER + return bind_file_access_scope( + FileAccessScope( + tenant_id=tenant_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + ) + def _prepare_user_inputs( self, *, @@ -50,6 +112,7 @@ class BaseAppGenerator: allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE @@ -64,6 +127,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, list) @@ -226,32 +290,30 @@ class BaseAppGenerator: assert isinstance(account, Account) def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - return DraftVariableSaverImpl( - session=session, + return _DebuggerDraftVariableSaver( + account=account, app_id=app_id, node_id=node_id, node_type=node_type, node_execution_id=node_execution_id, enclosing_node_id=enclosing_node_id, - user=account, ) else: def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: + _ = app_id, node_id, node_type, node_execution_id, enclosing_node_id return NoopDraftVariableSaver() return draft_var_saver_factory diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 5addd41815..d1771452c5 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -20,8 +20,8 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) -from dify_graph.runtime import GraphRuntimeState from extensions.ext_redis import redis_client +from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -61,27 +61,30 @@ class AppQueueManager(ABC): listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() last_ping_time: int | float = 0 - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break + try: + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE - ) + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE + ) - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + finally: + self._graph_runtime_state = None # Release reference once consumers finish or close the generator. def stop_listen(self): """ @@ -90,7 +93,6 @@ class AppQueueManager(ABC): """ self._clear_task_belong_cache() self._q.put(None) - self._graph_runtime_state = None # Release reference to allow GC to reclaim memory def _clear_task_belong_cache(self) -> None: """ diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 11fcbb7561..4a4c8b535d 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -29,22 +29,22 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from extensions.ext_database import db +from graphon.file.enums import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError -from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file.models import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 91cf54c774..db3a98c7ac 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -20,9 +21,9 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService @@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = ChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - stream=streaming, - ) + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + stream=streaming, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f63b38fc86..077c5239f3 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -15,9 +15,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) @@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 6a8e436163..2a90fbdad0 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,7 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from dify_graph.runtime import GraphRuntimeState +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -30,10 +31,10 @@ class GraphRuntimeStateSupport: return self._resolve_graph_runtime_state(graph_runtime_state) def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: - system_variables = graph_runtime_state.variable_pool.system_variables - if not system_variables or not system_variables.workflow_execution_id: + workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_run_id: raise ValueError("workflow_execution_id missing from runtime state") - return str(system_variables.workflow_execution_id) + return workflow_run_id def _resolve_graph_runtime_state( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 621b0d8cf3..e4aa2ff650 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping, Sequence @@ -50,22 +51,23 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import ( +from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import FILE_MODEL_IDENTITY, File -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm @@ -111,11 +113,11 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - system_variables: SystemVariable, + system_variables: Sequence[Variable], ): self._application_generate_entity = application_generate_entity self._user = user - self._system_variables = system_variables + self._system_variables = system_variables_to_mapping(system_variables) self._workflow_inputs = self._prepare_workflow_inputs() # Disable truncation for SERVICE_API calls to keep backward compatibility. @@ -133,7 +135,7 @@ class WorkflowResponseConverter: # ------------------------------------------------------------------ def _prepare_workflow_inputs(self) -> Mapping[str, Any]: inputs = dict(self._application_generate_entity.inputs) - for field_name, value in self._system_variables.to_dict().items(): + for field_name, value in self._system_variables.items(): # TODO(@future-refactor): store system variables separately from user inputs so we don't # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. if field_name == SystemVariableKey.CONVERSATION_ID: @@ -318,13 +320,23 @@ class WorkflowResponseConverter: pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] expiration_times_by_form_id: dict[str, datetime] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_token_by_form_id: dict[str, str] = {} if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) + stmt = select( + HumanInputForm.id, + HumanInputForm.expiration_time, + HumanInputForm.form_definition, + ).where(HumanInputForm.id.in_(human_input_form_ids)) with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): + for form_id, expiration_time, form_definition in session.execute(stmt): expiration_times_by_form_id[str(form_id)] = expiration_time + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) responses: list[StreamResponse] = [] @@ -344,8 +356,8 @@ class WorkflowResponseConverter: form_content=reason.form_content, inputs=reason.inputs, actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, + display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False), + form_token=form_token_by_form_id.get(reason.form_id), resolved_default_values=reason.resolved_default_values, expiration_time=int(expiration_time.timestamp()), ), diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 002b914ef1..c418fe9759 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -20,9 +21,9 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras={}, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras={}, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator): model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict - # parse files - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=message.message_files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + # parse files + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) - else: - file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=list(file_objs), + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={}, + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - inputs=message.inputs, - query=message.query, - files=list(file_objs), - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras={}, - ) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 56a4519879..6bb1ecdcb1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -13,9 +13,9 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) @@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/dify_graph/repositories/draft_variable_repository.py b/api/core/app/apps/draft_variable_saver.py similarity index 65% rename from api/dify_graph/repositories/draft_variable_repository.py rename to api/core/app/apps/draft_variable_saver.py index b2ebfacffd..24018012c5 100644 --- a/api/dify_graph/repositories/draft_variable_repository.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -4,31 +4,30 @@ import abc from collections.abc import Mapping from typing import Any, Protocol -from sqlalchemy.orm import Session - -from dify_graph.enums import NodeType +from graphon.enums import NodeType class DraftVariableSaver(Protocol): @abc.abstractmethod - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + """Persist node draft variables for a completed execution.""" + raise NotImplementedError class DraftVariableSaverFactory(Protocol): @abc.abstractmethod def __call__( self, - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - pass + """Build a saver bound to a concrete node execution.""" + raise NotImplementedError class NoopDraftVariableSaver(DraftVariableSaver): - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + return None diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 44d10d79b8..fe61224ada 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -28,6 +28,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.file_reference import resolve_file_record_id from extensions.ext_database import db from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic @@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): transfer_method=file.transfer_method, belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, - upload_file_id=file.related_id, + upload_file_id=resolve_file_record_id(file.reference), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19d67eb108..48457b5326 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -18,6 +18,7 @@ import contexts from configs import dify_config from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager @@ -34,13 +35,14 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.repositories.factory import DifyCoreRepositoryFactory -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index e767766bdb..44d2450f74 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -12,19 +12,19 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db +from graphon.entities.graph_init_params import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow @@ -106,13 +106,14 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow=workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=files, user_id=user_id, app_id=app_config.app_id, @@ -142,19 +143,25 @@ class PipelineRunner(WorkflowBasedAppRunner): ) ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - rag_pipeline_variables=rag_pipeline_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=workflow.environment_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ) + root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id( + workflow.graph_dict + ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( graph_runtime_state=graph_runtime_state, - start_node_id=self.application_generate_entity.start_node_id, + start_node_id=root_node_id, workflow=workflow, user_from=user_from, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 822019e170..f854035822 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -17,6 +17,7 @@ from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager @@ -30,15 +31,13 @@ from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom @@ -148,107 +147,109 @@ class WorkflowAppGenerator(BaseAppGenerator): graph_engine_layers: Sequence[GraphEngineLayer] = (), pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: - files: Sequence[Mapping[str, Any]] = args.get("files") or [] + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - # parse files - # TODO(QuantumGhost): Move file parsing logic to the API controller layer - # for better separation of concerns. - # - # For implementation reference, see the `_parse_file` function and - # `DraftWorkflowNodeRunApi` class which handle this properly. - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - system_files = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, - strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, - ) - - # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow, - ) - - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, - user_id=user.id if isinstance(user, Account) else user.session_id, - ) - - inputs: Mapping[str, Any] = args["inputs"] - - extras = { - **extract_external_trace_id_from_args(args), - } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) - # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args - # trigger shouldn't prepare user inputs - if self._should_prepare_user_inputs(args): - inputs = self._prepare_user_inputs( - user_inputs=inputs, - variables=app_config.variables, + # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + system_files = file_factory.build_from_mappings( + mappings=files, tenant_id=app_model.tenant_id, + config=file_extra_config, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + access_controller=self._file_access_controller, ) - # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - inputs=inputs, - files=list(system_files), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - call_depth=call_depth, - trace_manager=trace_manager, - workflow_execution_id=workflow_run_id, - extras=extras, - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if triggered_from is not None: - # Use explicitly provided triggered_from (for async triggers) - workflow_triggered_from = triggered_from - elif invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - root_node_id=root_node_id, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, - ) + inputs: Mapping[str, Any] = args["inputs"] + + extras = { + **extract_external_trace_id_from_args(args), + } + workflow_run_id = str(workflow_run_id or uuid.uuid4()) + # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args + # trigger shouldn't prepare user inputs + if self._should_prepare_user_inputs(args): + inputs = self._prepare_user_inputs( + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ) + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + inputs=inputs, + files=list(system_files), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + trace_manager=trace_manager, + workflow_execution_id=workflow_run_id, + extras=extras, + ) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if triggered_from is not None: + # Use explicitly provided triggered_from (for async triggers) + workflow_triggered_from = triggered_from + elif invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + root_node_id=root_node_id, + graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, + ) def resume( self, @@ -311,62 +312,67 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager - queue_manager = WorkflowAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode, - ) - - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - # release database connection, because the following new thread operations may take a long time - db.session.close() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": context, - "variable_loader": variable_loader, - "root_node_id": root_node_id, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # release database connection, because the following new thread operations may take a long time + db.session.close() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": context, + "variable_loader": variable_loader, + "root_node_id": root_node_id, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - draft_var_saver_factory=draft_var_saver_factory, - stream=streaming, - ) + draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + draft_var_saver_factory=draft_var_saver_factory, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index caea8b6b95..c02c0b16e9 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -8,17 +8,18 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow @@ -91,12 +92,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=self.application_generate_entity.files, user_id=self._sys_user_id, app_id=app_config.app_id, @@ -104,12 +106,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=self._workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + ), ) + root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph = self._init_graph( @@ -120,7 +126,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, - root_node_id=self._root_node_id, + root_node_id=root_node_id, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445..e0c5b44ee4 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, @@ -55,12 +56,11 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser @@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory self._workflow = workflow - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( files=application_generate_entity.files, user_id=user_session_id, app_id=application_generate_entity.app_config.app_id, @@ -705,7 +705,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): app_id=self._application_generate_entity.app_config.app_id, workflow_id=self._workflow.id, workflow_run_id=workflow_run_id, - created_from=created_from.value, + created_from=created_from, created_by_role=self._created_by_role, created_by=self._user_id, ) @@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): return response def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index adc6cce9af..d7d3bd27de 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -34,13 +34,22 @@ from core.app.entities.queue_entities import ( ) from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -66,10 +75,9 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from graphon.graph_events.graph import GraphRunAbortedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -156,6 +164,8 @@ class WorkflowBasedAppRunner: workflow: Workflow, single_iteration_run: Any | None = None, single_loop_run: Any | None = None, + *, + user_id: str, ) -> tuple[Graph, VariablePool, GraphRuntimeState]: """ Prepare graph, variable pool, and runtime state for single node execution @@ -173,14 +183,15 @@ class WorkflowBasedAppRunner: ValueError: If neither single_iteration_run nor single_loop_run is specified """ # Create initial runtime state with variable pool containing environment variables - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), environment_variables=workflow.environment_variables, ), - start_at=time.time(), ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: @@ -191,6 +202,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id=user_id, ) elif single_loop_run: graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( @@ -200,6 +212,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id=user_id, ) else: raise ValueError("Neither single_iteration_run nor single_loop_run is specified") @@ -216,6 +229,8 @@ class WorkflowBasedAppRunner: graph_runtime_state: GraphRuntimeState, node_type_filter_key: str, # 'iteration_id' or 'loop_id' node_type_label: str = "node", # 'iteration' or 'loop' for error messages + *, + user_id: str = "", ) -> tuple[Graph, VariablePool]: """ Get graph and variable pool for single node execution (iteration or loop). @@ -272,6 +287,8 @@ class WorkflowBasedAppRunner: graph_config["edges"] = edge_configs + typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] + # Create required parameters for Graph.init graph_init_params = GraphInitParams( workflow_id=workflow.id, @@ -279,7 +296,7 @@ class WorkflowBasedAppRunner: run_context=build_dify_run_context( tenant_id=workflow.tenant_id, app_id=self._app_id, - user_id="", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, ), @@ -291,26 +308,15 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) - # init graph - graph = Graph.init( - graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True - ) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id target_node_config = None - for node in node_configs: - if node.get("id") == node_id: + for node in typed_node_configs: + if node["id"] == node_id: target_node_config = node break if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") - target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) - # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) @@ -319,12 +325,31 @@ class WorkflowBasedAppRunner: # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool + preload_node_creation_variables( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + selectors=[ + selector + for node_config in typed_node_configs + for selector in get_node_creation_preload_selectors( + node_type=node_config["data"].type, + node_data=node_config["data"], + ) + ], + ) + try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=target_node_config["id"], + node_type=node_type, + node_data=target_node_config["data"], + variable_mapping=variable_mapping, + ) load_into_variable_pool( variable_loader=self._variable_loader, @@ -340,6 +365,14 @@ class WorkflowBasedAppRunner: tenant_id=workflow.tenant_id, ) + # init graph after constructor-time context has been loaded + graph = Graph.init( + graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True + ) + + if not graph: + raise ValueError("graph not found in workflow") + return graph, variable_pool @staticmethod @@ -408,7 +441,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeRetryEvent( @@ -448,7 +485,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeSucceededEvent( @@ -466,6 +507,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunFailedEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeFailedEvent( node_execution_id=event.id, @@ -475,7 +521,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, @@ -483,6 +529,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunExceptionEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeExceptionEvent( node_execution_id=event.id, @@ -492,7 +543,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3..d8d851c505 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,14 +7,16 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.file import File, FileUploadConfig -from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +DIFY_RUN_CONTEXT_KEY = "_dify" + + class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d2a36f2a0d..63857bfff2 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -7,10 +7,10 @@ from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.entities.pause_reason import PauseReason +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 46a8ab52f2..719027bd23 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -6,10 +6,10 @@ from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.nodes.human_input.entities import FormInput, UserAction +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 87d4772815..0bd904811a 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -4,6 +4,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models.dataset import Dataset from models.enums import CollectionBindingType, ConversationFromSource @@ -50,7 +51,7 @@ class AnnotationReplyFeature: dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index 5ed1fadc41..d59f5125e3 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -2,7 +2,7 @@ import logging from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from dify_graph.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..e0f1759e5e 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py new file mode 100644 index 0000000000..a75ab9781b --- /dev/null +++ b/api/core/app/file_access/__init__.py @@ -0,0 +1,11 @@ +from .controller import DatabaseFileAccessController +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope + +__all__ = [ + "DatabaseFileAccessController", + "FileAccessControllerProtocol", + "FileAccessScope", + "bind_file_access_scope", + "get_current_file_access_scope", +] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py new file mode 100644 index 0000000000..300c187083 --- /dev/null +++ b/api/core/app/file_access/controller.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable + +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, get_current_file_access_scope + + +class DatabaseFileAccessController(FileAccessControllerProtocol): + """Workflow-layer authorization helper for database-backed file lookups. + + Tenant scoping remains mandatory. When the current execution belongs to an + end user, the lookup is additionally constrained to that end user's file + ownership markers. + """ + + _scope_getter: Callable[[], FileAccessScope | None] + + def __init__( + self, + *, + scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope, + ) -> None: + self._scope_getter = scope_getter + + def current_scope(self) -> FileAccessScope | None: + return self._scope_getter() + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where( + UploadFile.created_by_role == CreatorUserRole.END_USER, + UploadFile.created_by == resolved_scope.user_id, + ) + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(UploadFile, file_id) + + stmt = self.apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(ToolFile, file_id) + + stmt = self.apply_tool_file_filters( + select(ToolFile).where(ToolFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) diff --git a/api/core/app/file_access/protocols.py b/api/core/app/file_access/protocols.py new file mode 100644 index 0000000000..8bb3eb9924 --- /dev/null +++ b/api/core/app/file_access/protocols.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Protocol + +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile + +from .scope import FileAccessScope + + +class FileAccessControllerProtocol(Protocol): + """Contract for applying access rules to file lookups. + + Implementations translate an optional execution scope into query constraints + and authorized record retrieval. The contract is intentionally limited to + ownership and tenancy rules for workflow-layer file access. + """ + + def current_scope(self) -> FileAccessScope | None: + """Return the scope active for the current execution, if one exists. + + Callers use this to decide whether embedded file metadata may be trusted + or whether a fresh authorized lookup is required. + """ + ... + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + """Return an upload-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + """Return a tool-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + """Load one authorized upload-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + """Load one authorized tool-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py new file mode 100644 index 0000000000..80d504ef1c --- /dev/null +++ b/api/core/app/file_access/scope.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + +_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( + "current_file_access_scope", + default=None, +) + + +@dataclass(frozen=True, slots=True) +class FileAccessScope: + """Request-scoped ownership context used by workflow-layer file lookups.""" + + tenant_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + @property + def requires_user_ownership(self) -> bool: + return self.user_from == UserFrom.END_USER + + +def get_current_file_access_scope() -> FileAccessScope | None: + return _current_file_access_scope.get() + + +@contextmanager +def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: + token = _current_file_access_scope.set(scope) + try: + yield + finally: + _current_file_access_scope.reset(token) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d227e4e904..eeb9abbbfa 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,19 @@ +""" +Persist conversation-scoped variable updates emitted by the graph engine. + +The graph package emits generic variable update events and stays unaware of +conversation identity or storage concerns. This layer lives in the application +core, listens to those generic events, and persists only the `conversation.*` +scope updates that matter to chat applications. +""" + import logging -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.variables import VariableBase +from core.workflow.system_variables import SystemVariableKey, get_system_text +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -20,41 +27,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): pass def on_event(self, event: GraphEngineEvent) -> None: - if not isinstance(event, NodeRunSucceededEvent): - return - if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: - return - if self.graph_runtime_state is None: + if not isinstance(event, NodeRunVariableUpdatedEvent): return - updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] - if not updated_variables: + selector = event.variable.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) return - conversation_id = self.graph_runtime_state.system_variable.conversation_id + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) if conversation_id is None: return - updated_any = False - for item in updated_variables: - selector = item.selector - if len(selector) < 2: - logger.warning("Conversation variable selector invalid. selector=%s", selector) - continue - if selector[0] != CONVERSATION_VARIABLE_NODE_ID: - continue - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - logger.warning( - "Conversation variable not found in variable pool. selector=%s", - selector, - ) - continue - self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) - updated_any = True + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + return - if updated_any: - self._conversation_variable_updater.flush() + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 4370c01a0b..98e2257b1f 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -6,9 +6,10 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory @@ -119,7 +120,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer): generate_entity=entity_wrapper, ) - workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 2adaf14a35..172306f271 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,21 +1,28 @@ -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_events.graph import GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): """ """ + def __init__(self) -> None: + super().__init__() + self._paused = False + def on_graph_start(self): - pass + self._paused = False def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. """ if isinstance(event, GraphRunPausedEvent): - pass + self._paused = True def on_graph_end(self, error: Exception | None): """ """ - pass + self._paused = False + + def is_paused(self) -> bool: + return self._paused diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index d7ca45f209..fef12df504 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -4,9 +4,9 @@ from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore -from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a4019a83e1..781a0aa3d3 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -5,9 +5,10 @@ from typing import Any, ClassVar from pydantic import TypeAdapter from core.db.session_factory import session_factory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events.base import GraphEngineEvent +from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity @@ -59,7 +60,10 @@ class TriggerPostLayer(GraphEngineLayer): outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id - workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id, "Workflow run id is not set" total_tokens = self.graph_runtime_state.total_tokens diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa5..c49c4eb0ac 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,23 +2,34 @@ from __future__ import annotations from typing import Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.llm.entities import ModelConfig -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager - def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: - self.tenant_id = tenant_id - self.provider_manager = provider_manager or ProviderManager() + def __init__( + self, + *, + run_context: DifyRunContext, + provider_manager: ProviderManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if provider_manager is None: + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + self.provider_manager = provider_manager def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: provider_configurations = self.provider_manager.get_configurations(self.tenant_id) @@ -42,9 +53,21 @@ class DifyModelFactory: tenant_id: str model_manager: ModelManager - def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: - self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + def __init__( + self, + *, + run_context: DifyRunContext, + model_manager: ModelManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if model_manager is None: + model_manager = ModelManager( + provider_manager=create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + ) + self.model_manager = model_manager def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( @@ -55,18 +78,42 @@ class DifyModelFactory: ) -def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: - return ( - DifyCredentialsProvider(tenant_id=tenant_id), - DifyModelFactory(tenant_id=tenant_id), +def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]: + """Create LLM access adapters that share the same tenant-bound manager graph.""" + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, ) + model_manager = ModelManager(provider_manager=provider_manager) + + return ( + DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyModelFactory(run_context=run_context, model_manager=model_manager), + ) + + +def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """ + Split node-level completion params into provider parameters and stop sequences. + + Workflow LLM-compatible nodes still consume runtime invocation settings from + ``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the + ``ModelInstance`` view and the returned config entity aligned here so callers + do not need to duplicate normalization logic. + """ + normalized_parameters = dict(completion_params) + stop = normalized_parameters.pop("stop", []) + if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop): + stop = [] + + return normalized_parameters, stop def fetch_model_config( *, node_data_model: ModelConfig, credentials_provider: CredentialsProvider, - model_factory: ModelFactory, + model_factory: DifyModelFactory, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -80,22 +127,18 @@ def fetch_model_config( model_type=ModelType.LLM, ) if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") provider_model.raise_for_status() - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + if model_schema is None: + raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.") + parameters, stop = _normalize_completion_params(node_data_model.completion_params) model_instance.provider = node_data_model.provider model_instance.model_name = node_data_model.name model_instance.credentials = credentials - model_instance.parameters = completion_params + model_instance.parameters = parameters model_instance.stop = tuple(stop) return model_instance, ModelConfigWithCredentialsEntity( @@ -103,8 +146,8 @@ def fetch_model_config( model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=completion_params, + parameters=parameters, stop=stop, + provider_model_bundle=provider_model_bundle, ) diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 7aa3bf15ab..65a3f39d64 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -6,8 +6,8 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMUsage from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 0d5e0acec6..9e688589db 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -17,7 +17,7 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b530fe1ce4..cf9cb6d051 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -51,15 +51,15 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.enums import FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from events.message_event import message_was_created +from extensions.ext_database import db +from graphon.file.enums import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.message_event import message_was_created -from extensions.ext_database import db +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index fc8b6c6b5a..45f622c469 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,8 @@ from typing import TypedDict from core.tools.signature import sign_tool_file -from dify_graph.file import helpers as file_helpers -from dify_graph.file.enums import FileTransferMethod +from graphon.file import helpers as file_helpers +from graphon.file.enums import FileTransferMethod from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index e0f8d27111..aa5291bad5 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -1,33 +1,42 @@ from __future__ import annotations +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse from collections.abc import Generator +from typing import TYPE_CHECKING, Literal from configs import dify_config +from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file -from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from dify_graph.file.runtime import set_workflow_file_runtime +from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +from graphon.file.enums import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime + +if TYPE_CHECKING: + from graphon.file.models import File class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``dify_graph.file``.""" + """Production runtime wiring for ``graphon.file``. - @property - def files_url(self) -> str: - return dify_config.FILES_URL + Opaque file references are resolved back to canonical database records before + URLs are signed or storage keys are used. When a request-scoped file access + scope is present, those lookups additionally enforce tenant and end-user + ownership filters. + """ - @property - def internal_files_url(self) -> str | None: - return dify_config.INTERNAL_FILES_URL + _file_access_controller: FileAccessControllerProtocol - @property - def secret_key(self) -> str: - return dify_config.SECRET_KEY - - @property - def files_access_timeout(self) -> int: - return dify_config.FILES_ACCESS_TIMEOUT + def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None: + self._file_access_controller = file_access_controller @property def multimodal_send_format(self) -> str: @@ -39,9 +48,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + storage_key = self._resolve_storage_key(file=file) + data = storage.load(storage_key, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {storage_key} is not a bytes object") + return data + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + return file.remote_url + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if file.transfer_method == FileTransferMethod.LOCAL_FILE: + return self.resolve_upload_file_url( + upload_file_id=parsed_reference.record_id, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + self._assert_upload_file_access(upload_file_id=parsed_reference.record_id) + return sign_tool_file( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.TOOL_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + return self.resolve_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + return None + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._assert_upload_file_access(upload_file_id=upload_file_id) + base_url = self._base_url(for_external=for_external) + url = f"{base_url}/files/{upload_file_id}/file-preview" + query = self._sign_query(payload=f"file-preview|{upload_file_id}") + if as_attachment: + query["as_attachment"] = "true" + return f"{url}?{urllib.parse.urlencode(query)}" + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: + payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest() + if sign != base64.urlsafe_b64encode(recalculated).decode(): + return False + return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def _base_url(*, for_external: bool) -> str: + if for_external: + return dify_config.FILES_URL + return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + @staticmethod + def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + + def _sign_query(self, *, payload: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest() + return { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(sign).decode(), + } + + def _resolve_storage_key(self, *, file: File) -> str: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + + record_id = parsed_reference.record_id + with session_factory.create_session() as session: + if file.transfer_method in { + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + }: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id) + if upload_file is None: + raise ValueError(f"Upload file {record_id} not found") + return upload_file.key + + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id) + if tool_file is None: + raise ValueError(f"Tool file {record_id} not found") + return tool_file.file_key + + def _assert_upload_file_access(self, *, upload_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id) + if upload_file is None: + raise ValueError(f"Upload file {upload_file_id} not found") + + def _assert_tool_file_access(self, *, tool_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id) + if tool_file is None: + raise ValueError(f"Tool file {tool_file_id} not found") + def bind_dify_workflow_file_runtime() -> None: - set_workflow_file_runtime(DifyWorkflowFileRuntime()) + set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index faf1516c40..5666bf1191 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -9,20 +9,21 @@ from typing import TYPE_CHECKING, cast, final from typing_extensions import override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.base.node import Node +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase +from graphon.graph_events.node import NodeRunSucceededEvent +from graphon.nodes.base.node import Node if TYPE_CHECKING: - from dify_graph.nodes.llm.node import LLMNode - from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + from graphon.nodes.llm.node import LLMNode + from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer): return try: - dify_ctx = node.require_dify_context() + dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) deduct_llm_quota( tenant_id=dify_ctx.tenant_id, model_instance=model_instance, @@ -114,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer): try: match node.node_type: case BuiltinNodeTypes.LLM: - return cast("LLMNode", node).model_instance + model_instance = cast("LLMNode", node).model_instance case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - return cast("ParameterExtractorNode", node).model_instance + model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - return cast("QuestionClassifierNode", node).model_instance + model_instance = cast("QuestionClassifierNode", node).model_instance case _: return None except AttributeError: @@ -127,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer): node.id, ) return None + + if isinstance(model_instance, ModelInstance): + return model_instance + + raw_model_instance = getattr(model_instance, "_model_instance", None) + if isinstance(raw_model_instance, ModelInstance): + return raw_model_instance + + return None diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 4b20477a7f..837bf7ff81 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -16,10 +16,6 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_ from typing_extensions import override from configs import dify_config -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser import ( DefaultNodeOTelParser, LLMNodeOTelParser, @@ -28,6 +24,10 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 99b64b3ab5..e540733de2 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -17,17 +17,19 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution -from dify_graph.enums import ( - SystemVariableKey, +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities import WorkflowExecution, WorkflowNodeExecution +from graphon.enums import ( WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, WorkflowType, ) -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, @@ -42,9 +44,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from graphon.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now @@ -128,14 +128,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): self._handle_graph_run_paused(event) return - if isinstance(event, NodeRunStartedEvent): - self._handle_node_started(event) - return - if isinstance(event, NodeRunRetryEvent): self._handle_node_retry(event) return + if isinstance(event, NodeRunStartedEvent): + self._handle_node_started(event) + return + if isinstance(event, NodeRunSucceededEvent): self._handle_node_succeeded(event) return @@ -372,10 +372,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution.error = error if update_outputs: + projected_outputs = project_node_outputs_for_workflow_run( + node_type=domain_execution.node_type, + inputs=node_result.inputs, + outputs=node_result.outputs, + ) domain_execution.update_from_mapping( inputs=node_result.inputs, process_data=node_result.process_data, - outputs=node_result.outputs, + outputs=projected_outputs, metadata=node_result.metadata, ) diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index beda515666..9e3c187210 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -15,8 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: @@ -25,12 +25,10 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str): if not text_content or text_content.isspace(): return - return model_instance.invoke_tts( - content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) def _process_future( @@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher: self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") - self.model_manager = ModelManager() + self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts") self.model_instance = self.model_manager.get_default_model_instance( tenant_id=self.tenant_id, model_type=ModelType.TTS ) @@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.voice ) future_queue.put(futures_result) break @@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher: if len(sentence_arr) >= min(self.max_sentence, 7): self.max_sentence += 1 text_content = "".join(sentence_arr) - futures_result = self.executor.submit( - _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice - ) + futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice) future_queue.put(futures_result) if isinstance(text_tmp, str): self.msg_text = text_tmp diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 24243add17..fe40d8f0e5 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -214,6 +214,6 @@ class DatasourceFileManager: # init tool_file_parser -# from dify_graph.file.datasource_file_parser import datasource_file_manager +# from graphon.file.datasource_file_parser import datasource_file_manager # # datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 4fa941ae16..8a9875e4d7 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,6 +6,7 @@ from typing import Any, cast from sqlalchemy import select import contexts +from core.app.file_access import DatabaseFileAccessController from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.datasource_entities import ( @@ -24,18 +25,20 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from factories import file_factory +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.file import File, get_file_type_by_mime_type +from graphon.file.enums import FileTransferMethod, FileType +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class DatasourceManager: @@ -279,11 +282,15 @@ class DatasourceManager: if datasource_file is not None: mapping = { "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(mime_type), + "type": get_file_type_by_mime_type(mime_type), "transfer_method": FileTransferMethod.TOOL_FILE, "url": url, } - file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + file_out = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) elif mtype == DatasourceMessage.MessageType.TEXT: assert isinstance(message.message, DatasourceMessage.TextMessage) yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) @@ -351,11 +358,10 @@ class DatasourceManager: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference(record_id=str(upload_file.id)), size=upload_file.size, storage_key=upload_file.key, url=upload_file.source_url, diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 4c9ff64479..84dd653772 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 2881888e27..089b8b8e59 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,7 +4,8 @@ from mimetypes import guess_extension, guess_type from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) @@ -103,8 +104,14 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + reference = getattr(file, "reference", None) or getattr(file, "related_id", None) + parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None + if parsed_reference is None: + raise ValueError("datasource file is missing reference") + url = cls.get_datasource_file_url( + datasource_file_id=parsed_reference.record_id, + extension=file.extension, + ) if file.type == FileType.IMAGE: yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 89b48fd2ef..f49cbf9ffe 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,5 @@ -from enum import StrEnum, auto +"""Compatibility wrapper for the runtime embedding input enum.""" +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType -class EmbeddingInputType(StrEnum): - """ - Enum for embedding input type. - """ - - DOCUMENT = auto() - QUERY = auto() +__all__ = ["EmbeddingInputType"] diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 1343bd8e82..9d970d5db1 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field -from dify_graph.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index d214652e9c..bfa4f56915 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -15,7 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file import helpers as file_helpers +from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 3427fc54b1..e99a131500 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -3,9 +3,9 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel +from graphon.model_runtime.entities.provider_entities import ProviderEntity class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a9f2300ba2..d90afd3f7b 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,7 +7,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -19,15 +21,17 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel): - Load balancing configurations - Model enablement/disablement + Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so + nested schema and model lookups reuse the caller scope that was already + resolved by the composition layer. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) @model_validator(mode="after") def _(self): @@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self + def bind_model_runtime(self, model_runtime: ModelRuntime) -> None: + """Attach the already-composed runtime for request-bound call chains.""" + self._bound_model_runtime = model_runtime + + def get_model_provider_factory(self) -> ModelProviderFactory: + """Return a provider factory that preserves any request-bound runtime.""" + if self._bound_model_runtime is not None: + return ModelProviderFactory(model_runtime=self._bound_model_runtime) + return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a830f227a9..dffc7f2fc1 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -12,7 +12,7 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4251cfd30b..951e065b2c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -13,7 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from dify_graph.nodes.code.entities import CodeLanguage +from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index c569e066f4..b96a9ce380 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 873f6a4093..dc37a36943 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,11 +4,11 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 600a444357..eb762c3508 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 52776ee626..46bf1d6937 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -31,10 +31,10 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account @@ -50,7 +50,10 @@ logger = logging.getLogger(__name__) class IndexingRunner: def __init__(self): self.storage = storage - self.model_manager = ModelManager() + + @staticmethod + def _get_model_manager(tenant_id: str) -> ModelManager: + return ModelManager.for_tenant(tenant_id=tenant_id) def _handle_indexing_error(self, document_id: str, error: Exception) -> None: """Handle indexing errors by updating document status.""" @@ -271,7 +274,7 @@ class IndexingRunner: doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, - indexing_technique: str = "economy", + indexing_technique: str = IndexTechniqueType.ECONOMY, ) -> IndexingEstimate: """ Estimate the indexing for the document. @@ -289,22 +292,22 @@ class IndexingRunner: dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") - if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": + if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_default_model_instance( + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) @@ -573,8 +576,8 @@ class IndexingRunner: """ embedding_model_instance = None - if dataset.indexing_technique == "high_quality": - embedding_model_instance = self.model_manager.get_model_instance( + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -587,7 +590,7 @@ class IndexingRunner: create_keyword_thread = None if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY ): # create keyword index create_keyword_thread = threading.Thread( @@ -597,7 +600,7 @@ class IndexingRunner: create_keyword_thread.start() max_workers = 10 - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -628,7 +631,7 @@ class IndexingRunner: tokens += future.result() if ( dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX - and dataset.indexing_technique == "economy" + and dataset.indexing_technique == IndexTechniqueType.ECONOMY and create_keyword_thread is not None ): create_keyword_thread.join() @@ -654,7 +657,7 @@ class IndexingRunner: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).where( DocumentSegment.document_id == document_id, @@ -764,16 +767,16 @@ class IndexingRunner: ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance( tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index c8848336d9..3712374305 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -27,13 +27,13 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow @@ -62,7 +62,7 @@ class LLMGenerator: prompt += query + "\n" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -120,7 +120,7 @@ class LLMGenerator: prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -172,7 +172,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -219,7 +219,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -306,7 +306,7 @@ class LLMGenerator: remove_template_variables=False, ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -337,7 +337,7 @@ class LLMGenerator: def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -362,7 +362,7 @@ class LLMGenerator: @classmethod def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -536,7 +536,7 @@ class LLMGenerator: injected_instruction = injected_instruction.replace(CURRENT, current or "null") if ERROR_MESSAGE in injected_instruction: injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 77ea1713ea..81672ee7aa 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -10,22 +10,22 @@ from pydantic import TypeAdapter, ValidationError from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT from core.model_manager import ModelInstance -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule class ResponseFormat(StrEnum): @@ -55,7 +55,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload @@ -70,7 +69,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload @@ -85,7 +83,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( @@ -99,7 +96,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ @@ -113,7 +109,6 @@ def invoke_llm_with_structured_output( :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -143,7 +138,6 @@ def invoke_llm_with_structured_output( tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index de68eb268b..92d23c6dc9 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -7,7 +7,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index db9cb726d7..7b5a7635f1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -8,7 +8,7 @@ from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1156a98af1..658206128d 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -4,10 +4,13 @@ from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController from core.model_manager import ModelInstance from core.prompt.utils.extract_thread_messages import extract_thread_messages -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from extensions.ext_database import db +from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -15,14 +18,14 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from extensions.ext_database import db -from factories import file_factory +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +_file_access_controller = DatabaseFileAccessController() + class TokenBufferMemory: def __init__( @@ -85,7 +88,10 @@ class TokenBufferMemory: # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( - message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config + message_file=message_file, + tenant_id=app_record.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) for message_file in message_files ] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf..f5ff375f65 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,21 +7,22 @@ from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel from extensions.ext_redis import redis_client +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType @@ -30,7 +31,7 @@ logger = logging.getLogger(__name__) class ModelInstance: """ - Model instance class + Model instance class. """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): @@ -49,6 +50,13 @@ class ModelInstance: credentials=self.credentials, ) + def get_model_schema(self) -> AIModelEntity: + """Return the resolved schema for the current model instance.""" + model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for {self.model_name}") + return model_schema + @staticmethod def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ @@ -110,7 +118,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator: ... @@ -122,7 +129,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResult: ... @@ -134,7 +140,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... @@ -145,7 +150,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ @@ -156,7 +160,6 @@ class ModelInstance: :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -173,7 +176,6 @@ class ModelInstance: tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ), ) @@ -202,13 +204,12 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> EmbeddingResult: """ Invoke large language model :param texts: texts to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -221,7 +222,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, texts=texts, - user=user, input_type=input_type, ), ) @@ -229,14 +229,12 @@ class ModelInstance: def invoke_multimodal_embedding( self, multimodel_documents: list[dict], - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ Invoke large language model :param multimodel_documents: multimodel documents to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -249,7 +247,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, - user=user, input_type=input_type, ), ) @@ -279,7 +276,6 @@ class ModelInstance: docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -288,7 +284,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -303,7 +298,6 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) @@ -313,7 +307,6 @@ class ModelInstance: docs: list[dict], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -322,7 +315,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -337,16 +329,14 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) - def invoke_moderation(self, text: str, user: str | None = None) -> bool: + def invoke_moderation(self, text: str) -> bool: """ Invoke moderation model :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): @@ -358,16 +348,14 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, text=text, - user=user, ), ) - def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: + def invoke_speech2text(self, file: IO[bytes]) -> str: """ Invoke large language model :param file: audio file - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): @@ -379,18 +367,15 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, file=file, - user=user, ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: """ Invoke large language tts model :param content_text: text content to be translated - :param tenant_id: user tenant id :param voice: model timbre - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -402,8 +387,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, content_text=content_text, - user=user, - tenant_id=tenant_id, voice=voice, ), ) @@ -477,10 +460,20 @@ class ModelInstance: class ModelManager: - def __init__(self): - self._provider_manager = ProviderManager() + def __init__(self, provider_manager: ProviderManager): + self._provider_manager = provider_manager - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: + @classmethod + def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": + return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)) + + def get_model_instance( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + ) -> ModelInstance: """ Get model instance :param tenant_id: tenant id @@ -496,7 +489,8 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + model_instance = ModelInstance(provider_model_bundle, model) + return model_instance def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 06676f5cf4..35d4469bc1 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,6 @@ from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): @@ -50,7 +50,7 @@ class OpenAIModeration(Moderation): def _is_violated(self, inputs: dict): text = "\n".join(str(inputs.values())) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest" ) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 18f35b5b9c..76e81242f4 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -57,9 +57,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + return workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id + ) def build_workflow_node_span( self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 45319f24c1..43b204b78c 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -14,9 +14,9 @@ from core.ops.aliyun_trace.entities.semconv import ( GenAISpanKind, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index f54461e99a..e354c3909a 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -271,8 +271,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) try: diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6e62387a1f..4a634e2e57 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,8 +28,8 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 32a0c77fe2..9f7d73b4ca 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,8 +28,8 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index ab4a7650ec..8ec69e3542 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -23,8 +23,8 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fb72bc2381..a3ead548bb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -23,8 +23,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9ac753240b..87a7579f3a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -35,7 +35,7 @@ from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: - from dify_graph.entities import WorkflowExecution + from graphon.entities import WorkflowExecution logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 0a6013e244..4f06458157 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -41,7 +41,7 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 7e56b1effa..1b1b1025bc 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -24,11 +24,11 @@ from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( +from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from dify_graph.nodes import BuiltinNodeTypes -from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id) return list(executions) except Exception: diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index ef1a3be45b..ed6a7dabbb 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): if field_name == "inputs": data = { "messages": [ - dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v + dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore + for msg in v ] if isinstance(v, list) else v, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2a657b672c..a55505822a 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -31,8 +31,8 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) # rearrange workflow_node_executions by starting time diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 11c9191bac..85625fc87d 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -18,22 +18,39 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, UserPromptMessage, ) +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant class PluginModelBackwardsInvocation(BaseBackwardsInvocation): + @staticmethod + def _get_bound_model_instance( + *, + tenant_id: str, + user_id: str | None, + provider: str, + model_type: ModelType, + model: str, + ): + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=provider, + model_type=model_type, + model=model, + ) + @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, ) if isinstance(response, Generator): @@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm with structured output """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, model_parameters=payload.completion_params, ) @@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke text embedding """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_text_embedding( - texts=payload.texts, - user=user_id, - ) + response = model_instance.invoke_text_embedding(texts=payload.texts) return response @@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke rerank """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): docs=payload.docs, score_threshold=payload.score_threshold, top_n=payload.top_n, - user=user_id, ) return response @@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke tts """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_tts( - content_text=payload.content_text, - tenant_id=tenant.id, - voice=payload.voice, - user=user_id, - ) + response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) def handle() -> Generator[dict, None, None]: for chunk in response: @@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke speech2text """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( - file=temp, - user=user_id, - ) + response = model_instance.invoke_speech2text(file=temp) return { "result": response, @@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke moderation """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_moderation( - text=payload.text, - user=user_id, - ) + response = model_instance.invoke_moderation(text=payload.text) return { "result": response, } @classmethod - def get_system_model_max_tokens(cls, tenant_id: str) -> int: + def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int: """ get system model max tokens """ - return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id) @classmethod - def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ get prompt tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + return ModelInvocationUtils.calculate_tokens( + tenant_id=tenant_id, + prompt_messages=prompt_messages, + user_id=user_id, + ) @classmethod def invoke_system_model( @@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tool_type=ToolProviderType.PLUGIN, tool_name="plugin", prompt_messages=prompt_messages, + caller_user_id=user_id, ) @classmethod @@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke summary """ - max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id) content = payload.text SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -325,6 +337,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=content)], + user_id=user_id, ) < max_tokens * 0.6 ): @@ -337,6 +350,7 @@ Here is the extra instruction you need to follow: SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), UserPromptMessage(content=content), ], + user_id=user_id, ) def summarize(content: str) -> str: @@ -394,6 +408,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=result)], + user_id=user_id, ) > max_tokens * 0.7 ): diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index d6aef93fc4..248f8ef3e6 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,17 +1,17 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) from services.workflow_service import WorkflowService diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index c2d1574e67..0585494269 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id + tool_type, + tenant_id, + provider, + tool_name, + tool_parameters, + user_id=user_id, + credential_id=credential_id, ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 81e1e12c5f..1bd239a831 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -4,7 +4,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 7a3780f7de..6aefc41400 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -13,7 +13,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 416e0f6b4d..864e4b8dd7 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -16,8 +16,8 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index c15e9b0385..704cacae2a 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.provider_entities import BasicProviderConfig from core.plugin.utils.http_parser import deserialize_response -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -17,17 +17,17 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 737d204105..f6580d3707 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -27,14 +27,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 49ee5d79cb..c91fa71374 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO +from typing import IO, Any from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -13,15 +13,22 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): + @staticmethod + def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {"data": data} + if user_id is not None: + payload["user_id"] = user_id + return payload + def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient): def get_model_schema( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/schema", PluginModelSchemaEntity, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict ) -> bool: """ validate the credentials of the provider @@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient): def validate_model_credentials( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_model_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient): def invoke_llm( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/invoke", type_=LLMResultChunk, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "llm", "model": model, @@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient): "stop": stop, "stream": stream, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient): def get_llm_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", type_=PluginLLMNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, @@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient): "prompt_messages": prompt_messages, "tools": tools, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient): def invoke_text_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient): "texts": texts, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient): "documents": documents, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient): def get_text_embedding_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, "credentials": credentials, "texts": texts, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient): def invoke_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient): def invoke_tts( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, @@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient): "content_text": content_text, "voice": voice, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient): def get_tts_model_voices( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type_=PluginVoicesResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, "credentials": credentials, "language": language, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient): def invoke_speech_to_text( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "speech2text", "model": model, "credentials": credentials, "file": binascii.hexlify(file.read()).decode(), }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient): def invoke_moderation( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/moderation/invoke", type_=PluginBasicBooleanResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "moderation", "model": model, "credentials": credentials, "text": text, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py new file mode 100644 index 0000000000..e3fba4ef3a --- /dev/null +++ b/api/core/plugin/impl/model_runtime.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import hashlib +import logging +from collections.abc import Generator, Iterable, Sequence +from threading import Lock +from typing import IO, Any, Union + +from pydantic import ValidationError +from redis import RedisError + +from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient +from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime +from models.provider_ids import ModelProviderID + +logger = logging.getLogger(__name__) + +# `TS` means tenant scope +TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" + + +class PluginModelRuntime(ModelRuntime): + """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + + tenant_id: str + user_id: str | None + client: PluginModelClient + _provider_entities: tuple[ProviderEntity, ...] | None + _provider_entities_lock: Lock + + def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + if client is None: + raise ValueError("client is required.") + self.tenant_id = tenant_id + self.user_id = user_id + self.client = client + self._provider_entities = None + self._provider_entities_lock = Lock() + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: + if self._provider_entities is not None: + return self._provider_entities + + with self._provider_entities_lock: + if self._provider_entities is None: + self._provider_entities = tuple( + self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) + ) + + return self._provider_entities + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + provider_schema = self._get_provider_schema(provider) + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + file_name = ( + provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + ) + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + file_name = ( + provider_schema.icon_small_dark.zh_Hans + if lang.lower() == "zh_hans" + else provider_schema.icon_small_dark.en_US + ) + else: + raise ValueError(f"Unsupported icon type: {icon_type}.") + + if not file_name: + raise ValueError(f"Provider {provider} does not have icon.") + + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + credentials=credentials, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_model_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + cache_key = self._get_schema_cache_key( + provider=provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + + cached_schema_json = None + try: + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + plugin_id, provider_name = self._split_provider(provider) + schema = self.client.get_model_schema( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_llm( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: + return 0 + + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_llm_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, + ) + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + documents=documents, + input_type=input_type, + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_text_embedding_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + ) + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_tts( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_tts_model_voices( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + language=language, + ) + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_speech_to_text( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + file=file, + ) + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_moderation( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + text=text, + ) + + def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = self._get_provider_short_name_alias(provider) + return declaration + + def _get_provider_schema(self, provider: str) -> ProviderEntity: + providers = self.fetch_model_providers() + provider_entity = next((item for item in providers if item.provider == provider), None) + if provider_entity is None: + provider_entity = next((item for item in providers if provider == item.provider_name), None) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + return provider_entity + + def _get_schema_cache_key( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> str: + # The plugin daemon distinguishes ``None`` from an explicit empty-string + # caller id, so the cache must only collapse ``None`` into tenant scope. + cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" + sorted_credentials = sorted(credentials.items()) if credentials else [] + if not sorted_credentials: + return cache_key + hashed_credentials = ":".join( + [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] + ) + return f"{cache_key}:{hashed_credentials}" + + def _split_provider(self, provider: str) -> tuple[str, str]: + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py new file mode 100644 index 0000000000..35abd2ae8c --- /dev/null +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + +if TYPE_CHECKING: + from core.model_manager import ModelManager + from core.plugin.impl.model_runtime import PluginModelRuntime + from core.provider_manager import ProviderManager + + +class PluginModelAssembly: + """Compose request-scoped model views on top of a single plugin runtime.""" + + tenant_id: str + user_id: str | None + _model_runtime: PluginModelRuntime | None + _model_provider_factory: ModelProviderFactory | None + _provider_manager: ProviderManager | None + _model_manager: ModelManager | None + + def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None: + self.tenant_id = tenant_id + self.user_id = user_id + self._model_runtime = None + self._model_provider_factory = None + self._provider_manager = None + self._model_manager = None + + @property + def model_runtime(self) -> PluginModelRuntime: + if self._model_runtime is None: + self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id) + return self._model_runtime + + @property + def model_provider_factory(self) -> ModelProviderFactory: + if self._model_provider_factory is None: + self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + return self._model_provider_factory + + @property + def provider_manager(self) -> ProviderManager: + if self._provider_manager is None: + from core.provider_manager import ProviderManager + + self._provider_manager = ProviderManager(model_runtime=self.model_runtime) + return self._provider_manager + + @property + def model_manager(self) -> ModelManager: + if self._model_manager is None: + from core.model_manager import ModelManager + + self._model_manager = ModelManager(provider_manager=self.provider_manager) + return self._model_manager + + +def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly: + """Create a request-scoped assembly that shares one plugin runtime across model views.""" + return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id) + + +def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: + """Create a plugin runtime with its client dependency fully composed.""" + from core.plugin.impl.model_runtime import PluginModelRuntime + + return PluginModelRuntime( + tenant_id=tenant_id, + user_id=user_id, + client=PluginModelClient(), + ) + + +def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: + """Create a tenant-bound model provider factory for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory + + +def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: + """Create a tenant-bound provider manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager + + +def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: + """Create a tenant-bound model manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index 0bbb62af93..ec4858ae2e 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -209,8 +209,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/decode/from_identifier", PluginDecodeResponse, - data={"plugin_unique_identifier": plugin_unique_identifier}, - headers={"Content-Type": "application/json"}, + params={"plugin_unique_identifier": plugin_unique_identifier}, ) def fetch_plugin_installation_by_ids( diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 53bcd9e9c6..322f78ab4e 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,7 @@ from typing import Any from core.tools.entities.tool_entities import ToolSelector -from dify_graph.file.models import File +from graphon.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index ce9f7e64b2..de87a09652 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -8,9 +8,9 @@ from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.file.models import File +from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -18,8 +18,8 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.runtime import VariablePool +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index d09a46bfde..8f1d51f08a 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -5,12 +5,12 @@ from core.app.entities.app_invoke_entities import ( ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 667f5ef099..b98fd8c179 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,50 +1,7 @@ -from typing import Literal +from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """ - Chat Message. - """ - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """ - Completion Model Prompt Template. - """ - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - - class RolePrefix(BaseModel): - """ - Role Prefix. - """ - - user: str - assistant: str - - class WindowConfig(BaseModel): - """ - Window Config. - """ - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 951736831f..6ff2f44cdc 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 10c44349ae..e091215b80 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -10,8 +10,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, @@ -22,7 +22,7 @@ from dify_graph.model_runtime.entities.message_entities import ( from models.model import AppMode if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file.models import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 85a2201395..ba76eb0c4e 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import Any, cast from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3c3fbd6dd2..79fd78fe80 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextlib import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -28,17 +30,17 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.position_helper import is_filtered -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from extensions import ext_hosting_provider +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from extensions import ext_hosting_provider -from extensions.ext_database import db -from extensions.ext_redis import redis_client +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, @@ -53,15 +55,25 @@ from models.provider import ( from models.provider_ids import ModelProviderID from services.feature_service import FeatureService +if TYPE_CHECKING: + from graphon.model_runtime.runtime import ModelRuntime + class ProviderManager: """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + ProviderManager manages tenant-scoped model provider configuration. + + The runtime adapter is injected by the composition layer so this class stays + focused on configuration assembly instead of constructing plugin runtimes. + Request-bound managers may carry caller identity in that runtime, and the + resulting ``ProviderConfiguration`` objects must reuse it for downstream + model-type and schema lookups. """ - def __init__(self): + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None + self._model_runtime = model_runtime def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -127,7 +139,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -255,6 +267,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) + provider_configuration.bind_model_runtime(self._model_runtime) provider_configurations[str(provider_id_entity)] = provider_configuration @@ -321,7 +334,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( @@ -392,7 +405,7 @@ class ProviderManager: # create default model default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.value, + model_type=model_type.to_origin_model_type(), provider_name=provider, model_name=model, ) @@ -918,11 +931,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.TRIAL.value, + pool_type=ProviderQuotaType.TRIAL, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type=ProviderQuotaType.PAID.value, + pool_type=ProviderQuotaType.PAID, ) else: trail_pool = None diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 33eb5f963a..2c81653559 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -8,8 +8,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): @@ -52,11 +52,10 @@ class DataPostProcessor: documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) @@ -106,9 +105,9 @@ class DataPostProcessor: ) -> ModelInstance | None: if reranking_model: try: - model_manager = ModelManager() - reranking_provider_name = reranking_model["reranking_provider_name"] - reranking_model_name = reranking_model["reranking_model_name"] + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 713319ab9d..1e4aa24287 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -23,8 +23,8 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -328,7 +328,7 @@ class RetrievalService: str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) if dataset.is_multimodal: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, provider=reranking_model["reranking_provider_name"], diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 144d834495..9f5842e449 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -13,6 +13,7 @@ from pymochow.exception import ServerError # type: ignore from pymochow.model.database import Database from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore from pymochow.model.schema import ( + AutoBuildRowCountIncrement, Field, FilteringIndex, HNSWParams, @@ -51,6 +52,9 @@ class BaiduConfig(BaseModel): replicas: int = 3 inverted_index_analyzer: str = "DEFAULT_ANALYZER" inverted_index_parser_mode: str = "COARSE_MODE" + auto_build_row_count_increment: int = 500 + auto_build_row_count_increment_ratio: float = 0.05 + rebuild_index_timeout_in_seconds: int = 300 @model_validator(mode="before") @classmethod @@ -107,18 +111,6 @@ class BaiduVector(BaseVector): rows.append(row) table.upsert(rows=rows) - # rebuild vector index after upsert finished - table.rebuild_index(self.vector_index) - timeout = 3600 # 1 hour timeout - start_time = time.time() - while True: - time.sleep(1) - index = table.describe_index(self.vector_index) - if index.state == IndexState.NORMAL: - break - if time.time() - start_time > timeout: - raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") - def text_exists(self, id: str) -> bool: res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) if res and res.code == 0: @@ -232,8 +224,14 @@ class BaiduVector(BaseVector): return self._client.database(self._client_config.database) def _table_existed(self) -> bool: - tables = self._db.list_table() - return any(table.table_name == self._collection_name for table in tables) + try: + table = self._db.table(self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + return False + else: + raise + return True def _create_table(self, dimension: int): # Try to grab distributed lock and create table @@ -287,6 +285,11 @@ class BaiduVector(BaseVector): field=VDBField.VECTOR, metric_type=metric_type, params=HNSWParams(m=16, efconstruction=200), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement( + row_count_increment=self._client_config.auto_build_row_count_increment, + row_count_increment_ratio=self._client_config.auto_build_row_count_increment_ratio, + ), ) ) @@ -335,7 +338,7 @@ class BaiduVector(BaseVector): ) # Wait for table created - timeout = 300 # 5 minutes timeout + timeout = self._client_config.rebuild_index_timeout_in_seconds # default 5 minutes timeout start_time = time.time() while True: time.sleep(1) @@ -345,6 +348,20 @@ class BaiduVector(BaseVector): if time.time() - start_time > timeout: raise TimeoutError(f"Table creation timeout after {timeout} seconds") redis_client.set(table_exist_cache_key, 1, ex=3600) + # rebuild vector index immediately after table created, make sure index is ready + table.rebuild_index(self.vector_index) + timeout = 3600 # 1 hour timeout + self._wait_for_index_ready(table, timeout) + + def _wait_for_index_ready(self, table, timeout: int = 3600): + start_time = time.time() + while True: + time.sleep(1) + index = table.describe_index(self.vector_index) + if index.state == IndexState.NORMAL: + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") class BaiduVectorFactory(AbstractVectorFactory): @@ -369,5 +386,8 @@ class BaiduVectorFactory(AbstractVectorFactory): replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, + auto_build_row_count_increment=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT, + auto_build_row_count_increment_ratio=dify_config.BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO, + rebuild_index_timeout_in_seconds=dify_config.BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS, ), ) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index c7b6593a8f..df02c584ed 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -124,13 +124,13 @@ class HuaweiCloudVector(BaseVector): ) ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) docs = [] for doc, score in docs_and_scores: - score_threshold = float(kwargs.get("score_threshold") or 0.0) if score >= score_threshold: if doc.metadata is not None: doc.metadata["score"] = score - docs.append(doc) + docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 71b6fa0a9b..3c1d5e015f 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -33,6 +33,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding +from models.enums import TidbAuthBindingStatus if TYPE_CHECKING: from qdrant_client import grpc # noqa @@ -452,7 +453,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): password=new_cluster["password"], tenant_id=dataset.tenant_id, active=True, - status="ACTIVE", + status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 754c149241..06b17b9e62 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -9,6 +9,7 @@ from configs import dify_config from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus class TidbService: @@ -170,7 +171,7 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = "ACTIVE" + cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" db.session.add(cluster_info) db.session.commit() diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd12cd3fae..a77458706a 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -14,10 +14,10 @@ from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile @@ -303,7 +303,7 @@ class Vector: redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 16a5588024..369159767e 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,9 +6,10 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding @@ -71,8 +72,8 @@ class DatasetDocumentStore: if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 6d1b65a055..b12a0ae2d6 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -10,10 +10,10 @@ from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding @@ -21,9 +21,8 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: str | None = None): + def __init__(self, model_instance: ModelInstance): self._model_instance = model_instance - self._user = user def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" @@ -65,7 +64,7 @@ class CacheEmbedding(Embeddings): batch_texts = embedding_queue_texts[i : i + max_chunks] embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT ) for vector in embedding_result.embeddings: @@ -147,7 +146,6 @@ class CacheEmbedding(Embeddings): embedding_result = self._model_instance.invoke_multimodal_embedding( multimodel_documents=batch_multimodel_documents, - user=self._user, input_type=EmbeddingInputType.DOCUMENT, ) @@ -202,7 +200,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + texts=[text], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] @@ -245,7 +243,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_multimodal_embedding( - multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 371f7b0865..e1ddd2dd96 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -95,15 +95,11 @@ class FirecrawlApp: if response.status_code == 200: crawl_status_response = response.json() if crawl_status_response.get("status") == "completed": - total = crawl_status_response.get("total", 0) - if total == 0: + # Normalize to avoid None bypassing the zero-guard when the API returns null. + total = crawl_status_response.get("total") or 0 + if total <= 0: raise Exception("Failed to check crawl status. Error: No page found") - data = crawl_status_response.get("data", []) - url_data_list: list[FirecrawlDocumentData] = [] - for item in data: - if isinstance(item, dict) and "metadata" in item and "markdown" in item: - url_data = self._extract_common_fields(item) - url_data_list.append(url_data) + url_data_list = self._collect_all_crawl_pages(crawl_status_response, headers) if url_data_list: file_key = "website_files/" + job_id + ".txt" try: @@ -120,6 +116,36 @@ class FirecrawlApp: self._handle_error(response, "check crawl status") raise RuntimeError("unreachable: _handle_error always raises") + def _collect_all_crawl_pages( + self, first_page: dict[str, Any], headers: dict[str, str] + ) -> list[FirecrawlDocumentData]: + """Collect all crawl result pages by following pagination links. + + Raises an exception if any paginated request fails, to avoid returning + partial data that is inconsistent with the reported total. + + The number of pages processed is capped at ``total`` (the + server-reported page count) to guard against infinite loops caused by + a misbehaving server that keeps returning a ``next`` URL. + """ + total: int = first_page.get("total") or 0 + url_data_list: list[FirecrawlDocumentData] = [] + current_page = first_page + pages_processed = 0 + while True: + for item in current_page.get("data", []): + if isinstance(item, dict) and "metadata" in item and "markdown" in item: + url_data_list.append(self._extract_common_fields(item)) + next_url: str | None = current_page.get("next") + pages_processed += 1 + if not next_url or pages_processed >= total: + break + response = self._get_request(next_url, headers) + if response.status_code != 200: + self._handle_error(response, "fetch next crawl page") + current_page = response.json() + return url_data_list + def _format_crawl_status_response( self, status: str, diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index d9145023ac..a6d1db214b 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -9,6 +9,7 @@ from flask import current_app from sqlalchemy import delete, func, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview @@ -159,7 +160,7 @@ class IndexProcessor: tenant_id = dataset.tenant_id preview_output = self.format_preview(chunk_structure, chunks) - if indexing_technique != "high_quality": + if indexing_technique != IndexTechniqueType.HIGH_QUALITY: return preview_output if not summary_index_setting or not summary_index_setting.get("enable"): diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80163b1707..9f36b7a225 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,11 +8,12 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword @@ -22,23 +23,24 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols -from dify_graph.file import File, FileTransferMethod, FileType, file_manager -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories.file_factory import build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from extensions.ext_database import db -from factories.file_factory import build_from_mapping +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account @@ -48,6 +50,8 @@ from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService +_file_access_controller = DatabaseFileAccessController() + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -117,7 +121,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -155,7 +159,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -253,12 +257,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: keyword = Keyword(dataset) keyword.add_texts(documents) @@ -410,7 +414,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # If default prompt doesn't have {language} placeholder, use it as-is pass - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id, model_provider_name, ModelType.LLM ) @@ -555,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): file_obj = build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) file_objects.append(file_obj) except Exception as e: @@ -604,11 +609,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, ) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index df0761ca73..70504e6e50 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -18,7 +18,7 @@ from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) for document in documents: child_documents = document.children @@ -166,7 +166,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete all summaries for the dataset SummaryIndexService.delete_summaries_for_segments(dataset, None) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: delete_child_chunks = kwargs.get("delete_child_chunks") or False precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) @@ -332,7 +332,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) # add document segments doc_store.add_documents(docs=documents, save_child=True) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: all_child_documents = [] all_multimodal_documents = [] for doc in documents: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 62f88b7760..6874603a83 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,7 +21,7 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -141,7 +141,7 @@ class QAIndexProcessor(BaseIndexProcessor): with_keywords: bool = True, **kwargs, ) -> None: - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: @@ -224,7 +224,7 @@ class QAIndexProcessor(BaseIndexProcessor): # save node to document segment doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: vector = Vector(dataset) vector.create(documents) else: diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index dc3b771406..4ebf095904 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.file import File +from graphon.file import File class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 88acb75133..cc65262527 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -12,7 +12,6 @@ class BaseRerankRunner(ABC): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -21,7 +20,6 @@ class BaseRerankRunner(ABC): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fcb14ffc52..6c6b077cc2 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -5,10 +5,10 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult from models.model import UploadFile @@ -22,7 +22,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -31,10 +30,11 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + ) is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, @@ -43,12 +43,12 @@ class RerankModelRunner(BaseRerankRunner): ) if not is_support_vision: if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) else: return documents else: rerank_result, unique_documents = self.fetch_multimodal_rerank( - query, documents, score_threshold, top_n, user, query_type + query, documents, score_threshold, top_n, query_type ) rerank_documents = [] @@ -73,7 +73,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> tuple[RerankResult, list[Document]]: """ Fetch text rerank @@ -81,7 +80,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ docs = [] @@ -103,7 +101,7 @@ class RerankModelRunner(BaseRerankRunner): unique_documents.append(document) rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents @@ -113,7 +111,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> tuple[RerankResult, list[Document]]: """ @@ -122,7 +119,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :param query_type: query type :return: rerank result """ @@ -168,7 +164,7 @@ class RerankModelRunner(BaseRerankRunner): documents = unique_documents if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) return rerank_result, unique_documents elif query_type == QueryType.IMAGE_QUERY: # Query file info within db.session context to ensure thread-safe access @@ -181,7 +177,7 @@ class RerankModelRunner(BaseRerankRunner): "content_type": DocType.IMAGE, } rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 7edd05d2d1..d0732b269a 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -11,7 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): @@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ @@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner): """ query_vector_scores = [] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=tenant_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 78a97f79a5..49b91707ec 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, @@ -63,13 +64,14 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( SourceChildChunk, SourceMetadata, ) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile from models.dataset import ( @@ -160,7 +162,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -383,23 +385,27 @@ class DatasetRetrieval: return None, [] retrieve_config = config.retrieve_config - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - # get model schema + # Reuse the caller-bound model instance for both schema resolution and + # downstream planner/invoke calls so a single request never mixes + # tenant-scope and request-bound runtimes. model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials + model=model_instance.model_name, + credentials=model_instance.credentials, ) if not model_schema: return None, [] + model_config.provider_model_bundle = model_instance.provider_model_bundle + model_config.credentials = model_instance.credentials + model_config.model_schema = model_schema + planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: @@ -517,11 +523,12 @@ class DatasetRetrieval: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=segment.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), @@ -675,7 +682,7 @@ class DatasetRetrieval: # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if selected_dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == IndexTechniqueType.ECONOMY: retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -752,7 +759,7 @@ class DatasetRetrieval: "The configured knowledge base list have different indexing technique, please set reranking model." ) index_type = available_datasets[0].indexing_technique - if index_type == "high_quality": + if index_type == IndexTechniqueType.HIGH_QUALITY: embedding_model_check = all( item.embedding_model == available_datasets[0].embedding_model for item in available_datasets ) @@ -986,6 +993,24 @@ class DatasetRetrieval: ) ) + @staticmethod + def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None: + """Map runtime user source values to dataset query audit roles. + + Workflow run context uses the hyphenated ``end-user`` value, while + ``DatasetQuery.created_by_role`` persists the underscore-based + ``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an + unsupported value should be skipped instead of aborting retrieval. + """ + normalized_user_from = str(user_from).strip().lower().replace("-", "_") + if normalized_user_from == CreatorUserRole.ACCOUNT.value: + return CreatorUserRole.ACCOUNT + if normalized_user_from == CreatorUserRole.END_USER.value: + return CreatorUserRole.END_USER + + logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from) + return None + def _on_query( self, query: str | None, @@ -996,10 +1021,18 @@ class DatasetRetrieval: user_id: str, ): """ - Handle query. + Persist dataset query audit rows for retrieval requests. """ if not query and not attachment_ids: return + created_by = parse_uuid_str_or_none(user_id) + if created_by is None: + logger.debug( + "Skipping dataset query log: empty created_by user_id (user_from=%s, app_id=%s)", + user_from, + app_id, + ) + return dataset_queries = [] for dataset_id in dataset_ids: contents = [] @@ -1015,7 +1048,7 @@ class DatasetRetrieval: source=DatasetQuerySource.APP, source_app_id=app_id, created_by_role=CreatorUserRole(user_from), - created_by=user_id, + created_by=created_by, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -1068,7 +1101,7 @@ class DatasetRetrieval: else default_retrieval_model ) - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -1411,7 +1444,7 @@ class DatasetRetrieval: raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id) # fetch prompt messages prompt_messages, stop = self._get_prompt_template( @@ -1430,7 +1463,6 @@ class DatasetRetrieval: model_parameters=model_config.parameters, stop=stop, stream=True, - user=user_id, ), ) @@ -1533,7 +1565,7 @@ class DatasetRetrieval: return filters def _fetch_model_config( - self, tenant_id: str, model: ModelConfig + self, tenant_id: str, model: ModelConfig, user_id: str | None = None ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config @@ -1543,7 +1575,7 @@ class DatasetRetrieval: model_name = model.name provider_name = model.provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 23a2ac8386..e617a9660e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,8 +2,8 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index ea110fa0a7..83e58fe0f9 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -3,13 +3,14 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -119,6 +120,7 @@ class ReactMultiDatasetRouter: memory_config=None, memory=None, model_config=model_config, + model_instance=model_instance, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -150,19 +152,24 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( + bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) + invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=completion_param, stop=stop, stream=True, - user=user_id, ) # handle invoke result text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage) return text, usage diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 7a00e8a886..2c27ac3cf6 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -15,7 +15,7 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 31d21dbeee..6f120bd471 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -2,6 +2,7 @@ import concurrent.futures import logging from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary from services.summary_index_service import SummaryIndexService @@ -21,7 +22,7 @@ class SummaryIndex: if is_preview: with session_factory.create_session() as session: dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset or dataset.indexing_technique != "high_quality": + if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return if summary_index_setting is None: diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..cfa9962ea8 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -4,7 +4,13 @@ from __future__ import annotations from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .factory import ( + DifyCoreRepositoryFactory, + OrderConfig, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -12,7 +18,10 @@ __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", + "OrderConfig", "RepositoryImportError", "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", + "WorkflowExecutionRepository", + "WorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 57764574d7..d0164b76dc 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -11,8 +11,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.workflow_execution import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 650cf79550..52361cf6dc 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -12,11 +12,11 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.repositories.workflow_node_execution_repository import ( +from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # For now, we'll re-raise the exception raise - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + Retrieve all workflow node executions for a workflow execution from cache. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results Returns: A sequence of WorkflowNodeExecution instances """ try: - # Get execution IDs for this workflow run from cache - execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + # Get execution IDs for this workflow execution from cache + execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, []) # Retrieve executions from cache result = [] @@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): for field_name in reversed(order_config.order_by): result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) - logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + logger.debug( + "Retrieved %d workflow node executions for execution %s from cache", + len(result), + workflow_execution_id, + ) return result except Exception: - logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + logger.exception( + "Failed to get workflow node executions for execution %s from cache", + workflow_execution_id, + ) return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dc9f8c96bf..dafdbf641a 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -from typing import Union +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Protocol, Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom +@dataclass +class OrderConfig: + """Configuration for ordering node execution instances.""" + + order_by: list[str] + order_direction: Literal["asc", "desc"] | None = None + + +class WorkflowExecutionRepository(Protocol): + def save(self, execution: WorkflowExecution): ... + + +class WorkflowNodeExecutionRepository(Protocol): + def save(self, execution: WorkflowNodeExecution): ... + + def save_execution_data(self, execution: WorkflowNodeExecution): ... + + def get_by_workflow_execution( + self, + workflow_execution_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: ... + + class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 6607a87032..02625e242f 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -2,33 +2,23 @@ import dataclasses import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Protocol from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( + BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, + InteractiveSurfaceDeliveryMethod, + is_human_input_webapp_enabled, ) +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin @@ -36,6 +26,7 @@ from models.human_input import ( BackstageRecipientPayload, ConsoleDeliveryPayload, ConsoleRecipientPayload, + DeliveryMethodType, EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, @@ -58,6 +49,65 @@ class _WorkspaceMemberInfo: email: str +class FormNotFoundError(Exception): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str | None + node_id: str + form_config: HumanInputNodeData + rendered_content: str + delivery_methods: Sequence[DeliveryChannelConfig] + display_in_ui: bool + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + +class HumanInputFormRecipientEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def token(self) -> str: ... + + +class HumanInputFormEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def submission_token(self) -> str | None: ... + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... + + +class HumanInputFormRepository(Protocol): + def get_form(self, node_id: str) -> HumanInputFormEntity | None: ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ... + + class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): def __init__(self, recipient_model: HumanInputFormRecipient): self._recipient_model = recipient_model @@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( + self._interactive_surface_recipient = next( ( recipient for recipient in recipient_models @@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._form_model.id @property - def web_app_token(self): + def submission_token(self) -> str | None: if self._console_recipient is not None: return self._console_recipient.access_token - if self._web_app_recipient is None: + if self._interactive_surface_recipient is None: return None - return self._web_app_recipient.access_token + return self._interactive_surface_recipient.access_token @property def recipients(self) -> list[HumanInputFormRecipientEntity]: @@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl: self, *, tenant_id: str, - ): + app_id: str | None = None, + workflow_execution_id: str | None = None, + invoke_source: str | None = None, + submission_actor_id: str | None = None, + ) -> None: self._tenant_id = tenant_id + self._app_id = app_id + self._workflow_execution_id = workflow_execution_id + self._invoke_source = invoke_source + self._submission_actor_id = submission_actor_id def _delivery_method_to_model( self, @@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): + if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, @@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl: delivery_id: str, recipients_config: EmailRecipients, ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + bound_reference_ids = [ + recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient) ] external_emails = [ recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) ] - if recipients_config.whole_workspace: + if recipients_config.include_bound_group: members = self._query_all_workspace_members(session=session) else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids) return self._create_email_recipients_from_resolved( form_id=form_id, @@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl: rows = session.execute(stmt).all() return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def _should_create_console_recipient( + self, + *, + form_config: HumanInputNodeData, + form_kind: HumanInputFormKind, + ) -> bool: + if form_kind != HumanInputFormKind.RUNTIME: + return False + if self._invoke_source == "debugger": + return True + if self._invoke_source == "explore": + return is_human_input_webapp_enabled(form_config) + return False + + def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool: + return form_kind == HumanInputFormKind.RUNTIME and ( + self._invoke_source is not None or self._submission_actor_id is not None + ) + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config + app_id = self._app_id + if not app_id: + raise ValueError("app_id is required to create a human input form") + workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id + if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None: + raise ValueError("workflow_execution_id is required for runtime human input forms") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl: form_model = HumanInputForm( id=form_id, tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, + app_id=app_id, + workflow_run_id=workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, form_definition=form_definition.model_dump_json(), @@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl: session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( + if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models ): console_delivery_id = str(uuidv7()) @@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl: delivery_id=console_delivery_id, recipient_type=RecipientType.CONSOLE, recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(console_delivery) session.add(console_recipient) recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( + if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models ): backstage_delivery_id = str(uuidv7()) @@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl: delivery_id=backstage_delivery_id, recipient_type=RecipientType.BACKSTAGE, recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(backstage_delivery) @@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + if self._workflow_execution_id is None: + raise ValueError("workflow_execution_id is required to load runtime human input forms") + form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.workflow_run_id == self._workflow_execution_id, HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 55e96515ac..1ee5d4ae77 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -9,10 +9,10 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7373ebc7cc..749ab44a14 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -17,12 +17,12 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( @@ -518,29 +518,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return db_models - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. This method always queries the database to ensure complete and ordered results, but updates the cache with any retrieved executions. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of node execution instances """ - # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) + db_models = self.get_db_models_by_workflow_run(workflow_execution_id, order_config, triggered_from) with ThreadPoolExecutor(max_workers=10) as executor: domain_models = executor.map(self._to_domain_model, db_models, timeout=30) diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 961d13f90a..5154bc9805 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing. + + ``user_id`` is optional so read-only tooling flows can stay tenant-scoped, + while execution paths may bind caller identity for model runtime lookups. """ tenant_id: str + user_id: str | None = None tool_id: str | None = None invoke_from: InvokeFrom | None = None tool_invoke_from: ToolInvokeFrom | None = None diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index dacc49c746..40bf2e98c2 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -7,9 +7,9 @@ from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.file.enums import FileType -from dify_graph.file.file_manager import download -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.file.enums import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService @@ -22,6 +22,9 @@ class ASRTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: + if not self.runtime: + raise ValueError("Runtime is required") + runtime = self.runtime file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore yield self.create_text_message("not a valid audio file") @@ -29,20 +32,19 @@ class ASRTool(BuiltinTool): audio_binary = io.BytesIO(download(file)) # type: ignore audio_binary.name = "temp.mp3" provider, model = tool_parameters.get("model").split("#") # type: ignore - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=runtime.tenant_id, provider=provider, model_type=ModelType.SPEECH2TEXT, model=model, ) - text = model_instance.invoke_speech2text( - file=audio_binary, - user=user_id, - ) + text = model_instance.invoke_speech2text(file=audio_binary) yield self.create_text_message(text) def get_available_models(self) -> list[tuple[str, str]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type( tenant_id=self.runtime.tenant_id, model_type="speech2text" diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 7818bff0ab..ac3820f1ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -7,7 +7,7 @@ from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService @@ -20,13 +20,14 @@ class TTSTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - provider, model = tool_parameters.get("model").split("#") # type: ignore - voice = tool_parameters.get(f"voice#{provider}#{model}") - model_manager = ModelManager() if not self.runtime: raise ValueError("Runtime is required") + runtime = self.runtime + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id or "", + tenant_id=runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -39,12 +40,7 @@ class TTSTool(BuiltinTool): raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), # type: ignore - user=user_id, - tenant_id=self.runtime.tenant_id, - voice=voice, - ) + tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type] buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index 44f94c2723..e07ca0d919 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import UTC, datetime from typing import Any -from pytz import timezone as pytz_timezone +from pytz import timezone as pytz_timezone # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index d0a41b940f..dc49b64dd8 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 462e4be5ce..8045e4b980 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index e23ae3b001..e2570811d6 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index bcf58394ba..d41503e1e6 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -4,8 +4,8 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but @@ -53,6 +53,7 @@ class BuiltinTool(Tool): tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, + caller_user_id=self.runtime.user_id, ) def tool_provider_type(self) -> ToolProviderType: @@ -69,6 +70,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", + user_id=self.runtime.user_id, ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -82,7 +84,9 @@ class BuiltinTool(Tool): raise ValueError("runtime is required") return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + tenant_id=self.runtime.tenant_id or "", + prompt_messages=prompt_messages, + user_id=self.runtime.user_id, ) def summary(self, user_id: str, content: str) -> str: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index c6a84e27c6..168e5f4493 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -13,7 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from dify_graph.file.file_manager import download +from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2545290b57..08640befb4 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -9,7 +9,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 9025ff6ef1..00fc8a8282 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -21,7 +21,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 22e099deba..1807226924 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -3,6 +3,7 @@ import hashlib import hmac import os import time +import urllib.parse from configs import dify_config @@ -58,3 +59,43 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: + """Build the signed upload URL used by the plugin-facing file upload endpoint.""" + + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + upload_url = f"{base_url}/files/upload/for-plugin" + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": encoded_sign, + "user_id": user_id, + "tenant_id": tenant_id, + } + ) + return f"{upload_url}?{query}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str +) -> bool: + """Verify the signature used by the plugin-facing file upload endpoint.""" + + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 64212a2636..1fd259f3bb 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -31,9 +31,9 @@ from core.tools.errors import ( ) from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FileType -from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db +from graphon.file import FileType +from graphon.file.models import FileTransferMethod from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 210f488afc..2ec292602c 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -14,8 +14,9 @@ import httpx from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from dify_graph.file.models import ToolFile as ToolFilePydanticModel +from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile @@ -23,6 +24,21 @@ logger = logging.getLogger(__name__) class ToolFileManager: + @staticmethod + def _build_graph_file_reference(tool_file: ToolFile) -> File: + extension = guess_extension(tool_file.mimetype) or ".bin" + return File( + type=get_file_type_by_mime_type(tool_file.mimetype), + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + filename=tool_file.name, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -209,9 +225,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id( - self, tool_file_id: str - ) -> tuple[Generator | None, ToolFilePydanticModel | None]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: """ get file binary @@ -233,11 +247,11 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, ToolFilePydanticModel.model_validate(tool_file) + return stream, self._build_graph_file_reference(tool_file) # init tool_file_parser -from dify_graph.file.tool_file_parser import set_tool_file_manager_factory +from graphon.file.tool_file_parser import set_tool_file_manager_factory def _factory() -> ToolFileManager: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 23a877b7e3..4870adb7b5 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa from sqlalchemy import select @@ -24,14 +24,14 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_database import db +from graphon.runtime.variable_pool import VariablePool from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -57,12 +57,12 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass logger = logging.getLogger(__name__) @@ -77,6 +77,23 @@ class EmojiIconDict(TypedDict): content: str +class WorkflowToolRuntimeSpec(Protocol): + @property + def provider_type(self) -> ToolProviderType: ... + + @property + def provider_id(self) -> str: ... + + @property + def tool_name(self) -> str: ... + + @property + def tool_configurations(self) -> Mapping[str, Any]: ... + + @property + def credential_id(self) -> str | None: ... + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -167,6 +184,7 @@ class ToolManager: provider_id: str, tool_name: str, tenant_id: str, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, @@ -178,6 +196,7 @@ class ToolManager: :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id + :param user_id: the caller id bound to runtime-scoped model/tool lookups :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id @@ -196,6 +215,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -304,6 +324,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(decrypted_credentials), credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, @@ -321,6 +342,7 @@ class ToolManager: return api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(encrypter.decrypt(credentials)), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -344,6 +366,7 @@ class ToolManager: return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -352,9 +375,21 @@ class ToolManager: elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool elif provider_type == ToolProviderType.MCP: - return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -364,6 +399,7 @@ class ToolManager: tenant_id: str, app_id: str, agent_tool: AgentToolEntity, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -375,6 +411,7 @@ class ToolManager: provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, @@ -405,7 +442,8 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: WorkflowToolRuntimeSpec, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -418,6 +456,7 @@ class ToolManager: provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, @@ -450,6 +489,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + user_id: str | None = None, credential_id: str | None = None, ) -> Tool: """ @@ -460,6 +500,7 @@ class ToolManager: provider_id=provider, tool_name=tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, @@ -1015,14 +1056,14 @@ class ToolManager: cls, parameters: list[ToolParameter], variable_pool: Optional["VariablePool"], - tool_configurations: dict[str, Any], + tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: """ Convert tool parameters type """ - from dify_graph.nodes.tool.entities import ToolNodeData - from dify_graph.nodes.tool.exc import ToolParameterError + from graphon.nodes.tool.entities import ToolNodeData + from graphon.nodes.tool.exc import ToolParameterError runtime_parameters = {} for parameter in parameters: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index c2b520fa99..dad5133a7a 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { @@ -65,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): for thread in threads: thread.join() # do rerank for searched documents - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) rerank_model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider=self.reranking_provider_name, @@ -169,7 +170,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 429b7e6622..f3d390ed59 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -8,6 +8,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -140,7 +141,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): # get retrieval model , if the model is not setting , using default retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] - if dataset.indexing_technique == "economy": + if dataset.indexing_technique == IndexTechniqueType.ECONOMY: # use keyword table query documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod.KEYWORD_SEARCH, @@ -173,7 +174,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for hit_callback in self.hit_callbacks: hit_callback.on_tool_end(documents) document_score_list = {} - if dataset.indexing_technique != "economy": + if dataset.indexing_technique != IndexTechniqueType.ECONOMY: for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fc5fead2d..5cf46b2564 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Generator from datetime import date, datetime from decimal import Decimal @@ -10,12 +11,15 @@ import pytz from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account logger = logging.getLogger(__name__) +_TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P[^/?#.]+)") + def safe_json_value(v): if isinstance(v, datetime): @@ -82,11 +86,15 @@ class ToolFileMessageTransformer: ) url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" + meta = cls._with_tool_file_meta( + message.meta, + tool_file_id=str(tool_file.id), + ) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, + meta=meta, ) except Exception as e: yield ToolInvokeMessage( @@ -122,38 +130,45 @@ class ToolFileMessageTransformer: ) url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) + meta = cls._with_tool_file_meta(meta, tool_file_id=str(tool_file.id)) # check if file is image if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("tool file is missing reference") + url = cls.get_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + ) + tool_file_meta = cls._with_tool_file_meta(meta, tool_file_id=parsed_reference.record_id) if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield message @@ -162,9 +177,40 @@ class ToolFileMessageTransformer: if isinstance(message.message, ToolInvokeMessage.JsonMessage): message.message.json_object = safe_json_value(message.message.json_object) yield message + elif message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + } and isinstance(message.message, ToolInvokeMessage.TextMessage): + yield ToolInvokeMessage( + type=message.type, + message=message.message, + meta=cls._with_tool_file_meta(message.meta, url=message.message.text), + ) else: yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" + + @staticmethod + def _with_tool_file_meta( + meta: dict | None, + *, + tool_file_id: str | None = None, + url: str | None = None, + ) -> dict: + normalized_meta = meta.copy() if meta is not None else {} + resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) + if resolved_tool_file_id and "tool_file_id" not in normalized_meta: + normalized_meta["tool_file_id"] = resolved_tool_file_id + return normalized_meta + + @staticmethod + def _extract_tool_file_id(url: str | None) -> str | None: + if not url: + return None + match = _TOOL_FILE_URL_PATTERN.search(url) + if match is None: + return None + return match.group("tool_file_id") diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 373bd1b1c8..9e1d41cb39 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -10,19 +10,19 @@ from typing import cast from core.model_manager import ModelManager from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.errors.invoke import ( +from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ToolModelInvoke @@ -34,11 +34,12 @@ class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, + user_id: str | None = None, ) -> int: """ get max llm context tokens of the model """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -60,13 +61,13 @@ class ModelInvocationUtils: return max_tokens @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ calculate tokens from prompt messages and model parameters """ # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -79,7 +80,12 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, + tenant_id: str, + tool_type: ToolProviderType, + tool_name: str, + prompt_messages: list[PromptMessage], + caller_user_id: str | None = None, ) -> LLMResult: """ invoke model with parameters in user's own context @@ -93,7 +99,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, @@ -137,7 +143,6 @@ class ModelInvocationUtils: tools=[], stop=[], stream=False, - user=user_id, callbacks=[], ) except InvokeRateLimitError as e: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 28f1376655..1e4f3ed2a7 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,9 +3,9 @@ from typing import Any from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.variables.input_entities import VariableEntity +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.variables.input_entities import VariableEntity class WorkflowToolConfigurationUtils: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index aef8b3f779..716368c191 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -22,8 +22,8 @@ from core.tools.entities.tool_entities import ( ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9b9aa7a741..495fcd48b3 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,6 +7,7 @@ from typing import Any, cast from sqlalchemy import select +from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -17,14 +18,17 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WorkflowTool(Tool): @@ -288,16 +292,25 @@ class WorkflowTool(Tool): file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [File.model_validate(f) for f in file] + file_var_list = [ + build_file_from_stored_mapping( + file_mapping=cast(Mapping[str, Any], f), + tenant_id=str(self.runtime.tenant_id), + ) + for f in file + if isinstance(f, Mapping) + ] for file in file_var_list: file_dict: dict[str, str | None] = { "transfer_method": file.transfer_method.value, "type": file.type.value, } if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file.related_id + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file.related_id + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() @@ -325,6 +338,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=item, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: @@ -332,6 +346,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=value, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) @@ -340,9 +355,10 @@ class WorkflowTool(Tool): return result, files def _update_file_mapping(self, file_dict: dict): + file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_dict.get("related_id") + file_dict["tool_file_id"] = file_id elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_dict.get("related_id") + file_dict["upload_file_id"] = file_id return file_dict diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 2a133b2b94..24c1271488 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -26,8 +26,8 @@ from core.trigger.debug.events import ( ) from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/file_reference.py b/api/core/workflow/file_reference.py new file mode 100644 index 0000000000..c80acb3783 --- /dev/null +++ b/api/core/workflow/file_reference.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + + +@dataclass(frozen=True) +class FileReference: + record_id: str + storage_key: str | None = None + + +def build_file_reference(*, record_id: str, storage_key: str | None = None) -> str: + payload = {"record_id": record_id} + if storage_key is not None: + payload["storage_key"] = storage_key + encoded_payload = base64.urlsafe_b64encode(json.dumps(payload, separators=(",", ":")).encode()).decode() + return f"{_FILE_REFERENCE_PREFIX}{encoded_payload}" + + +def parse_file_reference(reference: str | None) -> FileReference | None: + if not reference: + return None + + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return FileReference(record_id=reference) + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return FileReference(record_id=reference) + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return FileReference(record_id=reference) + + storage_key = payload.get("storage_key") + if storage_key is not None and not isinstance(storage_key, str): + storage_key = None + + return FileReference(record_id=record_id, storage_key=storage_key) + + +def resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return None + return parsed_reference.record_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py new file mode 100644 index 0000000000..75a0a0c202 --- /dev/null +++ b/api/core/workflow/human_input_compat.py @@ -0,0 +1,299 @@ +"""Workflow-layer adapters for legacy human-input payload keys. + +Stored workflow graphs and editor payloads may still use Dify-specific human +input recipient keys. Normalize them here before handing configs to +`graphon` so graph-owned models only see graph-neutral field names. +""" + +from __future__ import annotations + +import enum +import uuid +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ClassVar, Literal + +import bleach +import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.runtime import VariablePool +from graphon.variables.consts import SELECTORS_LENGTH + + +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + BOUND = "member" + MEMBER = BOUND + EXTERNAL = "external" + + +class _InteractiveSurfaceDeliveryConfig(BaseModel): + pass + + +class BoundRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.BOUND] = EmailRecipientType.BOUND + reference_id: str + + +class ExternalRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +MemberRecipient = BoundRecipient +EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + model_config = ConfigDict(extra="forbid") + + include_bound_group: bool = Field( + default=False, + validation_alias=AliasChoices("include_bound_group", "whole_workspace"), + ) + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "br", + "code", + "em", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[set[str]] = set(bleach.sanitizer.ALLOWED_PROTOCOLS) | {"mailto"} + + recipients: EmailRecipients + subject: str + body: str + debug_mode: bool = False + + def with_recipients(self, recipients: EmailRecipients) -> EmailDeliveryConfig: + return self.model_copy(update={"recipients": recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + @classmethod + def render_markdown_body(cls, body: str) -> str: + stripped_body = bleach.clean(body, tags=[], attributes={}, strip=True) + rendered = markdown.markdown( + stripped_body, + extensions=[TableExtension(use_align_attribute=True)], + output_format="html", + ) + return bleach.clean( + rendered, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + ) + + @staticmethod + def sanitize_subject(subject: str) -> str: + sanitized = subject.replace("\r", " ").replace("\n", " ") + sanitized = bleach.clean(sanitized, tags=[], strip=True) + return " ".join(sanitized.split()) + + +class _DeliveryMethodBase(BaseModel): + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class InteractiveSurfaceDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + config: _InteractiveSurfaceDeliveryConfig = Field(default_factory=_InteractiveSurfaceDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +WebAppDeliveryMethod = InteractiveSurfaceDeliveryMethod +_WebAppDeliveryConfig = _InteractiveSurfaceDeliveryConfig + +DeliveryChannelConfig = Annotated[InteractiveSurfaceDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + +_DELIVERY_METHODS_ADAPTER = TypeAdapter(list[DeliveryChannelConfig]) + + +def _copy_mapping(value: object) -> dict[str, Any] | None: + if isinstance(value, BaseModel): + return value.model_dump(mode="python") + if isinstance(value, Mapping): + return dict(value) + return None + + +def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") + + delivery_methods = normalized.get("delivery_methods") + if not isinstance(delivery_methods, list): + return normalized + + normalized_methods: list[Any] = [] + for method in delivery_methods: + method_mapping = _copy_mapping(method) + if method_mapping is None: + normalized_methods.append(method) + continue + + config_mapping = _copy_mapping(method_mapping.get("config")) + if config_mapping is not None: + recipients_mapping = _copy_mapping(config_mapping.get("recipients")) + if recipients_mapping is not None: + config_mapping["recipients"] = _normalize_email_recipients(recipients_mapping) + method_mapping["config"] = config_mapping + + normalized_methods.append(method_mapping) + + normalized["delivery_methods"] = normalized_methods + return normalized + + +def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: + normalized = normalize_human_input_node_data_for_graph(node_data) + raw_delivery_methods = normalized.get("delivery_methods") + if not isinstance(raw_delivery_methods, list): + return [] + return list(_DELIVERY_METHODS_ADAPTER.validate_python(raw_delivery_methods)) + + +def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> bool: + for method in parse_human_input_delivery_methods(node_data): + if method.enabled and method.type == DeliveryMethodType.WEBAPP: + return True + return False + + +def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") + + if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: + return normalized + return normalize_human_input_node_data_for_graph(normalized) + + +def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_config) + if normalized is None: + raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") + + data_mapping = _copy_mapping(normalized.get("data")) + if data_mapping is None: + return normalized + + normalized["data"] = normalize_node_data_for_graph(data_mapping) + return normalized + + +def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(recipients) + + legacy_include_bound_group = normalized.pop("whole_workspace", None) + if "include_bound_group" not in normalized and legacy_include_bound_group is not None: + normalized["include_bound_group"] = legacy_include_bound_group + + items = normalized.get("items") + if not isinstance(items, list): + return normalized + + normalized_items: list[Any] = [] + for item in items: + item_mapping = _copy_mapping(item) + if item_mapping is None: + normalized_items.append(item) + continue + + legacy_reference_id = item_mapping.pop("user_id", None) + if "reference_id" not in item_mapping and legacy_reference_id is not None: + item_mapping["reference_id"] = legacy_reference_id + normalized_items.append(item_mapping) + + normalized["items"] = normalized_items + return normalized + + +__all__ = [ + "BoundRecipient", + "DeliveryChannelConfig", + "DeliveryMethodType", + "EmailDeliveryConfig", + "EmailDeliveryMethod", + "EmailRecipientType", + "EmailRecipients", + "ExternalRecipient", + "MemberRecipient", + "WebAppDeliveryMethod", + "_WebAppDeliveryConfig", + "is_human_input_webapp_enabled", + "normalize_human_input_node_data_for_graph", + "normalize_node_config_for_graph", + "normalize_node_data_for_graph", + "parse_human_input_delivery_methods", +] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py new file mode 100644 index 0000000000..f124b321d4 --- /dev/null +++ b/api/core/workflow/human_input_forms.py @@ -0,0 +1,55 @@ +"""Shared helpers for workflow pause-time human input form lookups. + +Both controllers and streaming response converters need the same recipient +priority when exposing resume links for paused human input forms. Keep that +selection logic here so all API surfaces stay consistent. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.human_input import HumanInputFormRecipient, RecipientType + +_FORM_TOKEN_PRIORITY = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def load_form_tokens_by_form_id( + form_ids: Sequence[str], + *, + session: Session | None = None, +) -> dict[str, str]: + """Load the preferred access token for each human input form.""" + unique_form_ids = list(dict.fromkeys(form_ids)) + if not unique_form_ids: + return {} + + if session is not None: + return _load_form_tokens_by_form_id(session, unique_form_ids) + + with Session(bind=db.engine, expire_on_commit=False) as new_session: + return _load_form_tokens_by_form_id(new_session, unique_form_ids) + + +def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: + tokens_by_form_id: dict[str, tuple[int, str]] = {} + stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(stmt): + priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) + if priority is None or not recipient.access_token: + continue + + candidate = (priority, recipient.access_token) + current = tokens_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + tokens_by_form_id[recipient.form_id] = candidate + + return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ab34263a79..028e38fbee 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -9,8 +9,8 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config -from core.app.entities.app_invoke_entities import DifyRunContext -from core.app.llm.model_access import build_dify_model_access +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -19,45 +19,48 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + DifyPreparedLLM, + DifyPromptMessageSerializer, + DifyRetrieverAttachmentLoader, + DifyToolFileManager, + DifyToolNodeRuntime, + build_dify_llm_file_saver, +) from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey -from dify_graph.file.file_manager import file_manager -from dify_graph.graph.graph import NodeFactory -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.nodes.document_extractor import UnstructuredApiConfig -from dify_graph.nodes.http_request import build_http_request_config -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import TemplateRenderer -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector +from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState LATEST_VERSION = "latest" _START_NODE_TYPES: frozenset[NodeType] = frozenset( @@ -76,7 +79,7 @@ def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] @lru_cache(maxsize=1) def register_nodes() -> None: """Import production node modules so they self-register with ``Node``.""" - _import_node_package("dify_graph.nodes") + _import_node_package("graphon.nodes") _import_node_package("core.workflow.nodes") @@ -84,7 +87,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] """Return a read-only snapshot of the current production node registry. The workflow layer owns node bootstrap because it must compose built-in - `dify_graph.nodes.*` implementations with workflow-local nodes under + `graphon.nodes.*` implementations with workflow-local nodes under `core.workflow.nodes.*`. Keeping this import side effect here avoids reintroducing registry bootstrapping into lower-level graph primitives. """ @@ -115,7 +118,7 @@ def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str: This workflow-layer helper depends on start-node semantics defined by `is_start_node_type`, so it intentionally lives next to the node registry - instead of in the raw `dify_graph.entities.graph_config` schema module. + instead of in the raw `graphon.entities.graph_config` schema module. """ nodes = graph_config.get("nodes") if not isinstance(nodes, list): @@ -229,16 +232,6 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) -class DefaultLLMTemplateRenderer(TemplateRenderer): - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=inputs, - ) - return str(result.get("result", "")) - - @final class DifyNodeFactory(NodeFactory): """ @@ -264,11 +257,31 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) - self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( + self._dify_context, + conversation_id_getter=self._conversation_id, + ) + self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) + self._prompt_message_serializer = DifyPromptMessageSerializer() + self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + ) + self._llm_file_saver = build_dify_llm_file_saver( + run_context=self._dify_context, + http_client=self._http_request_http_client, + conversation_id_getter=self._conversation_id, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime( + self._dify_context, + workflow_execution_id_getter=lambda: get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ), + ) + self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, @@ -284,7 +297,7 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport() @@ -299,6 +312,9 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) + def _conversation_id(self) -> str | None: + return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) + @override def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ @@ -310,7 +326,7 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) @@ -321,22 +337,29 @@ class DifyNodeFactory(NodeFactory): "code_limits": self._code_limits, }, BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { - "template_renderer": self._template_renderer, + "jinja2_template_renderer": self._jinja2_template_renderer, "max_output_length": self._template_transform_max_output_length, }, BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, - "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "tool_file_manager_factory": self._bound_tool_file_manager_factory, "file_manager": self._http_request_file_manager, + "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + "runtime": self._human_input_runtime, + "form_repository": self._human_input_runtime.build_form_repository(), }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=True, + include_jinja2_template_renderer=True, ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, @@ -345,15 +368,26 @@ class DifyNodeFactory(NodeFactory): BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, @@ -387,7 +421,12 @@ class DifyNodeFactory(NodeFactory): *, node_class: type[Node], node_data: BaseNodeData, + wrap_model_instance: bool, include_http_client: bool, + include_llm_file_saver: bool, + include_prompt_message_serializer: bool, + include_retriever_attachment_loader: bool, + include_jinja2_template_renderer: bool, ) -> dict[str, object]: validated_node_data = cast( LLMCompatibleNodeData, @@ -397,49 +436,35 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, "model_factory": self._llm_model_factory, - "model_instance": model_instance, + "model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance, "memory": self._build_memory_for_llm_node( node_data=validated_node_data, model_instance=model_instance, ), } - if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: - node_init_kwargs["template_renderer"] = self._llm_template_renderer + if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: + node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client + if include_llm_file_saver: + node_init_kwargs["llm_file_saver"] = self._llm_file_saver + if include_prompt_message_serializer: + node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer + if include_retriever_attachment_loader: + node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + if include_jinja2_template_renderer: + node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer + if validated_node_data.type == BuiltinNodeTypes.LLM: + node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, + model_instance, _ = fetch_model_config( + node_data_model=node_data_model, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance @@ -452,12 +477,7 @@ class DifyNodeFactory(NodeFactory): if node_data.memory is None: return None - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py new file mode 100644 index 0000000000..2e632e56f0 --- /dev/null +++ b/api/core/workflow/node_runtime.py @@ -0,0 +1,670 @@ +from __future__ import annotations + +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .human_input_compat import ( + BoundRecipient, + DeliveryChannelConfig, + DeliveryMethodType, + EmailDeliveryMethod, + EmailRecipients, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from .system_variables import SystemVariableKey, get_system_text + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + from graphon.file import File + from graphon.nodes.llm.file_saver import LLMFileSaver + from graphon.nodes.tool.entities import ToolNodeData + + +_file_access_controller = DatabaseFileAccessController() + + +def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext: + if isinstance(run_context, DifyRunContext): + return run_context + + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + +def apply_dify_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + actor_id: str | None, +) -> DeliveryChannelConfig: + """Apply the Dify debugger-specific email recipient override outside `graphon`.""" + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + + if actor_id is None: + debug_recipients = EmailRecipients(include_bound_group=False, items=[]) + else: + debug_recipients = EmailRecipients( + include_bound_group=False, + items=[BoundRecipient(reference_id=actor_id)], + ) + debug_config = method.config.with_recipients(debug_recipients) + return method.model_copy(update={"config": debug_config}) + + +class DifyFileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def build_from_mapping(self, *, mapping: Mapping[str, Any]): + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self._run_context.tenant_id, + access_controller=_file_access_controller, + ) + + +class DifyPreparedLLM(PreparedLLMProtocol): + """Workflow-layer adapter that hides the full `ModelInstance` API from `graphon` nodes.""" + + def __init__(self, model_instance: ModelInstance) -> None: + self._model_instance = model_instance + + @property + def provider(self) -> str: + return self._model_instance.provider + + @property + def model_name(self) -> str: + return self._model_instance.model_name + + @property + def parameters(self) -> Mapping[str, Any]: + return self._model_instance.parameters + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: + self._model_instance.parameters = value + + @property + def stop(self) -> Sequence[str] | None: + return self._model_instance.stop + + def get_model_schema(self) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( + self._model_instance.model_name, + self._model_instance.credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {self._model_instance.model_name}") + return model_schema + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: + return self._model_instance.get_llm_num_tokens(prompt_messages) + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + return self._model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=dict(model_parameters), + tools=list(tools or []), + stop=list(stop or []), + stream=stream, + ) + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + return invoke_llm_with_structured_output( + provider=self.provider, + model_schema=self.get_model_schema(), + model_instance=self._model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + stop=list(stop or []), + stream=stream, + ) + + def is_structured_output_parse_error(self, error: Exception) -> bool: + return isinstance(error, OutputParserError) + + +class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, + prompt_messages=prompt_messages, + ) + + +class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): + """Resolve retriever attachments through Dify persistence and return graph file references.""" + + def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + self._file_reference_factory = file_reference_factory + + def load(self, *, segment_id: str) -> Sequence[File]: + with Session(db.engine, expire_on_commit=False) as session: + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.segment_id == segment_id) + ).all() + + return [ + self._file_reference_factory.build_from_mapping( + mapping={ + "id": upload_file.id, + "filename": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "remote_url": upload_file.source_url, + "reference": build_file_reference(record_id=str(upload_file.id)), + "size": upload_file.size, + } + ) + for _, upload_file in attachments_with_bindings + ] + + +class DifyToolFileManager(ToolFileManagerProtocol): + """Workflow adapter that resolves conversation scope outside `graphon`.""" + + _conversation_id_getter: Callable[[], str | None] | None + + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + conversation_id_getter: Callable[[], str | None] | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._manager = ToolFileManager() + self._conversation_id_getter = conversation_id_getter + + def create_file_by_raw( + self, + *, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: + conversation_id = self._conversation_id_getter() if self._conversation_id_getter is not None else None + return self._manager.create_file_by_raw( + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=conversation_id, + file_binary=file_binary, + mimetype=mimetype, + filename=filename, + ) + + def get_file_generator_by_tool_file_id(self, tool_file_id: str): + return self._manager.get_file_generator_by_tool_file_id(tool_file_id) + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeSpec: + provider_type: CoreToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None = None + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeBinding: + """Workflow-private runtime state stored inside the opaque graph handle. + + The binding keeps conversation scope in `core.workflow` while `graphon` + continues to treat the handle as an opaque token. + """ + + tool: Tool + conversation_id: str | None = None + + +class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._file_reference_factory = DifyFileReferenceFactory(self._run_context) + + @property + def file_reference_factory(self) -> FileReferenceFactoryProtocol: + return self._file_reference_factory + + def build_file_reference(self, *, mapping: Mapping[str, Any]): + return self._file_reference_factory.build_from_mapping(mapping=mapping) + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool, + ) -> ToolRuntimeHandle: + try: + tool_runtime = ToolManager.get_workflow_tool_runtime( + self._run_context.tenant_id, + self._run_context.app_id, + node_id, + self._build_tool_runtime_spec(node_data), + self._run_context.user_id, + self._run_context.invoke_from, + variable_pool, + ) + except ToolNodeError: + raise + except Exception as exc: + raise ToolRuntimeResolutionError(str(exc)) from exc + + conversation_id = ( + None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + ) + return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: + tool = self._tool_from_handle(tool_runtime) + return [ + ToolRuntimeParameter(name=parameter.name, required=parameter.required) + for parameter in (tool.get_merged_runtime_parameters() or []) + ] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + runtime_binding = self._binding_from_handle(tool_runtime) + tool = runtime_binding.tool + callback = DifyWorkflowCallbackHandler() + + try: + messages = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=dict(tool_parameters), + user_id=self._run_context.user_id, + workflow_tool_callback=callback, + workflow_call_depth=workflow_call_depth, + app_id=self._run_context.app_id, + conversation_id=runtime_binding.conversation_id, + ) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=runtime_binding.conversation_id, + ) + + return self._adapt_messages(transformed_messages, provider_name=provider_name) + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: + latest = getattr(self._binding_from_handle(tool_runtime).tool, "latest_usage", None) + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + return LLMUsage.model_validate(latest) + return LLMUsage.empty_usage() + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: + icon: str | Mapping[str, str] | None = default_icon + icon_dark: str | Mapping[str, str] | None = None + + manager = PluginInstaller() + plugins = manager.list_plugins(self._run_context.tenant_id) + try: + current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name) + icon = current_plugin.declaration.icon + except StopIteration: + pass + + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self._run_context.user_id, + self._run_context.tenant_id, + ) + if provider.name == provider_name + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + return icon, icon_dark + + @staticmethod + def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: + return DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool + + @staticmethod + def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: + if isinstance(tool_runtime.raw, _WorkflowToolRuntimeBinding): + return tool_runtime.raw + return _WorkflowToolRuntimeBinding(tool=cast("Tool", tool_runtime.raw)) + + @staticmethod + def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + return _WorkflowToolRuntimeSpec( + provider_type=CoreToolProviderType(node_data.provider_type.value), + provider_id=node_data.provider_id, + tool_name=node_data.tool_name, + tool_configurations=dict(node_data.tool_configurations), + credential_id=node_data.credential_id, + ) + + def _adapt_messages( + self, + messages: Generator[CoreToolInvokeMessage, None, None], + *, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + try: + for message in messages: + yield self._convert_message(message) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage: + graph_message_type = ToolRuntimeMessage.MessageType(message.type.value) + graph_message = self._convert_message_payload(message.message) + graph_meta = message.meta.copy() if message.meta is not None else None + return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta) + + def _convert_message_payload( + self, + message: CoreToolInvokeMessage.TextMessage + | CoreToolInvokeMessage.JsonMessage + | CoreToolInvokeMessage.BlobChunkMessage + | CoreToolInvokeMessage.BlobMessage + | CoreToolInvokeMessage.LogMessage + | CoreToolInvokeMessage.FileMessage + | CoreToolInvokeMessage.VariableMessage + | CoreToolInvokeMessage.RetrieverResourceMessage + | None, + ) -> ( + ToolRuntimeMessage.TextMessage + | ToolRuntimeMessage.JsonMessage + | ToolRuntimeMessage.BlobChunkMessage + | ToolRuntimeMessage.BlobMessage + | ToolRuntimeMessage.LogMessage + | ToolRuntimeMessage.FileMessage + | ToolRuntimeMessage.VariableMessage + | ToolRuntimeMessage.RetrieverResourceMessage + | None + ): + if message is None: + return None + + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + if isinstance(message, CoreToolInvokeMessage.TextMessage): + return ToolRuntimeMessage.TextMessage(text=message.text) + if isinstance(message, CoreToolInvokeMessage.JsonMessage): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + if isinstance(message, CoreToolInvokeMessage.BlobMessage): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + if isinstance(message, CoreToolInvokeMessage.FileMessage): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + if isinstance(message, CoreToolInvokeMessage.VariableMessage): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + if isinstance(message, CoreToolInvokeMessage.LogMessage): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + + @staticmethod + def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: + if isinstance(exc, ToolNodeError): + return exc + if isinstance(exc, PluginInvokeError): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + if isinstance(exc, PluginDaemonClientSideError): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + if isinstance(exc, ToolInvokeError): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + return ToolRuntimeInvocationError(str(exc)) + + +class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + workflow_execution_id_getter: Callable[[], str | None] | None = None, + form_repository: HumanInputFormRepository | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._workflow_execution_id_getter = workflow_execution_id_getter + self._form_repository = form_repository + + def _invoke_source(self) -> str: + invoke_from = self._run_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + + def _resolve_delivery_methods(self, *, node_data: HumanInputNodeData) -> Sequence[DeliveryChannelConfig]: + invoke_source = self._invoke_source() + methods = [method for method in parse_human_input_delivery_methods(node_data) if method.enabled] + if invoke_source in {"debugger", "explore"}: + methods = [method for method in methods if method.type != DeliveryMethodType.WEBAPP] + return [ + apply_dify_debug_email_recipient( + method, + enabled=invoke_source == "debugger", + actor_id=self._run_context.user_id, + ) + for method in methods + ] + + def _display_in_ui(self, *, node_data: HumanInputNodeData) -> bool: + if self._invoke_source() == "debugger": + return True + return is_human_input_webapp_enabled(node_data) + + def build_form_repository(self) -> HumanInputFormRepository: + if self._form_repository is not None: + return self._form_repository + + return self._build_form_repository() + + def _build_form_repository(self) -> HumanInputFormRepository: + invoke_source = self._invoke_source() + return HumanInputFormRepositoryImpl( + tenant_id=self._run_context.tenant_id, + app_id=self._run_context.app_id, + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + invoke_source=invoke_source, + submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, + ) + + def with_form_repository(self, form_repository: HumanInputFormRepository) -> DifyHumanInputNodeRuntime: + return DifyHumanInputNodeRuntime( + self._run_context, + workflow_execution_id_getter=self._workflow_execution_id_getter, + form_repository=form_repository, + ) + + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + repo = self.build_form_repository() + return repo.get_form(node_id) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: + repo = self.build_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=self._resolve_delivery_methods(node_data=node_data), + display_in_ui=self._display_in_ui(node_data=node_data), + resolved_default_values=resolved_default_values, + ) + return repo.create_form(params) + + +def build_dify_llm_file_saver( + *, + run_context: Mapping[str, Any] | DifyRunContext, + http_client: HttpClientProtocol, + conversation_id_getter: Callable[[], str | None] | None = None, +) -> LLMFileSaver: + from graphon.nodes.llm.file_saver import FileSaverImpl + + return FileSaverImpl( + tool_file_manager=DifyToolFileManager(run_context, conversation_id_getter=conversation_id_getter), + file_reference_factory=DifyFileReferenceFactory(run_context), + http_client=http_client, + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5699ccf404..7b000101b0 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,11 +3,13 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import AgentNodeData from .exceptions import ( @@ -19,8 +21,8 @@ from .runtime_support import AgentRuntimeSupport from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class AgentNode(Node[AgentNodeData]): @@ -59,7 +61,7 @@ class AgentNode(Node[AgentNodeData]): return "1" def populate_start_event(self, event) -> None: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { "name": self.node_data.agent_strategy_name, "icon": self._presentation_provider.get_icon( @@ -71,7 +73,7 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) try: strategy = self._strategy_resolver.resolve( @@ -97,6 +99,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, ) @@ -106,20 +109,21 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, for_log=True, ) credentials = self._runtime_support.build_credentials(parameters=parameters) - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) try: message_stream = strategy.invoke( params=parameters, user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + conversation_id=conversation_id, credentials=credentials, ) except Exception as e: @@ -146,6 +150,7 @@ class AgentNode(Node[AgentNodeData]): parameters_for_log=parameters_for_log, user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + conversation_id=conversation_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 91fed39795..51452c29a3 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f58a5665f4..f44681377d 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -6,27 +6,30 @@ from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.file_access import DatabaseFileAccessController from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( +from extensions.ext_database import db +from factories import file_factory +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( AgentLogEvent, NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent, ) -from dify_graph.variables.segments import ArrayFileSegment -from extensions.ext_database import db -from factories import file_factory +from graphon.variables.segments import ArrayFileSegment from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError +_file_access_controller = DatabaseFileAccessController() + class AgentMessageTransformer: def transform( @@ -37,6 +40,7 @@ class AgentMessageTransformer: parameters_for_log: dict[str, Any], user_id: str, tenant_id: str, + conversation_id: str | None, node_type: NodeType, node_id: str, node_execution_id: str, @@ -47,7 +51,7 @@ class AgentMessageTransformer: messages=messages, user_id=user_id, tenant_id=tenant_id, - conversation_id=None, + conversation_id=conversation_id, ) text = "" @@ -70,10 +74,12 @@ class AgentMessageTransformer: url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -83,20 +89,23 @@ class AgentMessageTransformer: mapping = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } file = file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) @@ -111,6 +120,7 @@ class AgentMessageTransformer: file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9..a872774c98 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -12,16 +12,15 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.plugin.entities.request import InvokeCredentials -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from dify_graph.enums import SystemVariableKey -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated @@ -38,6 +37,7 @@ class AgentRuntimeSupport: node_data: AgentNodeData, strategy: ResolvedAgentStrategy, tenant_id: str, + user_id: str, app_id: str, invoke_from: Any, for_log: bool = False, @@ -141,6 +141,7 @@ class AgentRuntimeSupport: tenant_id, app_id, entity, + user_id, invoke_from, runtime_variable_pool, ) @@ -174,7 +175,11 @@ class AgentRuntimeSupport: value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) - model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + model_instance, model_schema = self.fetch_model( + tenant_id=tenant_id, + user_id=user_id, + value=value, + ) history_prompt_messages = [] if node_data.memory: memory = self.fetch_memory( @@ -219,10 +224,9 @@ class AgentRuntimeSupport: app_id: str, model_instance: ModelInstance, ) -> TokenBufferMemory | None: - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): + conversation_id = get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + if conversation_id is None: return None - conversation_id = conversation_id_variable.value with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) @@ -232,9 +236,15 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( + def fetch_model( + self, + *, + tenant_id: str, + user_id: str, + value: dict[str, Any], + ) -> tuple[ModelInstance, AIModelEntity | None]: + assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id) + provider_model_bundle = assembly.provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM, @@ -246,7 +256,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( + model_instance = assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 44f4a23a5a..38f39b3f94 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,22 +1,25 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -50,15 +53,14 @@ class DatasourceNode(Node[DatasourceNodeData]): """ Run the datasource node """ - - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + datasource_type_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_TYPE) if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None - datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + datasource_info_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_INFO) if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segment.value @@ -131,12 +133,14 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) case DatasourceProviderType.LOCAL_FILE: - related_id = datasource_info.get("related_id") - if not related_id: + file_id = resolve_file_record_id( + datasource_info.get("reference") or datasource_info.get("related_id") + ) + if not file_id: raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=dify_ctx.tenant_id + file_id=file_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 65864474b0..28966f2392 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -3,8 +3,8 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/datasource/protocols.py b/api/core/workflow/nodes/datasource/protocols.py index c006e0885c..776e267317 100644 --- a/api/core/workflow/nodes/datasource/protocols.py +++ b/api/core/workflow/nodes/datasource/protocols.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Any, Protocol -from dify_graph.file import File -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.file import File +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from .entities import DatasourceParameter, OnlineDriveDownloadFileParam diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8d2e9bf3cb..11339bb122 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4ea9091c5b..b465a2d8ff 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -6,12 +6,13 @@ from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, SystemVariableKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template +from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( @@ -19,8 +20,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) _INVOKE_FROM_DEBUGGER = "debugger" @@ -46,21 +47,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # get dataset id as string - dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + dataset_id_segment = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID) if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") dataset_id: str = dataset_id_segment.value # get document id as string (may be empty when not provided) - document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id_segment = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - invoke_from_value = str(invoke_from.value) if invoke_from else None + invoke_from_value = get_system_text(variable_pool, SystemVariableKey.INVOKE_FROM) is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value @@ -87,8 +87,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): outputs=outputs.model_dump(exclude_none=True), ) - original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + original_document_id_segment = get_system_segment(variable_pool, SystemVariableKey.ORIGINAL_DOCUMENT_ID) + batch = get_system_segment(variable_pool, SystemVariableKey.BATCH) if not batch: raise KnowledgeIndexNodeError("Batch is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index bc5618685a..3f7cc364d3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -3,9 +3,9 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 80f59140be..117f426ade 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,26 +9,28 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from core.workflow.file_reference import parse_file_reference +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import NodeRunResult +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.variables import ( ArrayFileSegment, FileSegment, StringSegment, ) -from dify_graph.variables.segments import ArrayObjectSegment +from graphon.variables.segments import ArrayObjectSegment from .entities import ( Condition, @@ -42,8 +44,8 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file.models import File + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -160,7 +162,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -254,7 +256,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_model_config=node_data.metadata_model_config, metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, + attachment_ids=[ + parsed_reference.record_id + for attachment in attachments + if (parsed_reference := parse_file_reference(attachment.reference)) is not None + ] + if attachments + else None, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index e1311ab962..ea45dcf5c2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -3,8 +3,8 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.model_runtime.entities import LLMUsage -from dify_graph.nodes.llm.entities import ModelConfig +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition @@ -54,7 +54,7 @@ class KnowledgeRetrievalRequest(BaseModel): tenant_id: str = Field(description="Tenant unique identifier") user_id: str = Field(description="User unique identifier") app_id: str = Field(description="Application unique identifier") - user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')") + user_from: str = Field(description="User identity source for audit logging (e.g., 'account', 'end-user')") dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from") query: str | None = Field(default=None, description="Query text for knowledge retrieval") retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'") diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index ea7d20befe..23ed2cd408 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -5,8 +5,8 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 118c2f2668..a2c952a899 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -2,11 +2,11 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerEventNodeData @@ -53,13 +53,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]): "plugin_unique_identifier": self.node_data.plugin_unique_identifier, }, } - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 95a2548678..207c1e7253 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -3,8 +3,8 @@ from typing import Literal, Union from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/exc.py b/api/core/workflow/nodes/trigger_schedule/exc.py index 336d64d58f..10962c3de4 100644 --- a/api/core/workflow/nodes/trigger_schedule/exc.py +++ b/api/core/workflow/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b9580e6ab1..dd80617dfc 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData @@ -31,13 +31,11 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): } def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 242bf5ef6a..3125fe17e6 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -4,9 +4,9 @@ from enum import StrEnum from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/exc.py b/api/core/workflow/nodes/trigger_webhook/exc.py index 4d87f2a069..00b0b3baad 100644 --- a/api/core/workflow/nodes/trigger_webhook/exc.py +++ b/api/core/workflow/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 317844cbda..6858d6dc35 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -3,16 +3,17 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.file import FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import FileVariable -from factories import file_factory +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment_with_type +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.enums import NodeExecutionType +from graphon.file import FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.protocols import FileReferenceFactoryProtocol +from graphon.variables.types import SegmentType +from graphon.variables.variables import FileVariable from .entities import ContentType, WebhookData @@ -23,6 +24,13 @@ class TriggerWebhookNode(Node[WebhookData]): node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT + _file_reference_factory: FileReferenceFactoryProtocol + + def post_init(self) -> None: + from core.workflow.node_runtime import DifyFileReferenceFactory + + self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -53,16 +61,14 @@ class TriggerWebhookNode(Node[WebhookData]): happens in the trigger controller. """ # Get webhook data from variable pool (injected by Celery task) - webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + webhook_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) # Extract webhook-specific outputs based on node configuration outputs = self._extract_configured_outputs(webhook_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + outputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=webhook_inputs, @@ -70,24 +76,20 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - dify_ctx = self.require_dify_context() - related_id = file.get("related_id") + file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: - file["upload_file_id"] = related_id + file["upload_file_id"] = file_id case FileTransferMethod.TOOL_FILE: - file["tool_file_id"] = related_id + file["tool_file_id"] = file_id case FileTransferMethod.DATASOURCE_FILE: - file["datasource_file_id"] = related_id + file["datasource_file_id"] = file_id try: - file_obj = file_factory.build_from_mapping( - mapping=file, - tenant_id=dify_ctx.tenant_id, - ) + file_obj = self._file_reference_factory.build_from_mapping(mapping=file) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) except ValueError: diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py new file mode 100644 index 0000000000..9d15a3fcea --- /dev/null +++ b/api/core/workflow/system_variables.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Protocol, cast +from uuid import uuid4 + +from graphon.enums import BuiltinNodeTypes +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.segments import Segment +from graphon.variables.variables import RAGPipelineVariableInput, Variable + +from .variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) + + +class SystemVariableKey(StrEnum): + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" + TIMESTAMP = "timestamp" + DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" + + +class _VariablePoolReader(Protocol): + def get(self, selector: Sequence[str], /) -> Segment | None: ... + + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ... + + +class _VariablePoolWriter(_VariablePoolReader, Protocol): + def add(self, selector: Sequence[str], value: object, /) -> None: ... + + +class _VariableLoader(Protocol): + def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ... + + +def system_variable_name(key: str | SystemVariableKey) -> str: + return key.value if isinstance(key, SystemVariableKey) else key + + +def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]: + return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key) + + +def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]: + raw_values = dict(values or {}) + raw_values.update(kwargs) + + workflow_execution_id = raw_values.pop("workflow_execution_id", None) + if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values: + raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id + + normalized: dict[str, Any] = {} + for key, value in raw_values.items(): + if value is None: + continue + normalized[system_variable_name(key)] = value + + normalized.setdefault(SystemVariableKey.FILES.value, []) + return normalized + + +def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]: + normalized = _normalize_system_variable_values(values, **kwargs) + + return [ + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, + ), + ) + for key, value in normalized.items() + ] + + +def default_system_variables() -> list[Variable]: + return build_system_variables(workflow_run_id=str(uuid4())) + + +def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]: + return {variable.name: variable.value for variable in system_variables} + + +def _with_selector(variable: Variable, node_id: str) -> Variable: + selector = [node_id, variable.name] + if list(variable.selector) == selector: + return variable + return variable.model_copy(update={"selector": selector}) + + +def build_bootstrap_variables( + *, + system_variables: Sequence[Variable] = (), + environment_variables: Sequence[Variable] = (), + conversation_variables: Sequence[Variable] = (), + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (), +) -> list[Variable]: + variables = [ + *(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables), + *(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables), + *(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables), + ] + + rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_var in rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + rag_pipeline_variables_map[node_id][key] = rag_var.value + + for node_id, value in rag_pipeline_variables_map.items(): + variables.append( + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, + ), + ) + ) + + return variables + + +def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None: + return variable_pool.get(system_variable_selector(key)) + + +def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any: + segment = get_system_segment(variable_pool, key) + return None if segment is None else segment.value + + +def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None: + segment = get_system_segment(variable_pool, key) + if segment is None: + return None + text = getattr(segment, "text", None) + return text if isinstance(text, str) else None + + +def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]: + return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) + + +_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset( + ( + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + ) +) + + +def get_node_creation_preload_selectors( + *, + node_type: str, + node_data: object, +) -> tuple[tuple[str, str], ...]: + """Return selectors that must exist before node construction begins.""" + + if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None: + return () + + return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),) + + +def preload_node_creation_variables( + *, + variable_loader: _VariableLoader, + variable_pool: _VariablePoolWriter, + selectors: Sequence[Sequence[str]], +) -> None: + """Load constructor-time variables before node or graph creation.""" + + seen_selectors: set[tuple[str, ...]] = set() + selectors_to_load: list[list[str]] = [] + for selector in selectors: + normalized_selector = tuple(selector) + if len(normalized_selector) < 2: + raise ValueError(f"Invalid preload selector: {selector}") + if normalized_selector in seen_selectors: + continue + seen_selectors.add(normalized_selector) + if variable_pool.get(normalized_selector) is None: + selectors_to_load.append(list(normalized_selector)) + + loaded_variables = variable_loader.load_variables(selectors_to_load) + for variable in loaded_variables: + raw_selector = getattr(variable, "selector", ()) + loaded_selector = list(raw_selector) + if len(loaded_selector) < 2: + raise ValueError(f"Invalid loaded variable selector: {raw_selector}") + variable_pool.add(loaded_selector[:2], variable) + + +def inject_default_system_variable_mappings( + *, + node_id: str, + node_type: str, + node_data: object, + variable_mapping: Mapping[str, Sequence[str]], +) -> Mapping[str, Sequence[str]]: + """Add workflow-owned implicit sys mappings that `graphon` should not know about.""" + + if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None: + return variable_mapping + + query_mapping_key = f"{node_id}.#sys.query#" + if query_mapping_key in variable_mapping: + return variable_mapping + + augmented_mapping = dict(variable_mapping) + augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY) + return augmented_mapping diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py new file mode 100644 index 0000000000..b4ffb37549 --- /dev/null +++ b/api/core/workflow/template_rendering.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=variables, + ) + except Exception as exc: + if isinstance(exc, CodeExecutionError): + raise TemplateRenderError(str(exc)) from exc + raise + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/variable_pool_initializer.py b/api/core/workflow/variable_pool_initializer.py new file mode 100644 index 0000000000..43523e01b2 --- /dev/null +++ b/api/core/workflow/variable_pool_initializer.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable + + +def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Variable]) -> None: + for variable in variables: + variable_pool.add(variable.selector, variable) + + +def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None: + for key, value in inputs.items(): + variable_pool.add((node_id, key), value) diff --git a/api/dify_graph/constants.py b/api/core/workflow/variable_prefixes.py similarity index 100% rename from api/dify_graph/constants.py rename to api/core/workflow/variable_prefixes.py diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06bab..7429c95c7c 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,36 +1,44 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any from configs import dify_config +from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class -from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file.models import File -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.system_variables import ( + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file.models import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_engine.protocols.command_channel import CommandChannel +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: @@ -59,16 +67,22 @@ class _WorkflowChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + """Build a child engine with a fresh runtime state and only child-safe layers.""" + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, ) + graph_config = graph_init_params.graph_config has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) if has_root_node is False: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -79,17 +93,17 @@ class _WorkflowChildEngineBuilder: root_node_id=root_node_id, ) + command_channel = InMemoryChannel() + config = GraphEngineConfig() child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), + graph_runtime_state=child_graph_runtime_state, + command_channel=command_channel, + config=config, child_engine_builder=self, ) child_engine.layer(LLMQuotaLayer()) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -136,6 +150,8 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + execution_context = capture_current_context() + graph_runtime_state.execution_context = execution_context self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, @@ -212,6 +228,8 @@ class WorkflowEntry: # Get node type node_type = node_config_data.type + node_version = str(node_config_data.version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -226,15 +244,23 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) + + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + + preload_node_creation_variables( + variable_loader=variable_loader, + variable_pool=variable_pool, + selectors=get_node_creation_preload_selectors( + node_type=node_type, + node_data=node_config_data, + ), ) - node = node_factory.create_node(node_config) - node_cls = type(node) try: # variable selector to variable mapping @@ -243,6 +269,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_config_data, + variable_mapping=variable_mapping, + ) # Loading missing variable from draft var here, and set it into # variable_pool. @@ -260,6 +292,13 @@ class WorkflowEntry: tenant_id=workflow.tenant_id, ) + # init workflow run state + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node = node_factory.create_node(node_config) + try: generator = cls._traced_node_run(node) except Exception as e: @@ -347,11 +386,8 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, default_system_variables()) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -366,7 +402,11 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) @@ -384,6 +424,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_data, + variable_mapping=variable_mapping, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -477,13 +523,21 @@ class WorkflowEntry: continue if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: - input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mapping( + mapping=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) if ( isinstance(input_value, list) and all(isinstance(item, dict) for item in input_value) and all("type" in item and "transfer_method" in item for item in input_value) ): - input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mappings( + mappings=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) # append variable and value to variable pool if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: diff --git a/api/core/workflow/workflow_run_outputs.py b/api/core/workflow/workflow_run_outputs.py new file mode 100644 index 0000000000..bd89f7c441 --- /dev/null +++ b/api/core/workflow/workflow_run_outputs.py @@ -0,0 +1,18 @@ +from collections.abc import Mapping +from typing import Any + +from graphon.enums import BuiltinNodeTypes, NodeType + + +def project_node_outputs_for_workflow_run( + *, + node_type: NodeType, + inputs: Mapping[str, Any], + outputs: Mapping[str, Any], +) -> dict[str, Any]: + """Project internal node outputs onto the workflow-run public contract.""" + + if node_type == BuiltinNodeTypes.START: + return dict(inputs) + + return dict(outputs) diff --git a/api/dify_graph/context/__init__.py b/api/dify_graph/context/__init__.py deleted file mode 100644 index 103f526bec..0000000000 --- a/api/dify_graph/context/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Execution Context - Context management for workflow execution. - -This package provides Flask-independent context management for workflow -execution in multi-threaded environments. -""" - -from dify_graph.context.execution_context import ( - AppContext, - ContextProviderNotFoundError, - ExecutionContext, - IExecutionContext, - NullAppContext, - capture_current_context, - read_context, - register_context, - register_context_capturer, - reset_context_provider, -) -from dify_graph.context.models import SandboxContext - -__all__ = [ - "AppContext", - "ContextProviderNotFoundError", - "ExecutionContext", - "IExecutionContext", - "NullAppContext", - "SandboxContext", - "capture_current_context", - "read_context", - "register_context", - "register_context_capturer", - "reset_context_provider", -] diff --git a/api/dify_graph/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py deleted file mode 100644 index 17b19f2502..0000000000 --- a/api/dify_graph/conversation_variable_updater.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Protocol - -from dify_graph.variables import VariableBase - - -class ConversationVariableUpdater(Protocol): - """ - ConversationVariableUpdater defines an abstraction for updating conversation variable values. - - It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating - conversation variables. - - Implementations may choose to batch updates. If batching is used, the `flush` method - should be implemented to persist buffered changes, and `update` - should handle buffering accordingly. - - Note: Since implementations may buffer updates, instances of ConversationVariableUpdater - are not thread-safe. Each VariableAssignerNode should create its own instance during execution. - """ - - @abc.abstractmethod - def update(self, conversation_id: str, variable: "VariableBase"): - """ - Updates the value of the specified conversation variable in the underlying storage. - - :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `VariableBase` instance containing the updated value. - """ - pass - - @abc.abstractmethod - def flush(self): - """ - Flushes all pending updates to the underlying storage system. - - If the implementation does not buffer updates, this method can be a no-op. - """ - pass diff --git a/api/dify_graph/file/constants.py b/api/dify_graph/file/constants.py deleted file mode 100644 index 0665ed7e0d..0000000000 --- a/api/dify_graph/file/constants.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/dify_graph/file/helpers.py b/api/dify_graph/file/helpers.py deleted file mode 100644 index 310cb1310b..0000000000 --- a/api/dify_graph/file/helpers.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -import base64 -import hashlib -import hmac -import os -import time -import urllib.parse - -from .runtime import get_workflow_file_runtime - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url) - url = f"{base_url}/files/{upload_file_id}/file-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign} - if as_attachment: - query["as_attachment"] = "true" - query_string = urllib.parse.urlencode(query) - - return f"{url}?{query_string}" - - -def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - runtime = get_workflow_file_runtime() - # Plugin access should use internal URL for Docker network communication. - base_url = runtime.internal_files_url or runtime.files_url - url = f"{base_url}/files/upload/for-plugin" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) - - -def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str -) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout diff --git a/api/dify_graph/file/protocols.py b/api/dify_graph/file/protocols.py deleted file mode 100644 index 24cbb42735..0000000000 --- a/api/dify_graph/file/protocols.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import Protocol - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``dify_graph.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def files_url(self) -> str: ... - - @property - def internal_files_url(self) -> str | None: ... - - @property - def secret_key(self) -> str: ... - - @property - def files_access_timeout(self) -> int: ... - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... diff --git a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 5fa3d1634b..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,45 +0,0 @@ -import time - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :param user: unique user id - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_moderation( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index e69069a85d..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import IO - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :param user: unique user id - :return: text for given audio file - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_speech_to_text( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index de0677a348..0000000000 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,387 +0,0 @@ -from __future__ import annotations - -import hashlib -import logging -from collections.abc import Sequence -from threading import Lock - -from pydantic import ValidationError -from redis import RedisError - -import contexts -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel -from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) -from extensions.ext_redis import redis_client -from models.provider_ids import ModelProviderID - -logger = logging.getLogger(__name__) - - -class ModelProviderFactory: - def __init__(self, tenant_id: str): - from core.plugin.impl.model import PluginModelClient - - self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelClient() - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers - :return: list of providers - """ - # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server - # The plugin server should return providers in the desired order - plugin_providers = self.get_plugin_model_providers() - return [provider.declaration for provider in plugin_providers] - - def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: - """ - Get all plugin model providers - :return: list of plugin model providers - """ - # check if context is set - try: - contexts.plugin_model_providers.get() - except LookupError: - contexts.plugin_model_providers.set(None) - contexts.plugin_model_providers_lock.set(Lock()) - - with contexts.plugin_model_providers_lock.get(): - plugin_model_providers = contexts.plugin_model_providers.get() - if plugin_model_providers is not None: - return plugin_model_providers - - plugin_model_providers = [] - contexts.plugin_model_providers.set(plugin_model_providers) - - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider - plugin_model_providers.append(provider) - - return plugin_model_providers - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema - :param provider: provider name - :return: provider schema - """ - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - return plugin_model_provider_entity.declaration - - def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: - """ - Get plugin model provider - :param provider: provider name - :return: provider schema - """ - if "/" not in provider: - provider = str(ModelProviderID(provider)) - - # fetch plugin model providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # get the provider - plugin_model_provider_entity = next( - (p for p in plugin_model_provider_entities if p.declaration.provider == provider), - None, - ) - - if not plugin_model_provider_entity: - raise ValueError(f"Invalid provider: {provider}") - - return plugin_model_provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials - - :param provider: provider name - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get provider_credential_schema and validate credentials according to the rules - provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - # validate provider credential schema - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # validate the credentials, raise exception if validation failed - self.plugin_model_manager.validate_provider_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials - - :param provider: provider name - :param model_type: model type - :param model: model name - :param credentials: model credentials, credentials form defined in `model_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get model_credential_schema and validate credentials according to the rules - model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - # validate model credential schema - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # call validate_credentials method of model type to validate credentials, raise exception if validation failed - self.plugin_model_manager.validate_model_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - model_type=model_type.value, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, - ) - - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type - - :param provider: provider name - :param model_type: model type - :param provider_configs: list of provider configs - :return: list of models - """ - provider_configs = provider_configs or [] - - # scan all providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # traverse all model_provider_extensions - providers = [] - for plugin_model_provider_entity in plugin_model_provider_entities: - # filter by provider if provider is present - if provider and plugin_model_provider_entity.declaration.provider != provider: - continue - - # get provider schema - provider_schema = plugin_model_provider_entity.declaration - - model_types = provider_schema.supported_model_types - if model_type: - if model_type not in model_types: - continue - - model_types = [model_type] - - all_model_type_models = [] - for model_schema in provider_schema.models: - if model_schema.model_type != model_type: - continue - - all_model_type_models.append(model_schema) - - simple_provider_schema = provider_schema.to_simple_provider() - if model_type: - simple_provider_schema.models = all_model_type_models - - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type - :param provider: provider name - :param model_type: model type - :return: model type instance - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - init_params = { - "tenant_id": self.tenant_id, - "plugin_id": plugin_id, - "provider_name": provider_name, - "plugin_model_provider": self.get_plugin_model_provider(provider), - } - - if model_type == ModelType.LLM: - return LargeLanguageModel.model_validate(init_params) - elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel.model_validate(init_params) - elif model_type == ModelType.RERANK: - return RerankModel.model_validate(init_params) - elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel.model_validate(init_params) - elif model_type == ModelType.MODERATION: - return ModerationModel.model_validate(init_params) - elif model_type == ModelType.TTS: - return TTSModel.model_validate(init_params) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon - :param provider: provider name - :param icon_type: icon type (icon_small or icon_small_dark) - :param lang: language (zh_Hans or en_US) - :return: provider icon - """ - # get the provider schema - provider_schema = self.get_provider_schema(provider) - - if icon_type.lower() == "icon_small": - if not provider_schema.icon_small: - raise ValueError(f"Provider {provider} does not have small icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small.zh_Hans - else: - file_name = provider_schema.icon_small.en_US - elif icon_type.lower() == "icon_small_dark": - if not provider_schema.icon_small_dark: - raise ValueError(f"Provider {provider} does not have small dark icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small_dark.zh_Hans - else: - file_name = provider_schema.icon_small_dark.en_US - else: - raise ValueError(f"Unsupported icon type: {icon_type}.") - - if not file_name: - raise ValueError(f"Provider {provider} does not have icon.") - - image_mime_types = { - "jpg": "image/jpeg", - "jpeg": "image/jpeg", - "png": "image/png", - "gif": "image/gif", - "bmp": "image/bmp", - "tiff": "image/tiff", - "tif": "image/tiff", - "webp": "image/webp", - "svg": "image/svg+xml", - "ico": "image/vnd.microsoft.icon", - "heif": "image/heif", - "heic": "image/heic", - } - - extension = file_name.split(".")[-1] - mime_type = image_mime_types.get(extension, "image/png") - - # get icon bytes from plugin asset manager - from core.plugin.impl.asset import PluginAssetManager - - plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type - - def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: - """ - Get plugin id and provider name from provider name - :param provider: provider name - :return: plugin id and provider name - """ - - provider_id = ModelProviderID(provider) - return provider_id.plugin_id, provider_id.provider_name diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py deleted file mode 100644 index 0223149bb8..0000000000 --- a/api/dify_graph/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py deleted file mode 100644 index 2a33b4a0a8..0000000000 --- a/api/dify_graph/nodes/human_input/entities.py +++ /dev/null @@ -1,424 +0,0 @@ -""" -Human Input node entities. -""" - -import re -import uuid -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self - -import bleach -import markdown -from pydantic import BaseModel, Field, field_validator, model_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool -from dify_graph.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") - _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ - "a", - "blockquote", - "br", - "code", - "em", - "h1", - "h2", - "h3", - "h4", - "h5", - "h6", - "hr", - "li", - "ol", - "p", - "pre", - "strong", - "table", - "tbody", - "td", - "th", - "thead", - "tr", - "ul", - ] - _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { - "a": ["href", "title"], - "td": ["align"], - "th": ["align"], - } - _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": - if user_id is None: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - @classmethod - def render_markdown_body(cls, body: str) -> str: - """Render markdown to safe HTML for email delivery.""" - sanitized_markdown = bleach.clean( - body, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - rendered_html = markdown.markdown( - sanitized_markdown, - extensions=["nl2br", "tables"], - extension_configs={"tables": {"use_align_attribute": True}}, - ) - return bleach.clean( - rendered_html, - tags=cls._ALLOWED_HTML_TAGS, - attributes=cls._ALLOWED_HTML_ATTRIBUTES, - protocols=cls._ALLOWED_PROTOCOLS, - strip=True, - strip_comments=True, - ) - - @classmethod - def sanitize_subject(cls, subject: str) -> str: - """Sanitize email subject to plain text and prevent CRLF injection.""" - sanitized_subject = bleach.clean( - subject, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) - return " ".join(sanitized_subject.split()) - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str | None, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id) - return method.model_copy(update={"config": debug_config}) - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py deleted file mode 100644 index 9e95d341c9..0000000000 --- a/api/dify_graph/nodes/llm/protocols.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from core.model_manager import ModelInstance - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating initialized LLM model instances for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: - """Create a model instance that is ready for schema lookup and invocation.""" - ... - - -class TemplateRenderer(Protocol): - """Port for rendering prompt templates used by LLM-compatible nodes.""" - - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - """Render the given Jinja2 template into plain text.""" - ... diff --git a/api/dify_graph/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py deleted file mode 100644 index 9b679d4497..0000000000 --- a/api/dify_graph/nodes/template_transform/template_renderer.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage - - -class TemplateRenderError(ValueError): - """Raised when rendering a Jinja2 template fails.""" - - -class Jinja2TemplateRenderer(Protocol): - """Render Jinja2 templates for template transform nodes.""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render a Jinja2 template with provided variables.""" - raise NotImplementedError - - -class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via CodeExecutor.""" - - _code_executor: WorkflowCodeExecutor - - def __init__(self, code_executor: WorkflowCodeExecutor) -> None: - self._code_executor = code_executor - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - try: - result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) - except Exception as exc: - if self._code_executor.is_execution_error(exc): - raise TemplateRenderError(str(exc)) from exc - raise - - rendered = result.get("result") - if not isinstance(rendered, str): - raise TemplateRenderError("Template render result must be a string.") - return rendered diff --git a/api/dify_graph/repositories/__init__.py b/api/dify_graph/repositories/__init__.py deleted file mode 100644 index ef70eb09cc..0000000000 --- a/api/dify_graph/repositories/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Repository interfaces for data access. - -This package contains repository interfaces that define the contract -for accessing and manipulating data, regardless of the underlying -storage mechanism. -""" - -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository - -__all__ = [ - "OrderConfig", - "WorkflowNodeExecutionRepository", -] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py deleted file mode 100644 index 88966831cb..0000000000 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/dify_graph/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py deleted file mode 100644 index ef83f07649..0000000000 --- a/api/dify_graph/repositories/workflow_execution_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Protocol - -from dify_graph.entities import WorkflowExecution - - -class WorkflowExecutionRepository(Protocol): - """ - Repository interface for WorkflowExecution. - - This interface defines the contract for accessing and manipulating - WorkflowExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowExecution): - """ - Save or update a WorkflowExecution instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The WorkflowExecution instance to save or update - """ - ... diff --git a/api/dify_graph/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py deleted file mode 100644 index e6c1c3e497..0000000000 --- a/api/dify_graph/repositories/workflow_node_execution_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Literal, Protocol - -from dify_graph.entities import WorkflowNodeExecution - - -@dataclass -class OrderConfig: - """Configuration for ordering NodeExecution instances.""" - - order_by: list[str] - order_direction: Literal["asc", "desc"] | None = None - - -class WorkflowNodeExecutionRepository(Protocol): - """ - Repository interface for NodeExecution. - - This interface defines the contract for accessing and manipulating - NodeExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and trigger sources (triggered_from) should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowNodeExecution): - """ - Save or update a NodeExecution instance. - - This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`, - and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time - and execution-related details. - - It's main purpose is to save the status and various metadata (execution time, execution metadata etc.) - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The NodeExecution instance to save or update - """ - ... - - def save_execution_data(self, execution: WorkflowNodeExecution): - """Save or update the inputs, process_data, or outputs associated with a specific - node_execution record. - - If any of the inputs, process_data, or outputs are None, those fields will not be updated. - """ - ... - - def get_by_workflow_run( - self, - workflow_run_id: str, - order_config: OrderConfig | None = None, - ) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - order_config: Optional configuration for ordering results - order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) - order_config.order_direction: Direction to order ("asc" or "desc") - - Returns: - A list of NodeExecution instances - """ - ... diff --git a/api/dify_graph/system_variable.py b/api/dify_graph/system_variable.py deleted file mode 100644 index cc5deda892..0000000000 --- a/api/dify_graph/system_variable.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from types import MappingProxyType -from typing import Any -from uuid import uuid4 - -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator - -from dify_graph.enums import SystemVariableKey -from dify_graph.file.models import File - - -class SystemVariable(BaseModel): - """A model for managing system variables. - - Fields with a value of `None` are treated as absent and will not be included - in the variable pool. - """ - - model_config = ConfigDict( - extra="forbid", - serialize_by_alias=True, - validate_by_alias=True, - ) - - user_id: str | None = None - - # Ideally, `app_id` and `workflow_id` should be required and not `None`. - # However, there are scenarios in the codebase where these fields are not set. - # To maintain compatibility, they are marked as optional here. - app_id: str | None = None - workflow_id: str | None = None - - timestamp: int | None = None - - files: Sequence[File] = Field(default_factory=list) - - # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. - # To maintain compatibility with existing workflows, it must be serialized - # as `workflow_run_id` in dictionaries or JSON objects, and also referenced - # as `workflow_run_id` in the variable pool. - workflow_execution_id: str | None = Field( - validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), - serialization_alias="workflow_run_id", - default=None, - ) - # Chatflow related fields. - query: str | None = None - conversation_id: str | None = None - dialogue_count: int | None = None - document_id: str | None = None - original_document_id: str | None = None - dataset_id: str | None = None - batch: str | None = None - datasource_type: str | None = None - datasource_info: Mapping[str, Any] | None = None - invoke_from: str | None = None - - @model_validator(mode="before") - @classmethod - def validate_json_fields(cls, data): - if isinstance(data, dict): - # For JSON validation, only allow workflow_run_id - if "workflow_execution_id" in data and "workflow_run_id" not in data: - # This is likely from direct instantiation, allow it - return data - elif "workflow_execution_id" in data and "workflow_run_id" in data: - # Both present, remove workflow_execution_id - data = data.copy() - data.pop("workflow_execution_id") - return data - return data - - @classmethod - def default(cls) -> SystemVariable: - return cls(workflow_execution_id=str(uuid4())) - - def to_dict(self) -> dict[SystemVariableKey, Any]: - # NOTE: This method is provided for compatibility with legacy code. - # New code should use the `SystemVariable` object directly instead of converting - # it to a dictionary, as this conversion results in the loss of type information - # for each key, making static analysis more difficult. - - d: dict[SystemVariableKey, Any] = { - SystemVariableKey.FILES: self.files, - } - if self.user_id is not None: - d[SystemVariableKey.USER_ID] = self.user_id - if self.app_id is not None: - d[SystemVariableKey.APP_ID] = self.app_id - if self.workflow_id is not None: - d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id - if self.workflow_execution_id is not None: - d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id - if self.query is not None: - d[SystemVariableKey.QUERY] = self.query - if self.conversation_id is not None: - d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id - if self.dialogue_count is not None: - d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count - if self.document_id is not None: - d[SystemVariableKey.DOCUMENT_ID] = self.document_id - if self.original_document_id is not None: - d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id - if self.dataset_id is not None: - d[SystemVariableKey.DATASET_ID] = self.dataset_id - if self.batch is not None: - d[SystemVariableKey.BATCH] = self.batch - if self.datasource_type is not None: - d[SystemVariableKey.DATASOURCE_TYPE] = self.datasource_type - if self.datasource_info is not None: - d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info - if self.invoke_from is not None: - d[SystemVariableKey.INVOKE_FROM] = self.invoke_from - if self.timestamp is not None: - d[SystemVariableKey.TIMESTAMP] = self.timestamp - return d - - def as_view(self) -> SystemVariableReadOnlyView: - return SystemVariableReadOnlyView(self) - - -class SystemVariableReadOnlyView: - """ - A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. - - This class wraps a SystemVariable instance and provides read-only access to all its fields. - It always reads the latest data from the wrapped instance and prevents any write operations. - """ - - def __init__(self, system_variable: SystemVariable) -> None: - """ - Initialize the read-only view with a SystemVariable instance. - - Args: - system_variable: The SystemVariable instance to wrap - """ - self._system_variable = system_variable - - @property - def user_id(self) -> str | None: - return self._system_variable.user_id - - @property - def app_id(self) -> str | None: - return self._system_variable.app_id - - @property - def workflow_id(self) -> str | None: - return self._system_variable.workflow_id - - @property - def workflow_execution_id(self) -> str | None: - return self._system_variable.workflow_execution_id - - @property - def query(self) -> str | None: - return self._system_variable.query - - @property - def conversation_id(self) -> str | None: - return self._system_variable.conversation_id - - @property - def dialogue_count(self) -> int | None: - return self._system_variable.dialogue_count - - @property - def document_id(self) -> str | None: - return self._system_variable.document_id - - @property - def original_document_id(self) -> str | None: - return self._system_variable.original_document_id - - @property - def dataset_id(self) -> str | None: - return self._system_variable.dataset_id - - @property - def batch(self) -> str | None: - return self._system_variable.batch - - @property - def datasource_type(self) -> str | None: - return self._system_variable.datasource_type - - @property - def invoke_from(self) -> str | None: - return self._system_variable.invoke_from - - @property - def files(self) -> Sequence[File]: - """ - Get a copy of the files from the wrapped SystemVariable. - - Returns: - A defensive copy of the files sequence to prevent modification - """ - return tuple(self._system_variable.files) # Convert to immutable tuple - - @property - def datasource_info(self) -> Mapping[str, Any] | None: - """ - Get a copy of the datasource info from the wrapped SystemVariable. - - Returns: - A view of the datasource info mapping to prevent modification - """ - if self._system_variable.datasource_info is None: - return None - return MappingProxyType(self._system_variable.datasource_info) - - def __repr__(self) -> str: - """Return a string representation of the read-only view.""" - return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index c43e99f0f4..ba9758175f 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,10 +1,11 @@ import logging +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) @@ -19,8 +20,9 @@ def handle(sender, **kwargs): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) + provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, + provider_type=provider_type, provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, @@ -30,7 +32,7 @@ def handle(sender, **kwargs): tenant_id=app.tenant_id, tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, + provider_type=provider_type, identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 20852b818e..6769b94cde 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -3,9 +3,9 @@ from typing import cast from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.nodes import BuiltinNodeTypes from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 7b6a73af52..367a4c1ede 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -2,7 +2,7 @@ import ssl from datetime import timedelta from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 9a34acb0c1..120febecfb 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -10,7 +10,7 @@ def init_app(app: DifyApp): from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from dify_graph.model_runtime.errors.invoke import InvokeRateLimitError + from graphon.model_runtime.errors.invoke import InvokeRateLimitError def before_send(event, hint): if "exc_info" in hint: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index a94d75ec76..bdfa984874 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,10 +13,10 @@ from typing import Any from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value +from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index bdfc81bd1c..5208f8f37e 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -22,10 +22,10 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string +from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index c58aa6adbb..ea4a2b3dd1 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -7,11 +7,11 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index d84c0bc432..976b5db8e3 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -17,14 +17,14 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier +from graphon.entities import WorkflowNodeExecution +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, @@ -304,35 +304,39 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) # Don't raise - LogStore write succeeded, SQL is just a backup - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. Uses LogStore SQL query with window function to get the latest version of each node execution. This ensures we only get the most recent version of each node execution record. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of workflow node execution instances Note: This method uses ROW_NUMBER() window function partitioned by node_execution_id to get the latest version (highest log_version) of each node execution. """ - logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + logger.debug( + "get_by_workflow_execution: workflow_execution_id=%s, order_config=%s", + workflow_execution_id, + order_config, + ) # Build SQL query with deduplication using window function # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) # ensures we get the latest version of each node execution # Escape parameters to prevent SQL injection - escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_workflow_execution_id = escape_identifier(workflow_execution_id) escaped_tenant_id = escape_identifier(self._tenant_id) # Build ORDER BY clause for outer query @@ -360,7 +364,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{escaped_workflow_run_id}' + WHERE workflow_run_id='{escaped_workflow_execution_id}' AND tenant_id='{escaped_tenant_id}' {app_id_filter} ) t @@ -391,5 +395,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): return executions except Exception: - logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + logger.exception( + "Failed to retrieve node executions from LogStore: workflow_execution_id=%s", + workflow_execution_id, + ) raise diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 544ef3fe18..a2f552cac1 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -9,12 +9,12 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.file.models import File -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +from graphon.enums import BuiltinNodeTypes +from graphon.file.models import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 3da9a9e97d..ec3c78a12d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -8,10 +8,10 @@ from typing import Any from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index dd658b250b..56672d1fd4 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -8,11 +8,11 @@ from typing import Any from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index f4e6a18b4d..75ddbba448 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -4,12 +4,12 @@ Parser for tool nodes that captures tool-specific metadata. from opentelemetry.trace import Span -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData class ToolNodeOTelParser: diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py deleted file mode 100644 index cb07ba58ae..0000000000 --- a/api/factories/file_factory.py +++ /dev/null @@ -1,618 +0,0 @@ -import logging -import mimetypes -import os -import re -import urllib.parse -import uuid -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import httpx -from sqlalchemy import select -from sqlalchemy.orm import Session -from werkzeug.http import parse_options_header - -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.helper import ssrf_proxy -from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers -from extensions.ext_database import db -from models import MessageFile, ToolFile, UploadFile - -logger = logging.getLogger(__name__) - - -def build_from_message_files( - *, - message_files: Sequence["MessageFile"], - tenant_id: str, - config: FileUploadConfig | None = None, -) -> Sequence[File]: - results = [ - build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) - for file in message_files - if file.belongs_to != FileBelongsTo.ASSISTANT - ] - return results - - -def build_from_message_file( - *, - message_file: "MessageFile", - tenant_id: str, - config: FileUploadConfig | None, -): - mapping = { - "transfer_method": message_file.transfer_method, - "url": message_file.url, - "type": message_file.type, - } - - # Only include id if it exists (message_file has been committed to DB) - if message_file.id: - mapping["id"] = message_file.id - - # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE: - mapping["tool_file_id"] = message_file.upload_file_id - else: - mapping["upload_file_id"] = message_file.upload_file_id - - return build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - - -def build_from_mapping( - *, - mapping: Mapping[str, Any], - tenant_id: str, - config: FileUploadConfig | None = None, - strict_type_validation: bool = False, -) -> File: - transfer_method_value = mapping.get("transfer_method") - if not transfer_method_value: - raise ValueError("transfer_method is required in file mapping") - transfer_method = FileTransferMethod.value_of(transfer_method_value) - - build_functions: dict[FileTransferMethod, Callable] = { - FileTransferMethod.LOCAL_FILE: _build_from_local_file, - FileTransferMethod.REMOTE_URL: _build_from_remote_url, - FileTransferMethod.TOOL_FILE: _build_from_tool_file, - FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, - } - - build_func = build_functions.get(transfer_method) - if not build_func: - raise ValueError(f"Invalid file transfer method: {transfer_method}") - - file: File = build_func( - mapping=mapping, - tenant_id=tenant_id, - transfer_method=transfer_method, - strict_type_validation=strict_type_validation, - ) - - if config and not _is_file_valid_with_config( - input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension or "", - file_transfer_method=file.transfer_method, - config=config, - ): - raise ValueError(f"File validation failed for file: {file.filename}") - - return file - - -def build_from_mappings( - *, - mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None = None, - tenant_id: str, - strict_type_validation: bool = False, -) -> Sequence[File]: - # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. - # Implement batch processing to reduce database load when handling multiple files. - # Filter out None/empty mappings to avoid errors - def is_valid_mapping(m: Mapping[str, Any]) -> bool: - if not m or not m.get("transfer_method"): - return False - # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None - transfer_method = m.get("transfer_method") - if transfer_method == FileTransferMethod.REMOTE_URL: - url = m.get("url") or m.get("remote_url") - if not url: - return False - return True - - valid_mappings = [m for m in mappings if is_valid_mapping(m)] - files = [ - build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - strict_type_validation=strict_type_validation, - ) - for mapping in valid_mappings - ] - - if ( - config - # If image config is set. - and config.image_config - # And the number of image files exceeds the maximum limit - and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits - ): - raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config and config.number_limits and len(files) > config.number_limits: - raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") - - return files - - -def _build_from_local_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if not upload_file_id: - raise ValueError("Invalid upload file id") - # check if upload_file_id is a valid uuid - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - row = db.session.scalar(stmt) - if row is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - specified_type = mapping.get("type", "custom") - - if strict_type_validation and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - related_id=mapping.get("upload_file_id"), - size=row.size, - storage_key=row.key, - ) - - -def _build_from_remote_url( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if upload_file_id: - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - upload_file = db.session.scalar(stmt) - if upload_file is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type( - extension="." + upload_file.extension, mime_type=upload_file.mime_type - ) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - related_id=mapping.get("upload_file_id"), - size=upload_file.size, - storage_key=upload_file.key, - ) - url = mapping.get("url") or mapping.get("remote_url") - if not url: - raise ValueError("Invalid file url") - - mime_type, filename, file_size = _get_remote_file_info(url) - extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - - detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=url, - mime_type=mime_type, - extension=extension, - size=file_size, - storage_key="", - ) - - -def _extract_filename(url_path: str, content_disposition: str | None) -> str | None: - filename: str | None = None - # Try to extract from Content-Disposition header first - if content_disposition: - # Manually extract filename* parameter since parse_options_header doesn't support it - filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) - if filename_star_match: - raw_star = filename_star_match.group(1).strip() - # Remove trailing quotes if present - raw_star = raw_star.removesuffix('"') - # format: charset'lang'value - try: - parts = raw_star.split("'", 2) - charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" - value = parts[2] if len(parts) == 3 else parts[-1] - filename = urllib.parse.unquote(value, encoding=charset, errors="replace") - except Exception: - # Fallback: try to extract value after the last single quote - if "''" in raw_star: - filename = urllib.parse.unquote(raw_star.split("''")[-1]) - else: - filename = urllib.parse.unquote(raw_star) - - if not filename: - # Fallback to regular filename parameter - _, params = parse_options_header(content_disposition) - raw = params.get("filename") - if raw: - # Strip surrounding quotes and percent-decode if present - if len(raw) >= 2 and raw[0] == raw[-1] == '"': - raw = raw[1:-1] - filename = urllib.parse.unquote(raw) - # Fallback to URL path if no filename from header - if not filename: - candidate = os.path.basename(url_path) - filename = urllib.parse.unquote(candidate) if candidate else None - # Defense-in-depth: ensure basename only - if filename: - filename = os.path.basename(filename) - # Return None if filename is empty or only whitespace - if not filename or not filename.strip(): - filename = None - return filename or None - - -def _guess_mime_type(filename: str) -> str: - """Guess MIME type from filename, returning empty string if None.""" - guessed_mime, _ = mimetypes.guess_type(filename) - return guessed_mime or "" - - -def _get_remote_file_info(url: str): - file_size = -1 - parsed_url = urllib.parse.urlparse(url) - url_path = parsed_url.path - filename = os.path.basename(url_path) - - # Initialize mime_type from filename as fallback - mime_type = _guess_mime_type(filename) - - resp = ssrf_proxy.head(url, follow_redirects=True) - if resp.status_code == httpx.codes.OK: - content_disposition = resp.headers.get("Content-Disposition") - extracted_filename = _extract_filename(url_path, content_disposition) - if extracted_filename: - filename = extracted_filename - mime_type = _guess_mime_type(filename) - file_size = int(resp.headers.get("Content-Length", file_size)) - # Fallback to Content-Type header if mime_type is still empty - if not mime_type: - mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() - - if not filename: - extension = mimetypes.guess_extension(mime_type) or ".bin" - filename = f"{uuid.uuid4().hex}{extension}" - if not mime_type: - mime_type = _guess_mime_type(filename) - - return mime_type, filename, file_size - - -def _build_from_tool_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - # Backward/interop compatibility: allow tool_file_id to come from related_id or URL - tool_file_id = mapping.get("tool_file_id") - - if not tool_file_id: - raise ValueError(f"ToolFile {tool_file_id} not found") - tool_file = db.session.scalar( - select(ToolFile).where( - ToolFile.id == tool_file_id, - ToolFile.tenant_id == tenant_id, - ) - ) - - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") - - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - - detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - tenant_id=tenant_id, - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) - - -def _build_from_datasource_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - datasource_file_id = mapping.get("datasource_file_id") - if not datasource_file_id: - raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = db.session.scalar( - select(UploadFile).where( - UploadFile.id == datasource_file_id, - UploadFile.tenant_id == tenant_id, - ) - ) - - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("datasource_file_id"), - tenant_id=tenant_id, - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - related_id=datasource_file.id, - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) - - -def _is_file_valid_with_config( - *, - input_file_type: str, - file_extension: str, - file_transfer_method: FileTransferMethod, - config: FileUploadConfig, -) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions - if file_transfer_method == FileTransferMethod.TOOL_FILE: - return True - - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): - return False - - if ( - input_file_type == FileType.CUSTOM - and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions - ): - return False - - if input_file_type == FileType.IMAGE: - if ( - config.image_config - and config.image_config.transfer_methods - and file_transfer_method not in config.image_config.transfer_methods - ): - return False - elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: - return False - - return True - - -def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the possible actual type of the file based on the extension and mime_type - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = _get_file_type_by_mimetype(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - extension = extension.lstrip(".") - if extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - elif extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - elif extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - elif extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM - - -class StorageKeyLoader: - """FileKeyLoader load the storage key from database for a list of files. - This loader is batched, the database query count is constant regardless of the input size. - """ - - def __init__(self, session: Session, tenant_id: str): - self._session = session - self._tenant_id = tenant_id - - def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: - stmt = select(UploadFile).where( - UploadFile.id.in_(upload_file_ids), - UploadFile.tenant_id == self._tenant_id, - ) - - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: - stmt = select(ToolFile).where( - ToolFile.id.in_(tool_file_ids), - ToolFile.tenant_id == self._tenant_id, - ) - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def load_storage_keys(self, files: Sequence[File]): - """Loads storage keys for a sequence of files by retrieving the corresponding - `UploadFile` or `ToolFile` records from the database based on their transfer method. - - This method doesn't modify the input sequence structure but updates the `_storage_key` - property of each file object by extracting the relevant key from its database record. - - Performance note: This is a batched operation where database query count remains constant - regardless of input size. However, for optimal performance, input sequences should contain - fewer than 1000 files. For larger collections, split into smaller batches and process each - batch separately. - """ - - upload_file_ids: list[uuid.UUID] = [] - tool_file_ids: list[uuid.UUID] = [] - for file in files: - related_model_id = file.related_id - if file.related_id is None: - raise ValueError("file id should not be None.") - if file.tenant_id != self._tenant_id: - err_msg = ( - f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" - ) - raise ValueError(err_msg) - model_id = uuid.UUID(related_model_id) - - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(model_id) - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(model_id) - - tool_files = self._load_tool_files(tool_file_ids) - upload_files = self._load_upload_files(upload_file_ids) - for file in files: - model_id = uuid.UUID(file.related_id) - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(model_id) - if upload_file_row is None: - raise ValueError(f"Upload file not found for id: {model_id}") - file.storage_key = upload_file_row.key - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(model_id) - if tool_file_row is None: - raise ValueError(f"Tool file not found for id: {model_id}") - file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/__init__.py b/api/factories/file_factory/__init__.py new file mode 100644 index 0000000000..ae0cd972ec --- /dev/null +++ b/api/factories/file_factory/__init__.py @@ -0,0 +1,18 @@ +"""Workflow file factory package. + +This package normalizes workflow-layer file payloads into graph-layer ``File`` +values. It keeps tenancy and ownership checks in the application layer and +exports the workflow-facing file builders for callers. +""" + +from .builders import build_from_mapping, build_from_mappings +from .message_files import build_from_message_file, build_from_message_files +from .storage_keys import StorageKeyLoader + +__all__ = [ + "StorageKeyLoader", + "build_from_mapping", + "build_from_mappings", + "build_from_message_file", + "build_from_message_files", +] diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py new file mode 100644 index 0000000000..bc87510d43 --- /dev/null +++ b/api/factories/file_factory/builders.py @@ -0,0 +1,329 @@ +"""Core builders for workflow file mappings.""" + +from __future__ import annotations + +import mimetypes +import uuid +from collections.abc import Mapping, Sequence +from typing import Any + +from sqlalchemy import select + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers +from graphon.file.file_factory import standardize_file_type +from models import ToolFile, UploadFile + +from .common import resolve_mapping_file_id +from .remote import get_remote_file_info +from .validation import is_file_valid_with_config + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + transfer_method_value = mapping.get("transfer_method") + if not transfer_method_value: + raise ValueError("transfer_method is required in file mapping") + + transfer_method = FileTransferMethod.value_of(transfer_method_value) + build_func = _get_build_function(transfer_method) + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + + if config and not is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension or "", + file_transfer_method=file.transfer_method, + config=config, + ): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None = None, + tenant_id: str, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. + valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)] + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + for mapping in valid_mappings + ] + + if ( + config + and config.image_config + and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config and config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _get_build_function(transfer_method: FileTransferMethod): + build_functions = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, + } + build_func = build_functions.get(transfer_method) + if build_func is None: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + return build_func + + +def _resolve_file_type( + *, + detected_file_type: FileType, + specified_type: str | None, + strict_type_validation: bool, +) -> FileType: + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + if specified_type and specified_type != "custom": + return FileType(specified_type) + return detected_file_type + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") + + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if upload_file_id: + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) + + url = mapping.get("url") or mapping.get("remote_url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") + detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=filename, + type=file_type, + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id") + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") + + stmt = select(ToolFile).where( + ToolFile.id == tool_file_id, + ToolFile.tenant_id == tenant_id, + ) + tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + + +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") + + stmt = select(UploadFile).where( + UploadFile.id == datasource_file_id, + UploadFile.tenant_id == tenant_id, + ) + datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) + + +def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: + if not mapping or not mapping.get("transfer_method"): + return False + + if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL: + url = mapping.get("url") or mapping.get("remote_url") + if not url: + return False + + return True diff --git a/api/factories/file_factory/common.py b/api/factories/file_factory/common.py new file mode 100644 index 0000000000..2e1c95ab3f --- /dev/null +++ b/api/factories/file_factory/common.py @@ -0,0 +1,27 @@ +"""Shared helpers for workflow file factory modules.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.workflow.file_reference import resolve_file_record_id + + +def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None: + """Resolve historical file identifiers from persisted mapping payloads. + + Workflow and model payloads can outlive file schema changes. Older rows may + still carry concrete identifiers in legacy fields such as ``related_id``, + while newer payloads use opaque references. Keep this compatibility lookup in + the factory layer so historical data remains readable without reintroducing + storage details into graph-layer ``File`` values. + """ + + for key in (*keys, "reference", "related_id"): + raw_value = mapping.get(key) + if isinstance(raw_value, str) and raw_value: + resolved_value = resolve_file_record_id(raw_value) + if resolved_value: + return resolved_value + return None diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py new file mode 100644 index 0000000000..4b3d514238 --- /dev/null +++ b/api/factories/file_factory/message_files.py @@ -0,0 +1,59 @@ +"""Adapters from persisted message files to graph-layer file values.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from core.app.file_access import FileAccessControllerProtocol +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig +from models import MessageFile + +from .builders import build_from_mapping + + +def build_from_message_files( + *, + message_files: Sequence[MessageFile], + tenant_id: str, + config: FileUploadConfig | None = None, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + return [ + build_from_message_file( + message_file=message_file, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) + for message_file in message_files + if message_file.belongs_to != FileBelongsTo.ASSISTANT + ] + + +def build_from_message_file( + *, + message_file: MessageFile, + tenant_id: str, + config: FileUploadConfig | None, + access_controller: FileAccessControllerProtocol, +) -> File: + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "type": message_file.type, + } + + if message_file.id: + mapping["id"] = message_file.id + + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: + mapping["tool_file_id"] = message_file.upload_file_id + else: + mapping["upload_file_id"] = message_file.upload_file_id + + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py new file mode 100644 index 0000000000..e5a7186007 --- /dev/null +++ b/api/factories/file_factory/remote.py @@ -0,0 +1,91 @@ +"""Remote file metadata helpers used by workflow file normalization. + +These helpers are part of the ``factories.file_factory`` package surface +because both workflow builders and tests rely on the same RFC5987 filename +parsing and HEAD-response normalization rules. +""" + +from __future__ import annotations + +import mimetypes +import os +import re +import urllib.parse +import uuid + +import httpx +from werkzeug.http import parse_options_header + +from core.helper import ssrf_proxy + + +def extract_filename(url_path: str, content_disposition: str | None) -> str | None: + """Extract a safe filename from Content-Disposition or the request URL path.""" + filename: str | None = None + if content_disposition: + filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) + if filename_star_match: + raw_star = filename_star_match.group(1).strip() + raw_star = raw_star.removesuffix('"') + try: + parts = raw_star.split("'", 2) + charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" + value = parts[2] if len(parts) == 3 else parts[-1] + filename = urllib.parse.unquote(value, encoding=charset, errors="replace") + except Exception: + if "''" in raw_star: + filename = urllib.parse.unquote(raw_star.split("''")[-1]) + else: + filename = urllib.parse.unquote(raw_star) + + if not filename: + _, params = parse_options_header(content_disposition) + raw = params.get("filename") + if raw: + if len(raw) >= 2 and raw[0] == raw[-1] == '"': + raw = raw[1:-1] + filename = urllib.parse.unquote(raw) + + if not filename: + candidate = os.path.basename(url_path) + filename = urllib.parse.unquote(candidate) if candidate else None + + if filename: + filename = os.path.basename(filename) + if not filename or not filename.strip(): + filename = None + + return filename or None + + +def _guess_mime_type(filename: str) -> str: + guessed_mime, _ = mimetypes.guess_type(filename) + return guessed_mime or "" + + +def get_remote_file_info(url: str) -> tuple[str, str, int]: + """Resolve remote file metadata with SSRF-safe HEAD probing.""" + file_size = -1 + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + mime_type = _guess_mime_type(filename) + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + content_disposition = resp.headers.get("Content-Disposition") + extracted_filename = extract_filename(url_path, content_disposition) + if extracted_filename: + filename = extracted_filename + mime_type = _guess_mime_type(filename) + file_size = int(resp.headers.get("Content-Length", file_size)) + if not mime_type: + mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + + if not filename: + extension = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex}{extension}" + if not mime_type: + mime_type = _guess_mime_type(filename) + + return mime_type, filename, file_size diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py new file mode 100644 index 0000000000..dba4c84407 --- /dev/null +++ b/api/factories/file_factory/storage_keys.py @@ -0,0 +1,106 @@ +"""Batched storage-key hydration for workflow files.""" + +from __future__ import annotations + +import uuid +from collections.abc import Mapping, Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference, parse_file_reference +from graphon.file import File, FileTransferMethod +from models import ToolFile, UploadFile + + +class StorageKeyLoader: + """Load storage keys for files with a constant number of database queries.""" + + _session: Session + _tenant_id: str + _access_controller: FileAccessControllerProtocol + + def __init__( + self, + session: Session, + tenant_id: str, + access_controller: FileAccessControllerProtocol, + ) -> None: + self._session = session + self._tenant_id = tenant_id + self._access_controller = access_controller + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_upload_file_filters(stmt) + return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_tool_file_filters(stmt) + return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)} + + def load_storage_keys(self, files: Sequence[File]) -> None: + """Hydrate storage keys by loading their backing file rows in batches. + + The sequence shape is preserved. Each file is updated in place with a + canonical record reference and storage key loaded from an authorized + database row. Tenant scoping is enforced by this loader's context + rather than by embedding tenant identity or storage paths inside + graph-layer ``File`` values. + + For best performance, prefer batches smaller than 1000 files. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(f"Upload file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(upload_file_row.id), + ) + file.storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(f"Tool file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(tool_file_row.id), + ) + file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py new file mode 100644 index 0000000000..4c4f6150e4 --- /dev/null +++ b/api/factories/file_factory/validation.py @@ -0,0 +1,44 @@ +"""Validation helpers for workflow file inputs.""" + +from __future__ import annotations + +from graphon.file import FileTransferMethod, FileType, FileUploadConfig + + +def is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) + # These are internally generated and should bypass user upload restrictions + if file_transfer_method == FileTransferMethod.TOOL_FILE: + return True + + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): + return False + + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): + return False + + if input_file_type == FileType.IMAGE: + if ( + config.image_config + and config.image_config.transfer_methods + and file_transfer_method not in config.image_config.transfer_methods + ): + return False + elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: + return False + + return True diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 14a56bf4a2..fd7acb14d3 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,75 +1,51 @@ +"""Compatibility factory for non-graph variable bootstrapping. + +Graph runtime segment/variable conversions live under `graphon.variables`. +This module keeps the application-layer mapping helpers and re-exports the +shared conversion functions for legacy callers and tests. +""" + from collections.abc import Mapping, Sequence from typing import Any, cast -from uuid import uuid4 from configs import dify_config -from dify_graph.constants import ( +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from dify_graph.file import File -from dify_graph.variables.exc import VariableError -from dify_graph.variables.segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, +from graphon.variables.exc import VariableError +from graphon.variables.factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import ( - ArrayAnyVariable, +from graphon.variables.types import SegmentType +from graphon.variables.variables import ( ArrayBooleanVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, BooleanVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, StringVariable, VariableBase, ) - -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -# Define the constant -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} +__all__ = [ + "TypeMismatchError", + "UnsupportedSegmentTypeError", + "build_conversation_variable_from_mapping", + "build_environment_variable_from_mapping", + "build_pipeline_variable_from_mapping", + "build_segment", + "build_segment_with_type", + "segment_to_variable", +] def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: @@ -135,172 +111,3 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if not result.selector: result = result.model_copy(update={"selector": selector}) return cast(VariableBase, result) - - -def build_segment(value: Any, /) -> Segment: - # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` - # below - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - elif len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_segment_factory: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - # Array types - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """ - Build a segment with explicit type checking. - - This function creates a segment from a value while enforcing type compatibility - with the specified segment_type. It provides stricter type validation compared - to the standard build_segment function. - - Args: - segment_type: The expected SegmentType for the resulting segment - value: The value to be converted into a segment - - Returns: - Segment: A segment instance of the appropriate type - - Raises: - TypeMismatchError: If the value type doesn't match the expected segment_type - - Special Cases: - - For empty list [] values, if segment_type is array[*], returns the corresponding array type - - Type validation is performed before segment creation - - Examples: - >>> build_segment_with_type(SegmentType.STRING, "hello") - StringSegment(value="hello") - - >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) - ArrayStringSegment(value=[]) - - >>> build_segment_with_type(SegmentType.STRING, 123) - # Raises TypeMismatchError - """ - # Handle None values - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - # Handle empty list special case for array types - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - elif segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - elif segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - elif segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - elif segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - elif segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - # Type compatibility checking - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _segment_factory[segment_type] - return segment_class(value_type=segment_type, value=value) - elif segment_type == SegmentType.NUMBER and inferred_type in ( - SegmentType.INTEGER, - SegmentType.FLOAT, - ): - segment_class = _segment_factory[inferred_type] - return segment_class(value_type=inferred_type, value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value_type=segment.value_type, - value=segment.value, - selector=list(selector), - ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index ac7c5376fb..b5acbbbcb4 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from dify_graph.variables.segments import Segment -from dify_graph.variables.types import SegmentType +from graphon.variables.segments import Segment +from graphon.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index a5c7ddbb11..801949747e 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from dify_graph.file import File +from graphon.file import File JSONValue: TypeAlias = Any @@ -311,7 +311,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValue) -> JSONValue: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 7ee628726b..4e201e66e6 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -5,7 +5,7 @@ from datetime import datetime from flask_restx import fields from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from dify_graph.file import helpers as file_helpers +from graphon.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 428f92ed33..86c4f285cd 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -7,8 +7,8 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel -from dify_graph.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile +from graphon.file import File JSONValueType: TypeAlias = JSONValue @@ -133,7 +133,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValueType) -> JSONValueType: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/raws.py b/api/fields/raws.py index 318dedc25c..ee6f53b360 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,6 +1,6 @@ from flask_restx import fields -from dify_graph.file import File +from graphon.file import File class FilesContainedField(fields.Raw): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 7ce2139687..f9b5e98936 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields from core.helper import encrypter -from dify_graph.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type diff --git a/api/dify_graph/README.md b/api/graphon/README.md similarity index 98% rename from api/dify_graph/README.md rename to api/graphon/README.md index 2fc5b8b890..725f122cd8 100644 --- a/api/dify_graph/README.md +++ b/api/graphon/README.md @@ -114,7 +114,7 @@ The codebase enforces strict layering via import-linter: 1. Inherit from `BaseNode` or appropriate base class 1. Implement `_run()` method 1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/dify_graph/nodes/` +1. Add tests in `tests/unit_tests/graphon/nodes/` ### Implementing a Custom Layer diff --git a/api/dify_graph/__init__.py b/api/graphon/__init__.py similarity index 100% rename from api/dify_graph/__init__.py rename to api/graphon/__init__.py diff --git a/api/dify_graph/entities/__init__.py b/api/graphon/entities/__init__.py similarity index 100% rename from api/dify_graph/entities/__init__.py rename to api/graphon/entities/__init__.py diff --git a/api/dify_graph/entities/base_node_data.py b/api/graphon/entities/base_node_data.py similarity index 98% rename from api/dify_graph/entities/base_node_data.py rename to api/graphon/entities/base_node_data.py index 47b37c9daf..e8267043a9 100644 --- a/api/dify_graph/entities/base_node_data.py +++ b/api/graphon/entities/base_node_data.py @@ -8,8 +8,8 @@ from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, model_validator -from dify_graph.entities.exc import DefaultValueTypeError -from dify_graph.enums import ErrorStrategy, NodeType +from graphon.entities.exc import DefaultValueTypeError +from graphon.enums import ErrorStrategy, NodeType # Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. _NumberType = Union[int, float] diff --git a/api/dify_graph/entities/exc.py b/api/graphon/entities/exc.py similarity index 100% rename from api/dify_graph/entities/exc.py rename to api/graphon/entities/exc.py diff --git a/api/dify_graph/entities/graph_config.py b/api/graphon/entities/graph_config.py similarity index 89% rename from api/dify_graph/entities/graph_config.py rename to api/graphon/entities/graph_config.py index 36f7b94e82..392241c631 100644 --- a/api/dify_graph/entities/graph_config.py +++ b/api/graphon/entities/graph_config.py @@ -4,7 +4,7 @@ import sys from pydantic import TypeAdapter, with_config -from dify_graph.entities.base_node_data import BaseNodeData +from graphon.entities.base_node_data import BaseNodeData if sys.version_info >= (3, 12): from typing import TypedDict diff --git a/api/dify_graph/entities/graph_init_params.py b/api/graphon/entities/graph_init_params.py similarity index 100% rename from api/dify_graph/entities/graph_init_params.py rename to api/graphon/entities/graph_init_params.py diff --git a/api/dify_graph/entities/pause_reason.py b/api/graphon/entities/pause_reason.py similarity index 80% rename from api/dify_graph/entities/pause_reason.py rename to api/graphon/entities/pause_reason.py index 86d8c8ca16..ba2973fd45 100644 --- a/api/dify_graph/entities/pause_reason.py +++ b/api/graphon/entities/pause_reason.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, Field -from dify_graph.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.entities import FormInput, UserAction class PauseReasonType(StrEnum): @@ -18,7 +18,6 @@ class HumanInputRequired(BaseModel): form_content: str inputs: list[FormInput] = Field(default_factory=list) actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False node_id: str node_title: str @@ -33,13 +32,6 @@ class HumanInputRequired(BaseModel): # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None - class SchedulingPause(BaseModel): TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE diff --git a/api/dify_graph/entities/workflow_execution.py b/api/graphon/entities/workflow_execution.py similarity index 79% rename from api/dify_graph/entities/workflow_execution.py rename to api/graphon/entities/workflow_execution.py index 459ac46415..b8de7eed1a 100644 --- a/api/dify_graph/entities/workflow_execution.py +++ b/api/graphon/entities/workflow_execution.py @@ -1,26 +1,23 @@ """ Domain entities for workflow execution. -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. +Models describe graph runtime state and avoid infrastructure-specific details. """ from __future__ import annotations from collections.abc import Mapping -from datetime import datetime +from datetime import UTC, datetime from typing import Any from pydantic import BaseModel, Field -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from libs.datetime_utils import naive_utc_now +from graphon.enums import WorkflowExecutionStatus, WorkflowType class WorkflowExecution(BaseModel): """ - Domain model for workflow execution based on WorkflowRun but without - user, tenant, and app attributes. + Domain model for a workflow execution within the graph runtime. """ id_: str = Field(...) @@ -47,7 +44,7 @@ class WorkflowExecution(BaseModel): Calculate elapsed time in seconds. If workflow is not finished, use current time. """ - end_time = self.finished_at or naive_utc_now() + end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) return (end_time - self.started_at).total_seconds() @classmethod diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/graphon/entities/workflow_node_execution.py similarity index 85% rename from api/dify_graph/entities/workflow_node_execution.py rename to api/graphon/entities/workflow_node_execution.py index bc7e0d02e5..5458572e7e 100644 --- a/api/dify_graph/entities/workflow_node_execution.py +++ b/api/graphon/entities/workflow_node_execution.py @@ -1,9 +1,8 @@ """ Domain entities for workflow node execution. -This module contains the domain model for workflow node execution, which is used -by the core workflow module. These models are independent of the storage mechanism -and don't contain implementation details like tenant_id, app_id, etc. +These models capture node-level execution state for the graph runtime without +describing storage or application-layer concerns. """ from collections.abc import Mapping @@ -12,20 +11,15 @@ from typing import Any from pydantic import BaseModel, Field, PrivateAttr -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class WorkflowNodeExecution(BaseModel): """ Domain model for workflow node execution. - This model represents the core business entity of a node execution, - without implementation details like tenant_id, app_id, etc. - - Note: User/context-specific fields (triggered_from, created_by, created_by_role) - have been moved to the repository implementation to keep the domain model clean. - These fields are still accepted in the constructor for backward compatibility, - but they are not stored in the model. + This model represents the graph-level record of a node execution and + contains only execution state relevant to the runtime. """ # --------- Core identification fields --------- @@ -41,7 +35,7 @@ class WorkflowNodeExecution(BaseModel): # In most scenarios, `id` should be used as the primary identifier. node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow diff --git a/api/dify_graph/entities/workflow_start_reason.py b/api/graphon/entities/workflow_start_reason.py similarity index 100% rename from api/dify_graph/entities/workflow_start_reason.py rename to api/graphon/entities/workflow_start_reason.py diff --git a/api/dify_graph/enums.py b/api/graphon/enums.py similarity index 93% rename from api/dify_graph/enums.py rename to api/graphon/enums.py index f0333b3e1c..c8ee388751 100644 --- a/api/dify_graph/enums.py +++ b/api/graphon/enums.py @@ -10,30 +10,6 @@ class NodeState(StrEnum): SKIPPED = "skipped" -class SystemVariableKey(StrEnum): - """ - System Variables. - """ - - QUERY = "query" - FILES = "files" - CONVERSATION_ID = "conversation_id" - USER_ID = "user_id" - DIALOGUE_COUNT = "dialogue_count" - APP_ID = "app_id" - WORKFLOW_ID = "workflow_id" - WORKFLOW_EXECUTION_ID = "workflow_run_id" - TIMESTAMP = "timestamp" - # RAG Pipeline - DOCUMENT_ID = "document_id" - ORIGINAL_DOCUMENT_ID = "original_document_id" - BATCH = "batch" - DATASET_ID = "dataset_id" - DATASOURCE_TYPE = "datasource_type" - DATASOURCE_INFO = "datasource_info" - INVOKE_FROM = "invoke_from" - - NodeType: TypeAlias = str @@ -41,7 +17,7 @@ class BuiltinNodeTypes: """Built-in node type string constants. `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `dify_graph`; downstream packages can + only exposes the built-in values shipped by `graphon`; downstream packages can use additional strings without extending this class. """ diff --git a/api/dify_graph/errors.py b/api/graphon/errors.py similarity index 89% rename from api/dify_graph/errors.py rename to api/graphon/errors.py index 463d17713e..7eb007524d 100644 --- a/api/dify_graph/errors.py +++ b/api/graphon/errors.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.base.node import Node +from graphon.nodes.base.node import Node class WorkflowNodeRunFailedError(Exception): diff --git a/api/dify_graph/file/__init__.py b/api/graphon/file/__init__.py similarity index 75% rename from api/dify_graph/file/__init__.py rename to api/graphon/file/__init__.py index 44749ebec3..4908ae9795 100644 --- a/api/dify_graph/file/__init__.py +++ b/api/graphon/file/__init__.py @@ -1,5 +1,6 @@ from .constants import FILE_MODEL_IDENTITY from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .file_factory import get_file_type_by_mime_type, standardize_file_type from .models import ( File, FileUploadConfig, @@ -16,4 +17,6 @@ __all__ = [ "FileType", "FileUploadConfig", "ImageConfig", + "get_file_type_by_mime_type", + "standardize_file_type", ] diff --git a/api/graphon/file/constants.py b/api/graphon/file/constants.py new file mode 100644 index 0000000000..56b95b5f0d --- /dev/null +++ b/api/graphon/file/constants.py @@ -0,0 +1,48 @@ +from collections.abc import Iterable +from typing import Any + +# TODO(QuantumGhost): Refactor variable type identification. Instead of directly +# comparing `dify_model_identity` with constants throughout the codebase, extract +# this logic into a dedicated function. This would encapsulate the implementation +# details of how different variable types are identified. +FILE_MODEL_IDENTITY = "__dify__file__" +DEFAULT_MIME_TYPE = "application/octet-stream" +DEFAULT_EXTENSION = ".bin" + + +def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]: + normalized = {extension.lower() for extension in extensions} + return frozenset(normalized | {extension.upper() for extension in normalized}) + + +IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"}) +VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"}) +AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"}) +DOCUMENT_EXTENSIONS = _with_case_variants( + { + "txt", + "markdown", + "md", + "mdx", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "vtt", + "properties", + "doc", + "docx", + "csv", + "eml", + "msg", + "ppt", + "pptx", + "xml", + "epub", + } +) + + +def maybe_file_object(o: Any) -> bool: + return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/dify_graph/file/enums.py b/api/graphon/file/enums.py similarity index 100% rename from api/dify_graph/file/enums.py rename to api/graphon/file/enums.py diff --git a/api/graphon/file/file_factory.py b/api/graphon/file/file_factory.py new file mode 100644 index 0000000000..3d20b9377d --- /dev/null +++ b/api/graphon/file/file_factory.py @@ -0,0 +1,39 @@ +from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from .enums import FileType + + +def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: + """ + Infer the actual file type from extension and mime type. + """ + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = get_file_type_by_mime_type(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + normalized_extension = extension.lstrip(".") + if normalized_extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + if normalized_extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + if normalized_extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + if normalized_extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + return None + + +def get_file_type_by_mime_type(mime_type: str) -> FileType: + if "image" in mime_type: + return FileType.IMAGE + if "video" in mime_type: + return FileType.VIDEO + if "audio" in mime_type: + return FileType.AUDIO + if "text" in mime_type or "pdf" in mime_type: + return FileType.DOCUMENT + return FileType.CUSTOM diff --git a/api/dify_graph/file/file_manager.py b/api/graphon/file/file_manager.py similarity index 74% rename from api/dify_graph/file/file_manager.py rename to api/graphon/file/file_manager.py index 8d998054db..d7e4d472e7 100644 --- a/api/dify_graph/file/file_manager.py +++ b/api/graphon/file/file_manager.py @@ -3,16 +3,15 @@ from __future__ import annotations import base64 from collections.abc import Mapping -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, VideoPromptMessageContent, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType from .runtime import get_workflow_file_runtime @@ -80,7 +79,7 @@ def download(f: File, /) -> bytes: FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE, ): - return _download_file_content(f.storage_key) + return _download_file_content(f) elif f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: raise ValueError("Missing file remote_url") @@ -90,12 +89,9 @@ def download(f: File, /) -> bytes: raise ValueError(f"unsupported transfer method: {f.transfer_method}") -def _download_file_content(path: str, /) -> bytes: +def _download_file_content(file: File, /) -> bytes: """Download and return a file from storage as bytes.""" - data = get_workflow_file_runtime().storage_load(path, stream=False) - if not isinstance(data, bytes): - raise ValueError(f"file {path} is not a bytes object") - return data + return get_workflow_file_runtime().load_file_bytes(file=file) def _get_encoded_string(f: File, /) -> str: @@ -107,30 +103,20 @@ def _get_encoded_string(f: File, /) -> str: response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) return base64.b64encode(data).decode("utf-8") def _to_url(f: File, /): - if f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - return f.remote_url - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - if f.related_id is None: - raise ValueError("Missing file related_id") - return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) - elif f.transfer_method == FileTransferMethod.TOOL_FILE: - if f.related_id is None or f.extension is None: - raise ValueError("Missing file related_id or extension") - return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension) - else: + url = f.generate_url() + if url is None: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + return url class FileManager: diff --git a/api/graphon/file/helpers.py b/api/graphon/file/helpers.py new file mode 100644 index 0000000000..dade761227 --- /dev/null +++ b/api/graphon/file/helpers.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .runtime import get_workflow_file_runtime + +if TYPE_CHECKING: + from .models import File + + +def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: + return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) + + +def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: + return get_workflow_file_runtime().resolve_upload_file_url( + upload_file_id=upload_file_id, + as_attachment=as_attachment, + for_external=for_external, + ) + + +def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: + return get_workflow_file_runtime().resolve_tool_file_url( + tool_file_id=tool_file_id, + extension=extension, + for_external=for_external, + ) + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="image", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="file", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) diff --git a/api/dify_graph/file/models.py b/api/graphon/file/models.py similarity index 61% rename from api/dify_graph/file/models.py rename to api/graphon/file/models.py index dcba00978e..ccd7584371 100644 --- a/api/dify_graph/file/models.py +++ b/api/graphon/file/models.py @@ -1,17 +1,20 @@ from __future__ import annotations +import base64 +import json from collections.abc import Mapping, Sequence from typing import Any -from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from . import helpers from .constants import FILE_MODEL_IDENTITY from .enums import FileTransferMethod, FileType +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" @@ -44,57 +47,68 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 -class ToolFile(BaseModel): - id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") - user_id: UUID = Field(..., description="ID of the user who owns this file") - tenant_id: UUID = Field(..., description="ID of the tenant/organization") - conversation_id: UUID | None = Field(None, description="ID of the associated conversation") - file_key: str = Field(..., max_length=255, description="Storage key for the file") - mimetype: str = Field(..., max_length=255, description="MIME type of the file") - original_url: str | None = Field( - None, max_length=2048, description="Original URL if file was fetched from external source" - ) - name: str = Field(default="", max_length=255, description="Display name of the file") - size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") +def _parse_reference(reference: str | None) -> tuple[str | None, str | None]: + """Best-effort parser for record references and historical storage-key payloads.""" + if not reference: + return None, None - class Config: - from_attributes = True # Enable ORM mode for SQLAlchemy compatibility - populate_by_name = True + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return reference, None + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return reference, None + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return reference, None + + storage_key = payload.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + return record_id, storage_key class File(BaseModel): + """Graph-owned file reference. + + The graph layer deliberately keeps only the metadata required to route, + serialize, and render files. Application ownership concerns such as + tenant/user/conversation identity stay in the workflow/storage layer. + """ + # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY id: str | None = None # message file id - tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. remote_url: str | None = None # remote url - # If `transfer_method` is `FileTransferMethod.local_file` or - # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. - # - # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: str | None = None + # Opaque workflow-layer reference for files resolved outside ``graphon``. + # New payloads only carry the backing record id; historical payloads may + # still include storage_key and must remain readable. + reference: str | None = None filename: str | None = None extension: str | None = Field(default=None, description="File extension, should contain dot") mime_type: str | None = None size: int = -1 - - # Those properties are private, should not be exposed to the outside. _storage_key: str def __init__( self, *, id: str | None = None, - tenant_id: str, + tenant_id: str | None = None, type: FileType, transfer_method: FileTransferMethod, remote_url: str | None = None, + reference: str | None = None, related_id: str | None = None, filename: str | None = None, extension: str | None = None, @@ -103,18 +117,23 @@ class File(BaseModel): storage_key: str | None = None, dify_model_identity: str | None = FILE_MODEL_IDENTITY, url: str | None = None, - # Legacy compatibility fields - explicitly handle known extra fields + # Legacy compatibility fields - explicitly accept known extra fields tool_file_id: str | None = None, upload_file_id: str | None = None, datasource_file_id: str | None = None, ): + legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id + normalized_reference = reference + if normalized_reference is None and legacy_record_id is not None: + normalized_reference = str(legacy_record_id) + _, parsed_storage_key = _parse_reference(normalized_reference) + super().__init__( id=id, - tenant_id=tenant_id, type=type, transfer_method=transfer_method, remote_url=remote_url, - related_id=related_id, + reference=normalized_reference, filename=filename, extension=extension, mime_type=mime_type, @@ -122,12 +141,15 @@ class File(BaseModel): dify_model_identity=dify_model_identity, url=url, ) - self._storage_key = str(storage_key) + # Accept legacy constructor fields without promoting them back into the graph model. + _ = tenant_id + self._storage_key = storage_key or parsed_storage_key or "" def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { **data, + "related_id": self.related_id, "url": self.generate_url(), } @@ -142,21 +164,7 @@ class File(BaseModel): return text def generate_url(self, for_external: bool = True) -> str | None: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) - elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: - assert self.related_id is not None - assert self.extension is not None - return sign_tool_file( - tool_file_id=self.related_id, - extension=self.extension, - for_external=for_external, - ) - return None + return helpers.resolve_file_url(self, for_external=for_external) def to_plugin_parameter(self) -> dict[str, Any]: return { @@ -178,19 +186,29 @@ class File(BaseModel): if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): raise ValueError("Invalid file url") case FileTransferMethod.LOCAL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.TOOL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.DATASOURCE_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") return self + @property + def related_id(self) -> str | None: + record_id, _ = _parse_reference(self.reference) + return record_id + + @related_id.setter + def related_id(self, value: str | None) -> None: + self.reference = value + @property def storage_key(self) -> str: - return self._storage_key + _, storage_key = _parse_reference(self.reference) + return storage_key or self._storage_key @storage_key.setter def storage_key(self, value: str) -> None: diff --git a/api/graphon/file/protocols.py b/api/graphon/file/protocols.py new file mode 100644 index 0000000000..0acabe35e5 --- /dev/null +++ b/api/graphon/file/protocols.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import TYPE_CHECKING, Literal, Protocol + +if TYPE_CHECKING: + from .models import File + + +class HttpResponseProtocol(Protocol): + """Subset of response behavior needed by workflow file helpers.""" + + @property + def content(self) -> bytes: ... + + def raise_for_status(self) -> object: ... + + +class WorkflowFileRuntimeProtocol(Protocol): + """Runtime dependencies required by ``graphon.file``. + + Implementations are expected to be provided by integration layers (for example, + ``core.app.workflow.file_runtime``) so the workflow package avoids importing + application infrastructure modules directly. + """ + + @property + def multimodal_send_format(self) -> str: ... + + def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... + + def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... + + def load_file_bytes(self, *, file: File) -> bytes: ... + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: ... + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: ... diff --git a/api/dify_graph/file/runtime.py b/api/graphon/file/runtime.py similarity index 63% rename from api/dify_graph/file/runtime.py rename to api/graphon/file/runtime.py index 94253e0255..1c5d1c3ca4 100644 --- a/api/dify_graph/file/runtime.py +++ b/api/graphon/file/runtime.py @@ -1,10 +1,13 @@ from __future__ import annotations from collections.abc import Generator -from typing import NoReturn +from typing import TYPE_CHECKING, Literal, NoReturn from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +if TYPE_CHECKING: + from .models import File + class WorkflowFileRuntimeNotConfiguredError(RuntimeError): """Raised when workflow file runtime dependencies were not configured.""" @@ -16,22 +19,6 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" ) - @property - def files_url(self) -> str: - self._raise() - - @property - def internal_files_url(self) -> str | None: - self._raise() - - @property - def secret_key(self) -> str: - self._raise() - - @property - def files_access_timeout(self) -> int: - self._raise() - @property def multimodal_send_format(self) -> str: self._raise() @@ -42,7 +29,33 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: self._raise() - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + self._raise() + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + self._raise() + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._raise() + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._raise() + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: self._raise() diff --git a/api/dify_graph/file/tool_file_parser.py b/api/graphon/file/tool_file_parser.py similarity index 100% rename from api/dify_graph/file/tool_file_parser.py rename to api/graphon/file/tool_file_parser.py diff --git a/api/dify_graph/graph/__init__.py b/api/graphon/graph/__init__.py similarity index 100% rename from api/dify_graph/graph/__init__.py rename to api/graphon/graph/__init__.py diff --git a/api/dify_graph/graph/edge.py b/api/graphon/graph/edge.py similarity index 91% rename from api/dify_graph/graph/edge.py rename to api/graphon/graph/edge.py index f4f67ea6be..1f8a2884e3 100644 --- a/api/dify_graph/graph/edge.py +++ b/api/graphon/graph/edge.py @@ -1,7 +1,7 @@ import uuid from dataclasses import dataclass, field -from dify_graph.enums import NodeState +from graphon.enums import NodeState @dataclass diff --git a/api/dify_graph/graph/graph.py b/api/graphon/graph/graph.py similarity index 98% rename from api/dify_graph/graph/graph.py rename to api/graphon/graph/graph.py index 85117583e0..0f4cd8925f 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/graphon/graph/graph.py @@ -7,10 +7,9 @@ from typing import Protocol, cast, final from pydantic import TypeAdapter -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.nodes.base.node import Node -from libs.typing import is_str +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState +from graphon.nodes.base.node import Node from .edge import Edge from .validation import get_graph_validator @@ -102,7 +101,7 @@ class Graph: source = edge_config.get("source") target = edge_config.get("target") - if not is_str(source) or not is_str(target): + if not isinstance(source, str) or not isinstance(target, str): continue # Create edge @@ -110,7 +109,7 @@ class Graph: edge_counter += 1 source_handle = edge_config.get("sourceHandle", "source") - if not is_str(source_handle): + if not isinstance(source_handle, str): continue edge = Edge( diff --git a/api/dify_graph/graph/graph_template.py b/api/graphon/graph/graph_template.py similarity index 100% rename from api/dify_graph/graph/graph_template.py rename to api/graphon/graph/graph_template.py diff --git a/api/dify_graph/graph/validation.py b/api/graphon/graph/validation.py similarity index 98% rename from api/dify_graph/graph/validation.py rename to api/graphon/graph/validation.py index 50d1440b04..04b501fd33 100644 --- a/api/dify_graph/graph/validation.py +++ b/api/graphon/graph/validation.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeType if TYPE_CHECKING: from .graph import Graph diff --git a/api/dify_graph/graph_engine/__init__.py b/api/graphon/graph_engine/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/__init__.py rename to api/graphon/graph_engine/__init__.py diff --git a/api/dify_graph/graph_engine/_engine_utils.py b/api/graphon/graph_engine/_engine_utils.py similarity index 100% rename from api/dify_graph/graph_engine/_engine_utils.py rename to api/graphon/graph_engine/_engine_utils.py diff --git a/api/dify_graph/graph_engine/command_channels/README.md b/api/graphon/graph_engine/command_channels/README.md similarity index 100% rename from api/dify_graph/graph_engine/command_channels/README.md rename to api/graphon/graph_engine/command_channels/README.md diff --git a/api/dify_graph/graph_engine/command_channels/__init__.py b/api/graphon/graph_engine/command_channels/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/command_channels/__init__.py rename to api/graphon/graph_engine/command_channels/__init__.py diff --git a/api/dify_graph/graph_engine/command_channels/in_memory_channel.py b/api/graphon/graph_engine/command_channels/in_memory_channel.py similarity index 100% rename from api/dify_graph/graph_engine/command_channels/in_memory_channel.py rename to api/graphon/graph_engine/command_channels/in_memory_channel.py diff --git a/api/dify_graph/graph_engine/command_channels/redis_channel.py b/api/graphon/graph_engine/command_channels/redis_channel.py similarity index 100% rename from api/dify_graph/graph_engine/command_channels/redis_channel.py rename to api/graphon/graph_engine/command_channels/redis_channel.py diff --git a/api/dify_graph/graph_engine/command_processing/__init__.py b/api/graphon/graph_engine/command_processing/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/command_processing/__init__.py rename to api/graphon/graph_engine/command_processing/__init__.py diff --git a/api/dify_graph/graph_engine/command_processing/command_handlers.py b/api/graphon/graph_engine/command_processing/command_handlers.py similarity index 95% rename from api/dify_graph/graph_engine/command_processing/command_handlers.py rename to api/graphon/graph_engine/command_processing/command_handlers.py index eefd0c366b..ad92fd1abb 100644 --- a/api/dify_graph/graph_engine/command_processing/command_handlers.py +++ b/api/graphon/graph_engine/command_processing/command_handlers.py @@ -3,8 +3,8 @@ from typing import final from typing_extensions import override -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.runtime import VariablePool +from graphon.entities.pause_reason import SchedulingPause +from graphon.runtime import VariablePool from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand diff --git a/api/dify_graph/graph_engine/command_processing/command_processor.py b/api/graphon/graph_engine/command_processing/command_processor.py similarity index 100% rename from api/dify_graph/graph_engine/command_processing/command_processor.py rename to api/graphon/graph_engine/command_processing/command_processor.py diff --git a/api/dify_graph/graph_engine/config.py b/api/graphon/graph_engine/config.py similarity index 100% rename from api/dify_graph/graph_engine/config.py rename to api/graphon/graph_engine/config.py diff --git a/api/dify_graph/graph_engine/domain/__init__.py b/api/graphon/graph_engine/domain/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/domain/__init__.py rename to api/graphon/graph_engine/domain/__init__.py diff --git a/api/dify_graph/graph_engine/domain/graph_execution.py b/api/graphon/graph_engine/domain/graph_execution.py similarity index 97% rename from api/dify_graph/graph_engine/domain/graph_execution.py rename to api/graphon/graph_engine/domain/graph_execution.py index 0ee4a9f9a7..9c0c7d1624 100644 --- a/api/dify_graph/graph_engine/domain/graph_execution.py +++ b/api/graphon/graph_engine/domain/graph_execution.py @@ -8,9 +8,9 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import NodeState -from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeState +from graphon.runtime.graph_runtime_state import GraphExecutionProtocol from .node_execution import NodeExecution diff --git a/api/dify_graph/graph_engine/domain/node_execution.py b/api/graphon/graph_engine/domain/node_execution.py similarity index 96% rename from api/dify_graph/graph_engine/domain/node_execution.py rename to api/graphon/graph_engine/domain/node_execution.py index ae8f9a5e50..dafd6ccd8a 100644 --- a/api/dify_graph/graph_engine/domain/node_execution.py +++ b/api/graphon/graph_engine/domain/node_execution.py @@ -4,7 +4,7 @@ NodeExecution entity representing a node's execution state. from dataclasses import dataclass -from dify_graph.enums import NodeState +from graphon.enums import NodeState @dataclass diff --git a/api/dify_graph/graph_engine/entities/__init__.py b/api/graphon/graph_engine/entities/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/entities/__init__.py rename to api/graphon/graph_engine/entities/__init__.py diff --git a/api/dify_graph/graph_engine/entities/commands.py b/api/graphon/graph_engine/entities/commands.py similarity index 97% rename from api/dify_graph/graph_engine/entities/commands.py rename to api/graphon/graph_engine/entities/commands.py index c56845cfc4..25ebc804b6 100644 --- a/api/dify_graph/graph_engine/entities/commands.py +++ b/api/graphon/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.variables.variables import Variable +from graphon.variables.variables import Variable class CommandType(StrEnum): diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/graphon/graph_engine/error_handler.py similarity index 97% rename from api/dify_graph/graph_engine/error_handler.py rename to api/graphon/graph_engine/error_handler.py index e206f21592..43ce8bb502 100644 --- a/api/dify_graph/graph_engine/error_handler.py +++ b/api/graphon/graph_engine/error_handler.py @@ -6,21 +6,21 @@ import logging import time from typing import TYPE_CHECKING, final -from dify_graph.enums import ( +from graphon.enums import ( ErrorStrategy as ErrorStrategyEnum, ) -from dify_graph.enums import ( +from graphon.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from graphon.graph import Graph +from graphon.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetryEvent, ) -from dify_graph.node_events import NodeRunResult +from graphon.node_events import NodeRunResult if TYPE_CHECKING: from .domain import GraphExecution diff --git a/api/dify_graph/graph_engine/event_management/__init__.py b/api/graphon/graph_engine/event_management/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/event_management/__init__.py rename to api/graphon/graph_engine/event_management/__init__.py diff --git a/api/dify_graph/graph_engine/event_management/event_handlers.py b/api/graphon/graph_engine/event_management/event_handlers.py similarity index 93% rename from api/dify_graph/graph_engine/event_management/event_handlers.py rename to api/graphon/graph_engine/event_management/event_handlers.py index 7f5ad40e0e..184148280d 100644 --- a/api/dify_graph/graph_engine/event_management/event_handlers.py +++ b/api/graphon/graph_engine/event_management/event_handlers.py @@ -7,9 +7,9 @@ from collections.abc import Mapping from functools import singledispatchmethod from typing import TYPE_CHECKING, final -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from graphon.enums import ErrorStrategy, NodeExecutionType, NodeState +from graphon.graph import Graph +from graphon.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunExceptionEvent, @@ -28,9 +28,10 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState from ..domain.graph_execution import GraphExecution from ..response_coordinator import ResponseStreamCoordinator @@ -93,6 +94,10 @@ class EventHandler: Args: event: The event to handle """ + if isinstance(event, NodeRunVariableUpdatedEvent): + self._dispatch(event) + return + # Events in loops or iterations are always collected if event.in_loop_id or event.in_iteration_id: self._event_collector.collect(event) @@ -153,6 +158,17 @@ class EventHandler: for stream_event in streaming_events: self._event_collector.collect(stream_event) + @_dispatch.register + def _(self, event: NodeRunVariableUpdatedEvent) -> None: + """ + Apply a node-requested variable mutation before downstream observers run. + + The event is collected like other node events so parent/container engines can + forward the updated payload to outer layers, including persistence listeners. + """ + self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) + self._event_collector.collect(event) + @_dispatch.register def _(self, event: NodeRunSucceededEvent) -> None: """ diff --git a/api/dify_graph/graph_engine/event_management/event_manager.py b/api/graphon/graph_engine/event_management/event_manager.py similarity index 99% rename from api/dify_graph/graph_engine/event_management/event_manager.py rename to api/graphon/graph_engine/event_management/event_manager.py index 616f621c3e..5b2fb365e9 100644 --- a/api/dify_graph/graph_engine/event_management/event_manager.py +++ b/api/graphon/graph_engine/event_management/event_manager.py @@ -9,7 +9,7 @@ from collections.abc import Generator from contextlib import contextmanager from typing import final -from dify_graph.graph_events import GraphEngineEvent +from graphon.graph_events import GraphEngineEvent from ..layers.base import GraphEngineLayer diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/graphon/graph_engine/graph_engine.py similarity index 91% rename from api/dify_graph/graph_engine/graph_engine.py rename to api/graphon/graph_engine/graph_engine.py index ea98a46b06..32e0e60502 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/graphon/graph_engine/graph_engine.py @@ -9,14 +9,13 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import TYPE_CHECKING, cast, final -from dify_graph.context import capture_current_context -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import NodeExecutionType +from graphon.graph import Graph +from graphon.graph_events import ( GraphEngineEvent, GraphNodeEventBase, GraphRunAbortedEvent, @@ -26,11 +25,11 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from graphon.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from dify_graph.runtime.graph_runtime_state import GraphProtocol + from graphon.runtime.graph_runtime_state import GraphProtocol from .command_processing import ( AbortCommandHandler, @@ -50,9 +49,9 @@ from .protocols.command_channel import CommandChannel from .worker_management import WorkerPool if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.graph_engine.domain.graph_execution import GraphExecution - from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator + from graphon.entities import GraphInitParams + from graphon.graph_engine.domain.graph_execution import GraphExecution + from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator logger = logging.getLogger(__name__) @@ -86,6 +85,7 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._layers: list[GraphEngineLayer] = [] self._child_engine_builder = child_engine_builder if child_engine_builder is not None: self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) @@ -149,21 +149,14 @@ class GraphEngine: update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Worker Pool Setup === - # Capture execution context for worker threads - execution_context = capture_current_context() - # Create worker pool for parallel node execution self._worker_pool = WorkerPool( ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, layers=self._layers, - execution_context=execution_context, + execution_context=self._graph_runtime_state.execution_context, config=self._config, ) @@ -220,23 +213,23 @@ class GraphEngine: self._bind_layer_context(layer) return self + def request_abort(self, reason: str | None = None) -> None: + """Queue an abort command for this engine.""" + self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) + def create_child_engine( self, *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: dict[str, object] | Mapping[str, object], root_node_id: str, - layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: return self._graph_runtime_state.create_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) def run(self) -> Generator[GraphEngineEvent, None, None]: diff --git a/api/dify_graph/graph_engine/graph_state_manager.py b/api/graphon/graph_engine/graph_state_manager.py similarity index 99% rename from api/dify_graph/graph_engine/graph_state_manager.py rename to api/graphon/graph_engine/graph_state_manager.py index 922a968435..ade8e403a8 100644 --- a/api/dify_graph/graph_engine/graph_state_manager.py +++ b/api/graphon/graph_engine/graph_state_manager.py @@ -6,8 +6,8 @@ import threading from collections.abc import Sequence from typing import TypedDict, final -from dify_graph.enums import NodeState -from dify_graph.graph import Edge, Graph +from graphon.enums import NodeState +from graphon.graph import Edge, Graph from .ready_queue import ReadyQueue diff --git a/api/dify_graph/graph_engine/graph_traversal/__init__.py b/api/graphon/graph_engine/graph_traversal/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/graph_traversal/__init__.py rename to api/graphon/graph_engine/graph_traversal/__init__.py diff --git a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py b/api/graphon/graph_engine/graph_traversal/edge_processor.py similarity index 97% rename from api/dify_graph/graph_engine/graph_traversal/edge_processor.py rename to api/graphon/graph_engine/graph_traversal/edge_processor.py index c4625a8ff7..e51eee8a69 100644 --- a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py +++ b/api/graphon/graph_engine/graph_traversal/edge_processor.py @@ -5,9 +5,9 @@ Edge processing logic for graph traversal. from collections.abc import Sequence from typing import TYPE_CHECKING, final -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Edge, Graph -from dify_graph.graph_events import NodeRunStreamChunkEvent +from graphon.enums import NodeExecutionType +from graphon.graph import Edge, Graph +from graphon.graph_events import NodeRunStreamChunkEvent from ..graph_state_manager import GraphStateManager from ..response_coordinator import ResponseStreamCoordinator diff --git a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py b/api/graphon/graph_engine/graph_traversal/skip_propagator.py similarity index 98% rename from api/dify_graph/graph_engine/graph_traversal/skip_propagator.py rename to api/graphon/graph_engine/graph_traversal/skip_propagator.py index 76445bccd2..bdb83b38ad 100644 --- a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py +++ b/api/graphon/graph_engine/graph_traversal/skip_propagator.py @@ -5,7 +5,7 @@ Skip state propagation through the graph. from collections.abc import Sequence from typing import final -from dify_graph.graph import Edge, Graph +from graphon.graph import Edge, Graph from ..graph_state_manager import GraphStateManager diff --git a/api/dify_graph/graph_engine/layers/README.md b/api/graphon/graph_engine/layers/README.md similarity index 100% rename from api/dify_graph/graph_engine/layers/README.md rename to api/graphon/graph_engine/layers/README.md diff --git a/api/dify_graph/graph_engine/layers/__init__.py b/api/graphon/graph_engine/layers/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/layers/__init__.py rename to api/graphon/graph_engine/layers/__init__.py diff --git a/api/dify_graph/graph_engine/layers/base.py b/api/graphon/graph_engine/layers/base.py similarity index 94% rename from api/dify_graph/graph_engine/layers/base.py rename to api/graphon/graph_engine/layers/base.py index 890336c1ca..605615d347 100644 --- a/api/dify_graph/graph_engine/layers/base.py +++ b/api/graphon/graph_engine/layers/base.py @@ -7,10 +7,10 @@ intercept and respond to GraphEngine events. from abc import ABC, abstractmethod -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ReadOnlyGraphRuntimeState +from graphon.graph_engine.protocols.command_channel import CommandChannel +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.runtime import ReadOnlyGraphRuntimeState class GraphEngineLayerNotInitializedError(Exception): diff --git a/api/dify_graph/graph_engine/layers/debug_logging.py b/api/graphon/graph_engine/layers/debug_logging.py similarity index 99% rename from api/dify_graph/graph_engine/layers/debug_logging.py rename to api/graphon/graph_engine/layers/debug_logging.py index 1af2e2db9e..e6585fb3b9 100644 --- a/api/dify_graph/graph_engine/layers/debug_logging.py +++ b/api/graphon/graph_engine/layers/debug_logging.py @@ -11,7 +11,7 @@ from typing import Any, final from typing_extensions import override -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, diff --git a/api/dify_graph/graph_engine/layers/execution_limits.py b/api/graphon/graph_engine/layers/execution_limits.py similarity index 94% rename from api/dify_graph/graph_engine/layers/execution_limits.py rename to api/graphon/graph_engine/layers/execution_limits.py index 48ba5608d9..2742b3acd3 100644 --- a/api/dify_graph/graph_engine/layers/execution_limits.py +++ b/api/graphon/graph_engine/layers/execution_limits.py @@ -15,13 +15,13 @@ from typing import final from typing_extensions import override -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, NodeRunStartedEvent, ) -from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent class LimitType(StrEnum): diff --git a/api/dify_graph/graph_engine/manager.py b/api/graphon/graph_engine/manager.py similarity index 94% rename from api/dify_graph/graph_engine/manager.py rename to api/graphon/graph_engine/manager.py index 955c149069..c728ff6986 100644 --- a/api/dify_graph/graph_engine/manager.py +++ b/api/graphon/graph_engine/manager.py @@ -10,8 +10,8 @@ import logging from collections.abc import Sequence from typing import final -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from dify_graph.graph_engine.entities.commands import ( +from graphon.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol +from graphon.graph_engine.entities.commands import ( AbortCommand, GraphEngineCommand, PauseCommand, diff --git a/api/dify_graph/graph_engine/orchestration/__init__.py b/api/graphon/graph_engine/orchestration/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/orchestration/__init__.py rename to api/graphon/graph_engine/orchestration/__init__.py diff --git a/api/dify_graph/graph_engine/orchestration/dispatcher.py b/api/graphon/graph_engine/orchestration/dispatcher.py similarity index 99% rename from api/dify_graph/graph_engine/orchestration/dispatcher.py rename to api/graphon/graph_engine/orchestration/dispatcher.py index f8aaf20b2f..f75bbee08e 100644 --- a/api/dify_graph/graph_engine/orchestration/dispatcher.py +++ b/api/graphon/graph_engine/orchestration/dispatcher.py @@ -8,7 +8,7 @@ import threading import time from typing import TYPE_CHECKING, final -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, diff --git a/api/dify_graph/graph_engine/orchestration/execution_coordinator.py b/api/graphon/graph_engine/orchestration/execution_coordinator.py similarity index 100% rename from api/dify_graph/graph_engine/orchestration/execution_coordinator.py rename to api/graphon/graph_engine/orchestration/execution_coordinator.py diff --git a/api/dify_graph/graph_engine/protocols/command_channel.py b/api/graphon/graph_engine/protocols/command_channel.py similarity index 100% rename from api/dify_graph/graph_engine/protocols/command_channel.py rename to api/graphon/graph_engine/protocols/command_channel.py diff --git a/api/dify_graph/graph_engine/ready_queue/__init__.py b/api/graphon/graph_engine/ready_queue/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/__init__.py rename to api/graphon/graph_engine/ready_queue/__init__.py diff --git a/api/dify_graph/graph_engine/ready_queue/factory.py b/api/graphon/graph_engine/ready_queue/factory.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/factory.py rename to api/graphon/graph_engine/ready_queue/factory.py diff --git a/api/dify_graph/graph_engine/ready_queue/in_memory.py b/api/graphon/graph_engine/ready_queue/in_memory.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/in_memory.py rename to api/graphon/graph_engine/ready_queue/in_memory.py diff --git a/api/dify_graph/graph_engine/ready_queue/protocol.py b/api/graphon/graph_engine/ready_queue/protocol.py similarity index 100% rename from api/dify_graph/graph_engine/ready_queue/protocol.py rename to api/graphon/graph_engine/ready_queue/protocol.py diff --git a/api/dify_graph/graph_engine/response_coordinator/__init__.py b/api/graphon/graph_engine/response_coordinator/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/response_coordinator/__init__.py rename to api/graphon/graph_engine/response_coordinator/__init__.py diff --git a/api/dify_graph/graph_engine/response_coordinator/coordinator.py b/api/graphon/graph_engine/response_coordinator/coordinator.py similarity index 98% rename from api/dify_graph/graph_engine/response_coordinator/coordinator.py rename to api/graphon/graph_engine/response_coordinator/coordinator.py index 941a8a496b..a6562f0223 100644 --- a/api/dify_graph/graph_engine/response_coordinator/coordinator.py +++ b/api/graphon/graph_engine/response_coordinator/coordinator.py @@ -14,11 +14,11 @@ from uuid import uuid4 from pydantic import BaseModel, Field -from dify_graph.enums import NodeExecutionType, NodeState -from dify_graph.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from dify_graph.nodes.base.template import TextSegment, VariableSegment -from dify_graph.runtime import VariablePool -from dify_graph.runtime.graph_runtime_state import GraphProtocol +from graphon.enums import NodeExecutionType, NodeState +from graphon.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent +from graphon.nodes.base.template import TextSegment, VariableSegment +from graphon.runtime import VariablePool +from graphon.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession diff --git a/api/dify_graph/graph_engine/response_coordinator/path.py b/api/graphon/graph_engine/response_coordinator/path.py similarity index 100% rename from api/dify_graph/graph_engine/response_coordinator/path.py rename to api/graphon/graph_engine/response_coordinator/path.py diff --git a/api/dify_graph/graph_engine/response_coordinator/session.py b/api/graphon/graph_engine/response_coordinator/session.py similarity index 94% rename from api/dify_graph/graph_engine/response_coordinator/session.py rename to api/graphon/graph_engine/response_coordinator/session.py index 11a9f5dac5..cb877f1504 100644 --- a/api/dify_graph/graph_engine/response_coordinator/session.py +++ b/api/graphon/graph_engine/response_coordinator/session.py @@ -10,8 +10,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Protocol, cast -from dify_graph.nodes.base.template import Template -from dify_graph.runtime.graph_runtime_state import NodeProtocol +from graphon.nodes.base.template import Template +from graphon.runtime.graph_runtime_state import NodeProtocol class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): diff --git a/api/dify_graph/graph_engine/worker.py b/api/graphon/graph_engine/worker.py similarity index 92% rename from api/dify_graph/graph_engine/worker.py rename to api/graphon/graph_engine/worker.py index 988c20d72a..a0844ee48e 100644 --- a/api/dify_graph/graph_engine/worker.py +++ b/api/graphon/graph_engine/worker.py @@ -9,19 +9,18 @@ import queue import threading import time from collections.abc import Sequence -from datetime import datetime +from contextlib import AbstractContextManager +from datetime import UTC, datetime from typing import TYPE_CHECKING, final from typing_extensions import override -from dify_graph.context import IExecutionContext -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from libs.datetime_utils import naive_utc_now +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .ready_queue import ReadyQueue @@ -46,7 +45,7 @@ class Worker(threading.Thread): graph: Graph, layers: Sequence[GraphEngineLayer], worker_id: int = 0, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize worker thread. @@ -187,7 +186,7 @@ class Worker(threading.Thread): self, node: Node, error: Exception, *, started_at: datetime | None = None ) -> NodeRunFailedEvent: """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = naive_utc_now() + failure_time = datetime.now(UTC).replace(tzinfo=None) error_message = str(error) return NodeRunFailedEvent( id=node.execution_id, diff --git a/api/dify_graph/graph_engine/worker_management/__init__.py b/api/graphon/graph_engine/worker_management/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/worker_management/__init__.py rename to api/graphon/graph_engine/worker_management/__init__.py diff --git a/api/dify_graph/graph_engine/worker_management/worker_pool.py b/api/graphon/graph_engine/worker_management/worker_pool.py similarity index 97% rename from api/dify_graph/graph_engine/worker_management/worker_pool.py rename to api/graphon/graph_engine/worker_management/worker_pool.py index cc93087783..85cdf1ca21 100644 --- a/api/dify_graph/graph_engine/worker_management/worker_pool.py +++ b/api/graphon/graph_engine/worker_management/worker_pool.py @@ -8,11 +8,11 @@ DynamicScaler, and WorkerFactory into a single class. import logging import queue import threading +from contextlib import AbstractContextManager from typing import final -from dify_graph.context import IExecutionContext -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphNodeEventBase +from graphon.graph import Graph +from graphon.graph_events import GraphNodeEventBase from ..config import GraphEngineConfig from ..layers.base import GraphEngineLayer @@ -38,7 +38,7 @@ class WorkerPool: graph: Graph, layers: list[GraphEngineLayer], config: GraphEngineConfig, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize the simple worker pool. diff --git a/api/dify_graph/graph_events/__init__.py b/api/graphon/graph_events/__init__.py similarity index 96% rename from api/dify_graph/graph_events/__init__.py rename to api/graphon/graph_events/__init__.py index 56ea642092..7cec587a05 100644 --- a/api/dify_graph/graph_events/__init__.py +++ b/api/graphon/graph_events/__init__.py @@ -46,6 +46,7 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, is_node_result_event, ) @@ -78,5 +79,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "NodeRunVariableUpdatedEvent", "is_node_result_event", ] diff --git a/api/dify_graph/graph_events/agent.py b/api/graphon/graph_events/agent.py similarity index 100% rename from api/dify_graph/graph_events/agent.py rename to api/graphon/graph_events/agent.py diff --git a/api/dify_graph/graph_events/base.py b/api/graphon/graph_events/base.py similarity index 88% rename from api/dify_graph/graph_events/base.py rename to api/graphon/graph_events/base.py index 4560cf5085..4ea9787b9a 100644 --- a/api/dify_graph/graph_events/base.py +++ b/api/graphon/graph_events/base.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field -from dify_graph.enums import NodeType -from dify_graph.node_events import NodeRunResult +from graphon.enums import NodeType +from graphon.node_events import NodeRunResult class GraphEngineEvent(BaseModel): diff --git a/api/dify_graph/graph_events/graph.py b/api/graphon/graph_events/graph.py similarity index 90% rename from api/dify_graph/graph_events/graph.py rename to api/graphon/graph_events/graph.py index f4aaba64d6..3782cb49bc 100644 --- a/api/dify_graph/graph_events/graph.py +++ b/api/graphon/graph_events/graph.py @@ -1,8 +1,8 @@ from pydantic import Field -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events import BaseGraphEvent +from graphon.entities.pause_reason import PauseReason +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): diff --git a/api/dify_graph/graph_events/human_input.py b/api/graphon/graph_events/human_input.py similarity index 100% rename from api/dify_graph/graph_events/human_input.py rename to api/graphon/graph_events/human_input.py diff --git a/api/dify_graph/graph_events/iteration.py b/api/graphon/graph_events/iteration.py similarity index 100% rename from api/dify_graph/graph_events/iteration.py rename to api/graphon/graph_events/iteration.py diff --git a/api/dify_graph/graph_events/loop.py b/api/graphon/graph_events/loop.py similarity index 100% rename from api/dify_graph/graph_events/loop.py rename to api/graphon/graph_events/loop.py diff --git a/api/dify_graph/graph_events/node.py b/api/graphon/graph_events/node.py similarity index 86% rename from api/dify_graph/graph_events/node.py rename to api/graphon/graph_events/node.py index df19d6c03b..471ae08ee7 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/graphon/graph_events/node.py @@ -1,10 +1,11 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from pydantic import Field -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason +from graphon.entities.pause_reason import PauseReason +from graphon.variables.variables import Variable from .base import GraphNodeEventBase @@ -30,7 +31,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase): class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") + retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") context: str = Field(..., description="context") @@ -39,6 +40,12 @@ class NodeRunSucceededEvent(GraphNodeEventBase): finished_at: datetime | None = Field(default=None, description="node finish time") +class NodeRunVariableUpdatedEvent(GraphNodeEventBase): + """Request that the engine apply a variable update before downstream observers continue.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class NodeRunFailedEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") diff --git a/api/dify_graph/model_runtime/README.md b/api/graphon/model_runtime/README.md similarity index 100% rename from api/dify_graph/model_runtime/README.md rename to api/graphon/model_runtime/README.md diff --git a/api/dify_graph/model_runtime/README_CN.md b/api/graphon/model_runtime/README_CN.md similarity index 100% rename from api/dify_graph/model_runtime/README_CN.md rename to api/graphon/model_runtime/README_CN.md diff --git a/api/dify_graph/model_runtime/__init__.py b/api/graphon/model_runtime/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/__init__.py rename to api/graphon/model_runtime/__init__.py diff --git a/api/dify_graph/model_runtime/callbacks/__init__.py b/api/graphon/model_runtime/callbacks/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/callbacks/__init__.py rename to api/graphon/model_runtime/callbacks/__init__.py diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/graphon/model_runtime/callbacks/base_callback.py similarity index 78% rename from api/dify_graph/model_runtime/callbacks/base_callback.py rename to api/graphon/model_runtime/callbacks/base_callback.py index 20faf3d6cd..cd85cf6301 100644 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ b/api/graphon/model_runtime/callbacks/base_callback.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.model_providers.__base.ai_model import AIModel _TEXT_COLOR_MAPPING = { "blue": "36;1", @@ -34,6 +34,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -46,7 +47,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -63,6 +65,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -76,7 +79,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -93,6 +97,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -106,7 +111,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -123,6 +129,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -136,7 +143,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/graphon/model_runtime/callbacks/logging_callback.py similarity index 81% rename from api/dify_graph/model_runtime/callbacks/logging_callback.py rename to api/graphon/model_runtime/callbacks/logging_callback.py index 49b9ab27eb..f96eb446fc 100644 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ b/api/graphon/model_runtime/callbacks/logging_callback.py @@ -1,13 +1,13 @@ import json import logging import sys -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import cast -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -24,6 +24,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -36,7 +37,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ self.print_text("\n[on_llm_before_invoke]\n", color="blue") self.print_text(f"Model: {model}\n", color="blue") @@ -53,10 +55,12 @@ class LoggingCallback(Callback): self.print_text(f"\t\t{tool.name}\n", color="blue") self.print_text(f"Stream: {stream}\n", color="blue") - if user: self.print_text(f"User: {user}\n", color="blue") + if invocation_context: + self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: @@ -80,6 +84,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -93,8 +98,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() @@ -110,6 +116,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -123,8 +130,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context self.print_text("\n[on_llm_after_invoke]\n", color="yellow") self.print_text(f"Content: {result.message.content}\n", color="yellow") @@ -151,6 +159,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -164,7 +173,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/dify_graph/model_runtime/entities/__init__.py b/api/graphon/model_runtime/entities/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/entities/__init__.py rename to api/graphon/model_runtime/entities/__init__.py diff --git a/api/dify_graph/model_runtime/entities/common_entities.py b/api/graphon/model_runtime/entities/common_entities.py similarity index 100% rename from api/dify_graph/model_runtime/entities/common_entities.py rename to api/graphon/model_runtime/entities/common_entities.py diff --git a/api/dify_graph/model_runtime/entities/defaults.py b/api/graphon/model_runtime/entities/defaults.py similarity index 98% rename from api/dify_graph/model_runtime/entities/defaults.py rename to api/graphon/model_runtime/entities/defaults.py index 53b732e5c6..bcce17c5d5 100644 --- a/api/dify_graph/model_runtime/entities/defaults.py +++ b/api/graphon/model_runtime/entities/defaults.py @@ -1,4 +1,4 @@ -from dify_graph.model_runtime.entities.model_entities import DefaultParameterName +from graphon.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { diff --git a/api/dify_graph/model_runtime/entities/llm_entities.py b/api/graphon/model_runtime/entities/llm_entities.py similarity index 97% rename from api/dify_graph/model_runtime/entities/llm_entities.py rename to api/graphon/model_runtime/entities/llm_entities.py index eec682a2ae..bfc80f21c5 100644 --- a/api/dify_graph/model_runtime/entities/llm_entities.py +++ b/api/graphon/model_runtime/entities/llm_entities.py @@ -7,8 +7,8 @@ from typing import Any, TypedDict, Union from pydantic import BaseModel, Field -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage +from graphon.model_runtime.entities.model_entities import ModelUsage, PriceInfo class LLMMode(StrEnum): diff --git a/api/dify_graph/model_runtime/entities/message_entities.py b/api/graphon/model_runtime/entities/message_entities.py similarity index 100% rename from api/dify_graph/model_runtime/entities/message_entities.py rename to api/graphon/model_runtime/entities/message_entities.py diff --git a/api/dify_graph/model_runtime/entities/model_entities.py b/api/graphon/model_runtime/entities/model_entities.py similarity index 98% rename from api/dify_graph/model_runtime/entities/model_entities.py rename to api/graphon/model_runtime/entities/model_entities.py index fbcde6740a..5ec4970faf 100644 --- a/api/dify_graph/model_runtime/entities/model_entities.py +++ b/api/graphon/model_runtime/entities/model_entities.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, model_validator -from dify_graph.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.common_entities import I18nObject class ModelType(StrEnum): diff --git a/api/dify_graph/model_runtime/entities/provider_entities.py b/api/graphon/model_runtime/entities/provider_entities.py similarity index 84% rename from api/dify_graph/model_runtime/entities/provider_entities.py rename to api/graphon/model_runtime/entities/provider_entities.py index 97a99ea7ce..8e6c516fb9 100644 --- a/api/dify_graph/model_runtime/entities/provider_entities.py +++ b/api/graphon/model_runtime/entities/provider_entities.py @@ -3,8 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType class ConfigurateMethod(StrEnum): @@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel): class SimpleProviderEntity(BaseModel): """ - Simple model class for provider. + Simplified provider schema exposed to callers. + + `provider` is the canonical runtime identifier. `provider_name` is an optional + compatibility alias for short-name lookups and is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None @@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel): class ProviderEntity(BaseModel): """ - Model class for provider. + Runtime-native provider schema. + + `provider` is the canonical runtime identifier. `provider_name` is a + compatibility alias for callers that still resolve providers by short name and + is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None @@ -153,6 +162,7 @@ class ProviderEntity(BaseModel): """ return SimpleProviderEntity( provider=self.provider, + provider_name=self.provider_name, label=self.label, icon_small=self.icon_small, supported_model_types=self.supported_model_types, diff --git a/api/dify_graph/model_runtime/entities/rerank_entities.py b/api/graphon/model_runtime/entities/rerank_entities.py similarity index 72% rename from api/dify_graph/model_runtime/entities/rerank_entities.py rename to api/graphon/model_runtime/entities/rerank_entities.py index 99709e1bcd..8a0bb5fac2 100644 --- a/api/dify_graph/model_runtime/entities/rerank_entities.py +++ b/api/graphon/model_runtime/entities/rerank_entities.py @@ -1,6 +1,13 @@ +from typing import TypedDict + from pydantic import BaseModel +class MultimodalRerankInput(TypedDict): + content: str + content_type: str + + class RerankDocument(BaseModel): """ Model class for rerank document. diff --git a/api/dify_graph/model_runtime/entities/text_embedding_entities.py b/api/graphon/model_runtime/entities/text_embedding_entities.py similarity index 71% rename from api/dify_graph/model_runtime/entities/text_embedding_entities.py rename to api/graphon/model_runtime/entities/text_embedding_entities.py index a0210c169d..08ffd83b5b 100644 --- a/api/dify_graph/model_runtime/entities/text_embedding_entities.py +++ b/api/graphon/model_runtime/entities/text_embedding_entities.py @@ -1,8 +1,16 @@ from decimal import Decimal +from enum import StrEnum, auto from pydantic import BaseModel -from dify_graph.model_runtime.entities.model_entities import ModelUsage +from graphon.model_runtime.entities.model_entities import ModelUsage + + +class EmbeddingInputType(StrEnum): + """Embedding request input variants understood by the model runtime.""" + + DOCUMENT = auto() + QUERY = auto() class EmbeddingUsage(ModelUsage): diff --git a/api/dify_graph/model_runtime/errors/__init__.py b/api/graphon/model_runtime/errors/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/errors/__init__.py rename to api/graphon/model_runtime/errors/__init__.py diff --git a/api/dify_graph/model_runtime/errors/invoke.py b/api/graphon/model_runtime/errors/invoke.py similarity index 100% rename from api/dify_graph/model_runtime/errors/invoke.py rename to api/graphon/model_runtime/errors/invoke.py diff --git a/api/dify_graph/model_runtime/errors/validate.py b/api/graphon/model_runtime/errors/validate.py similarity index 100% rename from api/dify_graph/model_runtime/errors/validate.py rename to api/graphon/model_runtime/errors/validate.py diff --git a/api/dify_graph/model_runtime/memory/__init__.py b/api/graphon/model_runtime/memory/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/memory/__init__.py rename to api/graphon/model_runtime/memory/__init__.py diff --git a/api/dify_graph/model_runtime/memory/prompt_message_memory.py b/api/graphon/model_runtime/memory/prompt_message_memory.py similarity index 89% rename from api/dify_graph/model_runtime/memory/prompt_message_memory.py rename to api/graphon/model_runtime/memory/prompt_message_memory.py index a76a7faf71..03e26e9ff5 100644 --- a/api/dify_graph/model_runtime/memory/prompt_message_memory.py +++ b/api/graphon/model_runtime/memory/prompt_message_memory.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Sequence from typing import Protocol -from dify_graph.model_runtime.entities import PromptMessage +from graphon.model_runtime.entities import PromptMessage DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 diff --git a/api/dify_graph/model_runtime/model_providers/__base/__init__.py b/api/graphon/model_runtime/model_providers/__base/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__base/__init__.py rename to api/graphon/model_runtime/model_providers/__base/__init__.py diff --git a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py b/api/graphon/model_runtime/model_providers/__base/ai_model.py similarity index 64% rename from api/dify_graph/model_runtime/model_providers/__base/ai_model.py rename to api/graphon/model_runtime/model_providers/__base/ai_model.py index ac7ae9925b..1700ec9740 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py +++ b/api/graphon/model_runtime/model_providers/__base/ai_model.py @@ -1,15 +1,8 @@ import decimal -import hashlib -import logging -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from redis import RedisError - -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from graphon.model_runtime.entities.model_entities import ( AIModelEntity, DefaultParameterName, ModelType, @@ -17,7 +10,8 @@ from dify_graph.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from extensions.ext_redis import redis_client - -logger = logging.getLogger(__name__) +from graphon.model_runtime.runtime import ModelRuntime -class AIModel(BaseModel): +class AIModel: """ - Base class for all models. + Runtime-facing base class for all model providers. + + This stays a regular Python class because instances hold live collaborators + such as the provider schema and runtime adapter rather than user input that + benefits from Pydantic validation. Subclasses must pin ``model_type`` via a + class attribute; the base class is not meant to be instantiated directly. """ - tenant_id: str = Field(description="Tenant ID") - model_type: ModelType = Field(description="Model type") - plugin_id: str = Field(description="Plugin ID") - provider_name: str = Field(description="Provider") - plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") - started_at: float = Field(description="Invoke start time", default=0) + model_type: ModelType + provider_schema: ProviderEntity + model_runtime: ModelRuntime + started_at: float - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) + def __init__( + self, + provider_schema: ProviderEntity, + model_runtime: ModelRuntime, + *, + started_at: float = 0, + ) -> None: + if getattr(type(self), "model_type", None) is None: + raise TypeError("AIModel subclasses must define model_type as a class attribute") + + self.model_type = type(self).model_type + self.provider_schema = provider_schema + self.model_runtime = model_runtime + self.started_at = started_at + + @property + def provider(self) -> str: + return self.provider_schema.provider + + @property + def provider_display_name(self) -> str: + return self.provider_schema.label.en_US @property def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. + Map model invoke error to unified error. - :return: Invoke error mapping + The key is the error type thrown to the caller, and the value contains + runtime-facing exception types that should be normalized to it. """ - from core.plugin.entities.plugin_daemon import PluginDaemonInnerError - return { InvokeConnectionError: [InvokeConnectionError], InvokeServerUnavailableError: [InvokeServerUnavailableError], InvokeRateLimitError: [InvokeRateLimitError], InvokeAuthorizationError: [InvokeAuthorizationError], InvokeBadRequestError: [InvokeBadRequestError], - PluginDaemonInnerError: [PluginDaemonInnerError], ValueError: [ValueError], } @@ -79,15 +89,18 @@ class AIModel(BaseModel): if invoke_error == InvokeAuthorizationError: return InvokeAuthorizationError( description=( - f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." + f"[{self.provider_display_name}] Incorrect model credentials provided, " + "please check and try again." ) ) elif isinstance(invoke_error, InvokeError): - return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + return InvokeError( + description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" + ) else: return error - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") + return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: """ @@ -144,65 +157,13 @@ class AIModel(BaseModel): :param credentials: model credentials :return: model schema """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, + return self.model_runtime.get_model_schema( + provider=self.provider, + model_type=self.model_type, model=model, credentials=credentials or {}, ) - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/graphon/model_runtime/model_providers/__base/large_language_model.py similarity index 88% rename from api/dify_graph/model_runtime/model_providers/__base/large_language_model.py rename to api/graphon/model_runtime/model_providers/__base/large_language_model.py index bf864ca227..0f909646a1 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ b/api/graphon/model_runtime/model_providers/__base/large_language_model.py @@ -1,27 +1,24 @@ import logging import time import uuid -from collections.abc import Callable, Generator, Iterator, Sequence +from collections.abc import Callable, Generator, Iterator, Mapping, Sequence from typing import Union -from pydantic import ConfigDict - -from configs import dify_config -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.callbacks.logging_callback import LoggingCallback +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentUnionTypes, PromptMessageTool, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.model_entities import ( ModelType, PriceType, ) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -140,11 +137,9 @@ def _build_llm_result_from_chunks( ) -def _invoke_llm_via_plugin( +def _invoke_llm_via_runtime( *, - tenant_id: str, - user_id: str, - plugin_id: str, + llm_model: "LargeLanguageModel", provider: str, model: str, credentials: dict, @@ -154,25 +149,19 @@ def _invoke_llm_via_plugin( stop: Sequence[str] | None, stream: bool, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_llm( - tenant_id=tenant_id, - user_id=user_id, - plugin_id=plugin_id, + return llm_model.model_runtime.invoke_llm( provider=provider, model=model, credentials=credentials, model_parameters=model_parameters, prompt_messages=list(prompt_messages), tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) -def _normalize_non_stream_plugin_result( +def _normalize_non_stream_runtime_result( model: str, prompt_messages: Sequence[PromptMessage], result: Union[LLMResult, Iterator[LLMResultChunk]], @@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel): model_type: ModelType = ModelType.LLM - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, @@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ @@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if dify_config.DEBUG: + if logger.isEnabledFor(logging.DEBUG): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - result = _invoke_llm_via_plugin( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + result = _invoke_llm_via_runtime( + llm_model=self, + provider=self.provider, model=model, credentials=credentials, model_parameters=model_parameters, @@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel): ) if not stream: - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: @@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) @@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) elif isinstance(result, LLMResult): @@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) # Following https://github.com/langgenius/dify/issues/17799, @@ -342,7 +320,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ @@ -384,7 +362,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -415,7 +393,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :return: """ - if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_llm_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - return 0 + return self.model_runtime.get_llm_num_tokens( + provider=self.provider, + model_type=self.model_type, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + ) def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int @@ -504,7 +474,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -517,7 +487,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -532,7 +502,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -546,7 +516,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -560,7 +530,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ _run_callbacks( callbacks, @@ -575,7 +545,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -589,7 +559,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -603,7 +573,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -619,7 +589,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -633,7 +603,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -647,7 +617,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -663,6 +633,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) diff --git a/api/graphon/model_runtime/model_providers/__base/moderation_model.py b/api/graphon/model_runtime/model_providers/__base/moderation_model.py new file mode 100644 index 0000000000..01f6842998 --- /dev/null +++ b/api/graphon/model_runtime/model_providers/__base/moderation_model.py @@ -0,0 +1,33 @@ +import time + +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.__base.ai_model import AIModel + + +class ModerationModel(AIModel): + """ + Model class for moderation model. + """ + + model_type: ModelType = ModelType.MODERATION + + def invoke(self, model: str, credentials: dict, text: str) -> bool: + """ + Invoke moderation model + + :param model: model name + :param credentials: model credentials + :param text: text to moderate + :return: false if text is safe, true otherwise + """ + self.started_at = time.perf_counter() + + try: + return self.model_runtime.invoke_moderation( + provider=self.provider, + model=model, + credentials=credentials, + text=text, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py b/api/graphon/model_runtime/model_providers/__base/rerank_model.py similarity index 61% rename from api/dify_graph/model_runtime/model_providers/__base/rerank_model.py rename to api/graphon/model_runtime/model_providers/__base/rerank_model.py index 5da2b84b95..94b2b5a4fb 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py +++ b/api/graphon/model_runtime/model_providers/__base/rerank_model.py @@ -1,6 +1,6 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.model_providers.__base.ai_model import AIModel class RerankModel(AIModel): @@ -18,7 +18,6 @@ class RerankModel(AIModel): docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -29,18 +28,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, @@ -55,11 +47,10 @@ class RerankModel(AIModel): self, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke multimodal rerank model @@ -69,18 +60,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_multimodal_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, diff --git a/api/graphon/model_runtime/model_providers/__base/speech2text_model.py b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py new file mode 100644 index 0000000000..4f5d648639 --- /dev/null +++ b/api/graphon/model_runtime/model_providers/__base/speech2text_model.py @@ -0,0 +1,31 @@ +from typing import IO + +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.__base.ai_model import AIModel + + +class Speech2TextModel(AIModel): + """ + Model class for speech2text model. + """ + + model_type: ModelType = ModelType.SPEECH2TEXT + + def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: + """ + Invoke speech to text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :return: text for given audio file + """ + try: + return self.model_runtime.invoke_speech_to_text( + provider=self.provider, + model=model, + credentials=credentials, + file=file, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py similarity index 65% rename from api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py rename to api/graphon/model_runtime/model_providers/__base/text_embedding_model.py index 3438da2ada..c8b4a0a6af 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/graphon/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,9 +1,6 @@ -from pydantic import ConfigDict - -from core.entities.embedding_type import EmbeddingInputType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.model_providers.__base.ai_model import AIModel class TextEmbeddingModel(AIModel): @@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel): model_type: ModelType = ModelType.TEXT_EMBEDDING - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, credentials: dict, texts: list[str] | None = None, multimodel_documents: list[dict] | None = None, - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ @@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel): :param credentials: model credentials :param texts: texts to embed :param files: files to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ - from core.plugin.impl.model import PluginModelClient - try: - plugin_model_manager = PluginModelClient() if texts: - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_text_embedding( + provider=self.provider, model=model, credentials=credentials, texts=texts, input_type=input_type, ) if multimodel_documents: - return plugin_model_manager.invoke_multimodal_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_embedding( + provider=self.provider, model=model, credentials=credentials, documents=multimodel_documents, @@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel): :param texts: texts to embed :return: """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_text_embedding_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_text_embedding_num_tokens( + provider=self.provider, model=model, credentials=credentials, texts=texts, diff --git a/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py rename to api/graphon/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py b/api/graphon/model_runtime/model_providers/__base/tts_model.py similarity index 57% rename from api/dify_graph/model_runtime/model_providers/__base/tts_model.py rename to api/graphon/model_runtime/model_providers/__base/tts_model.py index 0656529f22..6846f3c403 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py +++ b/api/graphon/model_runtime/model_providers/__base/tts_model.py @@ -1,10 +1,8 @@ import logging from collections.abc import Iterable -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -16,38 +14,25 @@ class TTSModel(AIModel): model_type: ModelType = ModelType.TTS - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, - tenant_id: str, credentials: dict, content_text: str, voice: str, - user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param voice: model timbre :param content_text: text content to be translated - :param user: unique user id :return: translated audio file """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_tts( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_tts( + provider=self.provider, model=model, credentials=credentials, content_text=content_text, @@ -65,14 +50,8 @@ class TTSModel(AIModel): :param credentials: The credentials required to access the TTS model. :return: A list of voices supported by the TTS model. """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_tts_model_voices( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_tts_model_voices( + provider=self.provider, model=model, credentials=credentials, language=language, diff --git a/api/dify_graph/model_runtime/model_providers/__init__.py b/api/graphon/model_runtime/model_providers/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__init__.py rename to api/graphon/model_runtime/model_providers/__init__.py diff --git a/api/dify_graph/model_runtime/model_providers/_position.yaml b/api/graphon/model_runtime/model_providers/_position.yaml similarity index 100% rename from api/dify_graph/model_runtime/model_providers/_position.yaml rename to api/graphon/model_runtime/model_providers/_position.yaml diff --git a/api/graphon/model_runtime/model_providers/model_provider_factory.py b/api/graphon/model_runtime/model_providers/model_provider_factory.py new file mode 100644 index 0000000000..1ea30c7120 --- /dev/null +++ b/api/graphon/model_runtime/model_providers/model_provider_factory.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from collections.abc import Sequence + +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator +from graphon.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) + + +class ModelProviderFactory: + """Factory for provider schemas and model-type instances backed by a runtime adapter.""" + + def __init__(self, model_runtime: ModelRuntime): + if model_runtime is None: + raise ValueError("model_runtime is required.") + self.model_runtime = model_runtime + + def get_providers(self) -> Sequence[ProviderEntity]: + """ + Get all providers. + """ + return list(self.get_model_providers()) + + def get_model_providers(self) -> Sequence[ProviderEntity]: + """ + Get all model providers exposed by the runtime adapter. + """ + return self.model_runtime.fetch_model_providers() + + def get_provider_schema(self, provider: str) -> ProviderEntity: + """ + Get provider schema. + """ + return self.get_model_provider(provider=provider) + + def get_model_provider(self, provider: str) -> ProviderEntity: + """ + Get provider schema. + """ + provider_entity = self._resolve_provider(provider) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + + return provider_entity + + def provider_credentials_validate(self, *, provider: str, credentials: dict): + """ + Validate provider credentials. + """ + provider_entity = self.get_model_provider(provider=provider) + + provider_credential_schema = provider_entity.provider_credential_schema + if not provider_credential_schema: + raise ValueError(f"Provider {provider} does not have provider_credential_schema") + + validator = ProviderCredentialSchemaValidator(provider_credential_schema) + filtered_credentials = validator.validate_and_filter(credentials) + + self.model_runtime.validate_provider_credentials( + provider=provider_entity.provider, + credentials=filtered_credentials, + ) + + return filtered_credentials + + def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): + """ + Validate model credentials. + """ + provider_entity = self.get_model_provider(provider=provider) + + model_credential_schema = provider_entity.model_credential_schema + if not model_credential_schema: + raise ValueError(f"Provider {provider} does not have model_credential_schema") + + validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) + filtered_credentials = validator.validate_and_filter(credentials) + + self.model_runtime.validate_model_credentials( + provider=provider_entity.provider, + model_type=model_type, + model=model, + credentials=filtered_credentials, + ) + + return filtered_credentials + + def get_model_schema( + self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None + ) -> AIModelEntity | None: + """ + Get model schema. + """ + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_model_schema( + provider=provider_entity.provider, + model_type=model_type, + model=model, + credentials=credentials or {}, + ) + + def get_models( + self, + *, + provider: str | None = None, + model_type: ModelType | None = None, + provider_configs: list[ProviderConfig] | None = None, + ) -> list[SimpleProviderEntity]: + """ + Get all models for given model type. + """ + providers = [] + for provider_entity in self.get_model_providers(): + if provider and not self._matches_provider(provider_entity, provider): + continue + + if model_type and model_type not in provider_entity.supported_model_types: + continue + + simple_provider_schema = provider_entity.to_simple_provider() + if model_type is not None: + simple_provider_schema.models = [ + model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type + ] + providers.append(simple_provider_schema) + + return providers + + def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: + """ + Get model type instance by provider name and model type. + """ + provider_schema = self.get_model_provider(provider) + + if model_type == ModelType.LLM: + return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TEXT_EMBEDDING: + return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.RERANK: + return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.SPEECH2TEXT: + return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.MODERATION: + return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TTS: + return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + + raise ValueError(f"Unsupported model type: {model_type}") + + def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + """ + Get provider icon. + """ + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) + + def _resolve_provider(self, provider: str) -> ProviderEntity | None: + return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) + + @staticmethod + def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: + return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/graphon/model_runtime/runtime.py b/api/graphon/model_runtime/runtime.py new file mode 100644 index 0000000000..79862bab8b --- /dev/null +++ b/api/graphon/model_runtime/runtime.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from collections.abc import Generator, Iterable, Sequence +from typing import IO, Any, Protocol, Union, runtime_checkable + +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult + + +@runtime_checkable +class ModelRuntime(Protocol): + """Port for provider discovery, schema lookup, and model execution. + + `provider` is the model runtime's canonical provider identifier. Adapters may + derive transport-specific details from it, but those details stay outside + this boundary. + """ + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: ... + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: ... + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: ... + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: ... + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: ... + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: ... + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: ... + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: ... diff --git a/api/dify_graph/model_runtime/schema_validators/__init__.py b/api/graphon/model_runtime/schema_validators/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/schema_validators/__init__.py rename to api/graphon/model_runtime/schema_validators/__init__.py diff --git a/api/dify_graph/model_runtime/schema_validators/common_validator.py b/api/graphon/model_runtime/schema_validators/common_validator.py similarity index 97% rename from api/dify_graph/model_runtime/schema_validators/common_validator.py rename to api/graphon/model_runtime/schema_validators/common_validator.py index 04cdb8e4f7..984507081b 100644 --- a/api/dify_graph/model_runtime/schema_validators/common_validator.py +++ b/api/graphon/model_runtime/schema_validators/common_validator.py @@ -1,6 +1,6 @@ from typing import Union, cast -from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType +from graphon.model_runtime.entities.provider_entities import CredentialFormSchema, FormType class CommonValidator: diff --git a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py similarity index 78% rename from api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py rename to api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py index a97796e98f..9e4830c1b7 100644 --- a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/graphon/model_runtime/schema_validators/model_credential_schema_validator.py @@ -1,6 +1,6 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ModelCredentialSchema +from graphon.model_runtime.schema_validators.common_validator import CommonValidator class ModelCredentialSchemaValidator(CommonValidator): diff --git a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py similarity index 79% rename from api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py rename to api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py index 2fed75a76c..05fd3ce142 100644 --- a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/graphon/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -1,5 +1,5 @@ -from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator +from graphon.model_runtime.entities.provider_entities import ProviderCredentialSchema +from graphon.model_runtime.schema_validators.common_validator import CommonValidator class ProviderCredentialSchemaValidator(CommonValidator): diff --git a/api/dify_graph/model_runtime/utils/__init__.py b/api/graphon/model_runtime/utils/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/utils/__init__.py rename to api/graphon/model_runtime/utils/__init__.py diff --git a/api/dify_graph/model_runtime/utils/encoders.py b/api/graphon/model_runtime/utils/encoders.py similarity index 83% rename from api/dify_graph/model_runtime/utils/encoders.py rename to api/graphon/model_runtime/utils/encoders.py index c85152463e..13abf74767 100644 --- a/api/dify_graph/model_runtime/utils/encoders.py +++ b/api/graphon/model_runtime/utils/encoders.py @@ -1,7 +1,7 @@ import dataclasses import datetime from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network @@ -99,7 +99,7 @@ def jsonable_encoder( exclude_defaults: bool = False, exclude_none: bool = False, custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - sqlalchemy_safe: bool = True, + excluded_key_prefixes: Sequence[str] = (), ) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: @@ -126,7 +126,7 @@ def jsonable_encoder( obj_dict, exclude_none=exclude_none, exclude_defaults=exclude_defaults, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if dataclasses.is_dataclass(obj): # Ensure obj is a dataclass instance, not a dataclass type @@ -139,7 +139,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if isinstance(obj, Enum): return obj.value @@ -152,26 +152,28 @@ def jsonable_encoder( if isinstance(obj, dict): encoded_dict = {} for key, value in obj.items(): - if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( - value is not None or not exclude_none - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value + if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): + continue + if value is None and exclude_none: + continue + + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] @@ -184,7 +186,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) ) return encoded_list @@ -212,5 +214,5 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) diff --git a/api/dify_graph/node_events/__init__.py b/api/graphon/node_events/__init__.py similarity index 95% rename from api/dify_graph/node_events/__init__.py rename to api/graphon/node_events/__init__.py index a9bef8f9a2..a2bbf9f176 100644 --- a/api/dify_graph/node_events/__init__.py +++ b/api/graphon/node_events/__init__.py @@ -21,6 +21,7 @@ from .node import ( RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) __all__ = [ @@ -43,4 +44,5 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "VariableUpdatedEvent", ] diff --git a/api/dify_graph/node_events/agent.py b/api/graphon/node_events/agent.py similarity index 100% rename from api/dify_graph/node_events/agent.py rename to api/graphon/node_events/agent.py diff --git a/api/dify_graph/node_events/base.py b/api/graphon/node_events/base.py similarity index 86% rename from api/dify_graph/node_events/base.py rename to api/graphon/node_events/base.py index 2f6259ae7d..dcd1672428 100644 --- a/api/dify_graph/node_events/base.py +++ b/api/graphon/node_events/base.py @@ -3,8 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage class NodeEventBase(BaseModel): diff --git a/api/dify_graph/node_events/iteration.py b/api/graphon/node_events/iteration.py similarity index 100% rename from api/dify_graph/node_events/iteration.py rename to api/graphon/node_events/iteration.py diff --git a/api/dify_graph/node_events/loop.py b/api/graphon/node_events/loop.py similarity index 100% rename from api/dify_graph/node_events/loop.py rename to api/graphon/node_events/loop.py diff --git a/api/dify_graph/node_events/node.py b/api/graphon/node_events/node.py similarity index 79% rename from api/dify_graph/node_events/node.py rename to api/graphon/node_events/node.py index 2e3973b8fa..17f1494cf2 100644 --- a/api/dify_graph/node_events/node.py +++ b/api/graphon/node_events/node.py @@ -4,10 +4,11 @@ from typing import Any from pydantic import Field -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.file import File -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult +from graphon.entities.pause_reason import PauseReason +from graphon.file import File +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult +from graphon.variables.variables import Variable from .base import NodeEventBase @@ -45,6 +46,12 @@ class StreamCompletedEvent(NodeEventBase): node_run_result: NodeRunResult = Field(..., description="run result") +class VariableUpdatedEvent(NodeEventBase): + """Notify the engine that a single variable should be applied to the shared pool.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/graphon/nodes/__init__.py b/api/graphon/nodes/__init__.py new file mode 100644 index 0000000000..2d376d104d --- /dev/null +++ b/api/graphon/nodes/__init__.py @@ -0,0 +1,3 @@ +from graphon.enums import BuiltinNodeTypes + +__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/answer/__init__.py b/api/graphon/nodes/answer/__init__.py similarity index 100% rename from api/dify_graph/nodes/answer/__init__.py rename to api/graphon/nodes/answer/__init__.py diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/graphon/nodes/answer/answer_node.py similarity index 83% rename from api/dify_graph/nodes/answer/answer_node.py rename to api/graphon/nodes/answer/answer_node.py index 4286e1a492..c5261a7939 100644 --- a/api/dify_graph/nodes/answer/answer_node.py +++ b/api/graphon/nodes/answer/answer_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.answer.entities import AnswerNodeData -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.variables import ArrayFileSegment, FileSegment, Segment +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.answer.entities import AnswerNodeData +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.variables import ArrayFileSegment, FileSegment, Segment class AnswerNode(Node[AnswerNodeData]): diff --git a/api/dify_graph/nodes/answer/entities.py b/api/graphon/nodes/answer/entities.py similarity index 93% rename from api/dify_graph/nodes/answer/entities.py rename to api/graphon/nodes/answer/entities.py index cd82df1ac4..c49f1f3895 100644 --- a/api/dify_graph/nodes/answer/entities.py +++ b/api/graphon/nodes/answer/entities.py @@ -3,8 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AnswerNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/base/__init__.py b/api/graphon/nodes/base/__init__.py similarity index 100% rename from api/dify_graph/nodes/base/__init__.py rename to api/graphon/nodes/base/__init__.py diff --git a/api/dify_graph/nodes/base/entities.py b/api/graphon/nodes/base/entities.py similarity index 96% rename from api/dify_graph/nodes/base/entities.py rename to api/graphon/nodes/base/entities.py index 4f8b2682e1..94b88c097d 100644 --- a/api/dify_graph/nodes/base/entities.py +++ b/api/graphon/nodes/base/entities.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, field_validator -from dify_graph.entities.base_node_data import BaseNodeData +from graphon.entities.base_node_data import BaseNodeData class VariableSelector(BaseModel): diff --git a/api/dify_graph/nodes/base/node.py b/api/graphon/nodes/base/node.py similarity index 91% rename from api/dify_graph/nodes/base/node.py rename to api/graphon/nodes/base/node.py index 56b46a5894..613ff4f037 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/graphon/nodes/base/node.py @@ -4,23 +4,23 @@ import logging import operator from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence +from datetime import UTC, datetime from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData, RetryConfig +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus, ) -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, @@ -39,8 +39,9 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) -from dify_graph.node_events import ( +from graphon.node_events import ( AgentLogEvent, HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -58,9 +59,9 @@ from dify_graph.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) -from dify_graph.runtime import GraphRuntimeState -from libs.datetime_utils import naive_utc_now +from graphon.runtime import GraphRuntimeState NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -68,23 +69,6 @@ _MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) -class DifyRunContextProtocol(Protocol): - tenant_id: str - app_id: str - user_id: str - user_from: Any - invoke_from: Any - - -class _MappingDifyRunContext: - def __init__(self, mapping: Mapping[str, Any]) -> None: - self.tenant_id = str(mapping["tenant_id"]) - self.app_id = str(mapping["app_id"]) - self.user_id = str(mapping["user_id"]) - self.user_from = mapping["user_from"] - self.invoke_from = mapping["invoke_from"] - - class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -177,8 +161,9 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under the - # canonical workflow namespaces. + # Only treat nodes from the base graphon package as production + # registrations. Higher-layer packages may still register subclasses, + # but graphon itself should not know their module identities. # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -186,7 +171,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): + if module_name.startswith("graphon.nodes."): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -263,16 +248,25 @@ class Node(Generic[NodeDataT]): self._node_id = node_id self._node_execution_id: str = "" - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) self._node_data = self.validate_node_data(config["data"]) self.post_init() @classmethod - def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model.""" - return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + def validate_node_data(cls, node_data: BaseNodeData | Mapping[str, Any]) -> NodeDataT: + """Validate shared graph node payloads against the subclass-declared NodeData model. + + Re-validate from a dumped payload instead of `from_attributes=True` so compatibility + extras stored on `BaseNodeData` survive the handoff to the concrete node data model. + Human Input delivery methods are one such extra field until graphon owns that schema. + """ + if isinstance(node_data, BaseNodeData): + payload = node_data.model_dump(mode="python") + else: + payload = dict(node_data) + return cast(NodeDataT, cls._node_data_type.model_validate(payload)) def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" @@ -299,25 +293,6 @@ class Node(Generic[NodeDataT]): raise ValueError(f"run_context missing required key: {key}") return value - def require_dify_context(self) -> DifyRunContextProtocol: - raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) - if raw_ctx is None: - raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") - - if isinstance(raw_ctx, Mapping): - missing_keys = [ - key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx - ] - if missing_keys: - raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") - return _MappingDifyRunContext(raw_ctx) - - for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): - if not hasattr(raw_ctx, attr): - raise TypeError(f"invalid dify context object, missing attribute: {attr}") - - return cast(DifyRunContextProtocol, raw_ctx) - @property def execution_id(self) -> str: return self._node_execution_id @@ -364,7 +339,7 @@ class Node(Generic[NodeDataT]): def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) # Create and push start event with required fields start_event = NodeRunStartedEvent( @@ -406,7 +381,7 @@ class Node(Generic[NodeDataT]): error=str(e), error_type="WorkflowNodeError", ) - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) yield NodeRunFailedEvent( id=self.execution_id, node_id=self._node_id, @@ -570,7 +545,7 @@ class Node(Generic[NodeDataT]): return self._node_data def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( @@ -611,7 +586,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -637,6 +612,15 @@ class Node(Generic[NodeDataT]): f"Node {self._node_id} does not support status {event.node_run_result.status}" ) + @_dispatch.register + def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + variable=event.variable, + ) + @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( @@ -793,16 +777,11 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - from core.rag.entities.citation_metadata import RetrievalSourceMetadata - - retriever_resources = [ - RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources - ] return NodeRunRetrieverResourceEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, - retriever_resources=retriever_resources, + retriever_resources=event.retriever_resources, context=event.context, node_version=self.version(), ) diff --git a/api/dify_graph/nodes/base/template.py b/api/graphon/nodes/base/template.py similarity index 98% rename from api/dify_graph/nodes/base/template.py rename to api/graphon/nodes/base/template.py index 5976e808e3..311de4a6ea 100644 --- a/api/dify_graph/nodes/base/template.py +++ b/api/graphon/nodes/base/template.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Union -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.base.variable_template_parser import VariableTemplateParser @dataclass(frozen=True) diff --git a/api/dify_graph/nodes/base/usage_tracking_mixin.py b/api/graphon/nodes/base/usage_tracking_mixin.py similarity index 89% rename from api/dify_graph/nodes/base/usage_tracking_mixin.py rename to api/graphon/nodes/base/usage_tracking_mixin.py index bd49419fd3..955bfe6726 100644 --- a/api/dify_graph/nodes/base/usage_tracking_mixin.py +++ b/api/graphon/nodes/base/usage_tracking_mixin.py @@ -1,5 +1,5 @@ -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState class LLMUsageTrackingMixin: diff --git a/api/dify_graph/nodes/base/variable_template_parser.py b/api/graphon/nodes/base/variable_template_parser.py similarity index 100% rename from api/dify_graph/nodes/base/variable_template_parser.py rename to api/graphon/nodes/base/variable_template_parser.py diff --git a/api/dify_graph/nodes/code/__init__.py b/api/graphon/nodes/code/__init__.py similarity index 100% rename from api/dify_graph/nodes/code/__init__.py rename to api/graphon/nodes/code/__init__.py diff --git a/api/dify_graph/nodes/code/code_node.py b/api/graphon/nodes/code/code_node.py similarity index 97% rename from api/dify_graph/nodes/code/code_node.py rename to api/graphon/nodes/code/code_node.py index 82d5fced62..c2eea0bec1 100644 --- a/api/dify_graph/nodes/code/code_node.py +++ b/api/graphon/nodes/code/code_node.py @@ -3,14 +3,14 @@ from decimal import Decimal from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.segments import ArrayFileSegment -from dify_graph.variables.types import SegmentType +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.variables.segments import ArrayFileSegment +from graphon.variables.types import SegmentType from .exc import ( CodeNodeError, @@ -19,8 +19,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class WorkflowCodeExecutor(Protocol): diff --git a/api/dify_graph/nodes/code/entities.py b/api/graphon/nodes/code/entities.py similarity index 85% rename from api/dify_graph/nodes/code/entities.py rename to api/graphon/nodes/code/entities.py index 55b4ee4862..dc89d64495 100644 --- a/api/dify_graph/nodes/code/entities.py +++ b/api/graphon/nodes/code/entities.py @@ -3,10 +3,10 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.entities import VariableSelector +from graphon.variables.types import SegmentType class CodeLanguage(StrEnum): diff --git a/api/dify_graph/nodes/code/exc.py b/api/graphon/nodes/code/exc.py similarity index 100% rename from api/dify_graph/nodes/code/exc.py rename to api/graphon/nodes/code/exc.py diff --git a/api/dify_graph/nodes/code/limits.py b/api/graphon/nodes/code/limits.py similarity index 100% rename from api/dify_graph/nodes/code/limits.py rename to api/graphon/nodes/code/limits.py diff --git a/api/dify_graph/nodes/document_extractor/__init__.py b/api/graphon/nodes/document_extractor/__init__.py similarity index 100% rename from api/dify_graph/nodes/document_extractor/__init__.py rename to api/graphon/nodes/document_extractor/__init__.py diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/graphon/nodes/document_extractor/entities.py similarity index 73% rename from api/dify_graph/nodes/document_extractor/entities.py rename to api/graphon/nodes/document_extractor/entities.py index 1110cc2710..026a0cd224 100644 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ b/api/graphon/nodes/document_extractor/entities.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from dataclasses import dataclass -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class DocumentExtractorNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/document_extractor/exc.py b/api/graphon/nodes/document_extractor/exc.py similarity index 100% rename from api/dify_graph/nodes/document_extractor/exc.py rename to api/graphon/nodes/document_extractor/exc.py diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/graphon/nodes/document_extractor/node.py similarity index 98% rename from api/dify_graph/nodes/document_extractor/node.py rename to api/graphon/nodes/document_extractor/node.py index 27196f1aca..be46481e7d 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/graphon/nodes/document_extractor/node.py @@ -21,14 +21,14 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, file_manager -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment, FileSegment +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, file_manager +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.protocols import HttpClientProtocol +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment, FileSegment from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -36,8 +36,8 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class DocumentExtractorNode(Node[DocumentExtractorNodeData]): diff --git a/api/dify_graph/nodes/end/__init__.py b/api/graphon/nodes/end/__init__.py similarity index 100% rename from api/dify_graph/nodes/end/__init__.py rename to api/graphon/nodes/end/__init__.py diff --git a/api/dify_graph/nodes/end/end_node.py b/api/graphon/nodes/end/end_node.py similarity index 82% rename from api/dify_graph/nodes/end/end_node.py rename to api/graphon/nodes/end/end_node.py index 1f5cfab22b..11b9e58644 100644 --- a/api/dify_graph/nodes/end/end_node.py +++ b/api/graphon/nodes/end/end_node.py @@ -1,8 +1,8 @@ -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.end.entities import EndNodeData +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template +from graphon.nodes.end.entities import EndNodeData class EndNode(Node[EndNodeData]): diff --git a/api/dify_graph/nodes/end/entities.py b/api/graphon/nodes/end/entities.py similarity index 76% rename from api/dify_graph/nodes/end/entities.py rename to api/graphon/nodes/end/entities.py index be7f0c8de8..839aed7e4b 100644 --- a/api/dify_graph/nodes/end/entities.py +++ b/api/graphon/nodes/end/entities.py @@ -1,8 +1,8 @@ from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import OutputVariableEntity +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.entities import OutputVariableEntity class EndNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/http_request/__init__.py b/api/graphon/nodes/http_request/__init__.py similarity index 100% rename from api/dify_graph/nodes/http_request/__init__.py rename to api/graphon/nodes/http_request/__init__.py diff --git a/api/dify_graph/nodes/http_request/config.py b/api/graphon/nodes/http_request/config.py similarity index 100% rename from api/dify_graph/nodes/http_request/config.py rename to api/graphon/nodes/http_request/config.py diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/graphon/nodes/http_request/entities.py similarity index 98% rename from api/dify_graph/nodes/http_request/entities.py rename to api/graphon/nodes/http_request/entities.py index f594d58ae6..6fa067bdd1 100644 --- a/api/dify_graph/nodes/http_request/entities.py +++ b/api/graphon/nodes/http_request/entities.py @@ -8,8 +8,8 @@ import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" diff --git a/api/dify_graph/nodes/http_request/exc.py b/api/graphon/nodes/http_request/exc.py similarity index 100% rename from api/dify_graph/nodes/http_request/exc.py rename to api/graphon/nodes/http_request/exc.py diff --git a/api/dify_graph/nodes/http_request/executor.py b/api/graphon/nodes/http_request/executor.py similarity index 98% rename from api/dify_graph/nodes/http_request/executor.py rename to api/graphon/nodes/http_request/executor.py index 892b0fc688..0c6f4ecd3a 100644 --- a/api/dify_graph/nodes/http_request/executor.py +++ b/api/graphon/nodes/http_request/executor.py @@ -10,9 +10,9 @@ from urllib.parse import urlencode, urlparse import httpx from json_repair import repair_json -from dify_graph.file.enums import FileTransferMethod -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayFileSegment, FileSegment +from graphon.file.enums import FileTransferMethod +from graphon.runtime import VariablePool +from graphon.variables.segments import ArrayFileSegment, FileSegment from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( @@ -246,7 +246,7 @@ class Executor: files: dict[str, list[tuple[str | None, bytes, str]]] = {} for key, files_in_segment in files_list: for file in files_in_segment: - if file.related_id is not None or ( + if file.reference is not None or ( file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None ): file_tuple = ( diff --git a/api/dify_graph/nodes/http_request/node.py b/api/graphon/nodes/http_request/node.py similarity index 89% rename from api/dify_graph/nodes/http_request/node.py rename to api/graphon/nodes/http_request/node.py index 3e5253d809..3d74347a7f 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/graphon/nodes/http_request/node.py @@ -3,17 +3,21 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayFileSegment -from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base import variable_template_parser +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.http_request.executor import Executor +from graphon.nodes.protocols import ( + FileManagerProtocol, + FileReferenceFactoryProtocol, + HttpClientProtocol, + ToolFileManagerProtocol, +) +from graphon.variables.segments import ArrayFileSegment from .config import build_http_request_config, resolve_http_request_config from .entities import ( @@ -28,8 +32,8 @@ from .exc import HttpRequestNodeError, RequestBodyError logger = logging.getLogger(__name__) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class HttpRequestNode(Node[HttpRequestNodeData]): @@ -46,6 +50,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): http_client: HttpClientProtocol, tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], file_manager: FileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, ) -> None: super().__init__( id=id, @@ -58,6 +63,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): self._http_client = http_client self._tool_file_manager_factory = tool_file_manager_factory self._file_manager = file_manager + self._file_reference_factory = file_reference_factory @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -212,7 +218,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ - dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -237,20 +242,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - conversation_id=None, file_binary=content, mimetype=mime_type, ) - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=dify_ctx.tenant_id, + file = self._file_reference_factory.build_from_mapping( + mapping={ + "tool_file_id": tool_file.id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } ) files.append(file) diff --git a/api/dify_graph/nodes/human_input/__init__.py b/api/graphon/nodes/human_input/__init__.py similarity index 100% rename from api/dify_graph/nodes/human_input/__init__.py rename to api/graphon/nodes/human_input/__init__.py diff --git a/api/graphon/nodes/human_input/entities.py b/api/graphon/nodes/human_input/entities.py new file mode 100644 index 0000000000..aa01bde145 --- /dev/null +++ b/api/graphon/nodes/human_input/entities.py @@ -0,0 +1,208 @@ +"""Human Input node entities. + +The graph package owns the workflow-facing form schema and keeps it transportable +across runtimes. Dify-specific delivery surface and recipient translation stay +outside `graphon`. +""" + +import re +from collections.abc import Mapping, Sequence +from datetime import datetime, timedelta +from typing import Any, Self + +from pydantic import BaseModel, Field, field_validator, model_validator + +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.variables.consts import SELECTORS_LENGTH + +from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit + +_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") + + +class FormInputDefault(BaseModel): + """Default configuration for form inputs.""" + + # NOTE: Ideally, a discriminated union would be used to model + # FormInputDefault. However, the UI requires preserving the previous + # value when switching between `VARIABLE` and `CONSTANT` types. This + # necessitates retaining all fields, making a discriminated union unsuitable. + + type: PlaceholderType + + # The selector of default variable, used when `type` is `VARIABLE`. + selector: Sequence[str] = Field(default_factory=tuple) # + + # The value of the default, used when `type` is `CONSTANT`. + # TODO: How should we express JSON values? + value: str = "" + + @model_validator(mode="after") + def _validate_selector(self) -> Self: + if self.type == PlaceholderType.CONSTANT: + return self + if len(self.selector) < SELECTORS_LENGTH: + raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") + return self + + +class FormInput(BaseModel): + """Form input definition.""" + + type: FormInputType + output_variable_name: str + default: FormInputDefault | None = None + + +_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +class UserAction(BaseModel): + """User action configuration.""" + + # id is the identifier for this action. + # It also serves as the identifiers of output handle. + # + # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) + id: str = Field(max_length=20) + title: str = Field(max_length=20) + button_style: ButtonStyle = ButtonStyle.DEFAULT + + @field_validator("id") + @classmethod + def _validate_id(cls, value: str) -> str: + if not _IDENTIFIER_PATTERN.match(value): + raise ValueError( + f"'{value}' is not a valid identifier. It must start with a letter or underscore, " + f"and contain only letters, numbers, or underscores." + ) + return value + + +class HumanInputNodeData(BaseNodeData): + """Human Input node data.""" + + type: NodeType = BuiltinNodeTypes.HUMAN_INPUT + form_content: str = "" + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + timeout: int = 36 + timeout_unit: TimeoutUnit = TimeoutUnit.HOUR + + @field_validator("inputs") + @classmethod + def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: + seen_names: set[str] = set() + for form_input in inputs: + name = form_input.output_variable_name + if name in seen_names: + raise ValueError(f"duplicated output_variable_name '{name}' in inputs") + seen_names.add(name) + return inputs + + @field_validator("user_actions") + @classmethod + def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: + seen_ids: set[str] = set() + for action in user_actions: + action_id = action.id + if action_id in seen_ids: + raise ValueError(f"duplicated user action id '{action_id}'") + seen_ids.add(action_id) + return user_actions + + def expiration_time(self, start_time: datetime) -> datetime: + if self.timeout_unit == TimeoutUnit.HOUR: + return start_time + timedelta(hours=self.timeout) + elif self.timeout_unit == TimeoutUnit.DAY: + return start_time + timedelta(days=self.timeout) + else: + raise AssertionError("unknown timeout unit.") + + def outputs_field_names(self) -> Sequence[str]: + field_names = [] + for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): + field_names.append(match.group("field_name")) + return field_names + + def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: + variable_mappings: dict[str, Sequence[str]] = {} + + def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: + for selector in selectors: + if len(selector) < SELECTORS_LENGTH: + continue + qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" + variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) + + form_template_parser = VariableTemplateParser(template=self.form_content) + _add_variable_selectors( + [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] + ) + + for input in self.inputs: + default_value = input.default + if default_value is None: + continue + if default_value.type == PlaceholderType.CONSTANT: + continue + default_value_key = ".".join(default_value.selector) + qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" + variable_mappings[qualified_variable_mapping_key] = default_value.selector + + return variable_mappings + + def find_action_text(self, action_id: str) -> str: + """ + Resolve action display text by id. + """ + for action in self.user_actions: + if action.id == action_id: + return action.title + return action_id + + +class FormDefinition(BaseModel): + form_content: str + inputs: list[FormInput] = Field(default_factory=list) + user_actions: list[UserAction] = Field(default_factory=list) + rendered_content: str + expiration_time: datetime + + # this is used to store the resolved default values + default_values: dict[str, Any] = Field(default_factory=dict) + + # node_title records the title of the HumanInput node. + node_title: str | None = None + + # display_in_ui controls whether the form should be displayed in UI surfaces. + display_in_ui: bool | None = None + + +class HumanInputSubmissionValidationError(ValueError): + pass + + +def validate_human_input_submission( + *, + inputs: Sequence[FormInput], + user_actions: Sequence[UserAction], + selected_action_id: str, + form_data: Mapping[str, Any], +) -> None: + available_actions = {action.id for action in user_actions} + if selected_action_id not in available_actions: + raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") + + provided_inputs = set(form_data.keys()) + missing_inputs = [ + form_input.output_variable_name + for form_input in inputs + if form_input.output_variable_name not in provided_inputs + ] + + if missing_inputs: + missing_list = ", ".join(missing_inputs) + raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/graphon/nodes/human_input/enums.py similarity index 76% rename from api/dify_graph/nodes/human_input/enums.py rename to api/graphon/nodes/human_input/enums.py index da85728828..3fb0ab4499 100644 --- a/api/dify_graph/nodes/human_input/enums.py +++ b/api/graphon/nodes/human_input/enums.py @@ -25,16 +25,6 @@ class HumanInputFormKind(enum.StrEnum): DELIVERY_TEST = enum.auto() # Form created for delivery tests. -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - class ButtonStyle(enum.StrEnum): """Button styles for user actions.""" @@ -63,10 +53,3 @@ class PlaceholderType(enum.StrEnum): VARIABLE = enum.auto() CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/graphon/nodes/human_input/human_input_node.py similarity index 65% rename from api/dify_graph/nodes/human_input/human_input_node.py rename to api/graphon/nodes/human_input/human_input_node.py index 794e33d92e..fe04022877 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/graphon/nodes/human_input/human_input_node.py @@ -1,39 +1,33 @@ import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, cast -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, NodeRunResult, PauseRequestedEvent, ) -from dify_graph.node_events.base import NodeEventBase -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from libs.datetime_utils import naive_utc_now +from graphon.node_events.base import NodeEventBase +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType +from .entities import HumanInputNodeData +from .enums import HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: - from dify_graph.entities.graph_init_params import GraphInitParams - from dify_graph.runtime.graph_runtime_state import GraphRuntimeState + from graphon.entities.graph_init_params import GraphInitParams + from graphon.runtime.graph_runtime_state import GraphRuntimeState _SELECTED_BRANCH_KEY = "selected_branch" -_INVOKE_FROM_DEBUGGER = "debugger" -_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -56,7 +50,6 @@ class HumanInputNode(Node[HumanInputNodeData]): ) _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository _OUTPUT_FIELD_ACTION_ID = "__action_id" _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" @@ -67,7 +60,8 @@ class HumanInputNode(Node[HumanInputNodeData]): config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository, + runtime: HumanInputNodeRuntimeProtocol | None = None, + form_repository: object | None = None, ) -> None: super().__init__( id=id, @@ -75,7 +69,14 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._form_repository = form_repository + resolved_runtime = runtime + if resolved_runtime is None: + raise ValueError("runtime is required") + if form_repository is not None: + with_form_repository = getattr(resolved_runtime, "with_form_repository", None) + if callable(with_form_repository): + resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) + self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime @classmethod def version(cls) -> str: @@ -128,13 +129,7 @@ class HumanInputNode(Node[HumanInputNodeData]): return None - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): + def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): required_event = self._human_input_required_event(form_entity) pause_requested_event = PauseRequestedEvent(reason=required_event) return pause_requested_event @@ -157,56 +152,16 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults - def _should_require_console_recipient(self) -> bool: - invoke_from = self._invoke_from_value() - if invoke_from == _INVOKE_FROM_DEBUGGER: - return True - if invoke_from == _INVOKE_FROM_EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - dify_ctx = self.require_dify_context() - invoke_from = self._invoke_from_value() - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=invoke_from == _INVOKE_FROM_DEBUGGER, - user_id=dify_ctx.user_id, - ) - for method in enabled_methods - ] - - def _invoke_from_value(self) -> str: - invoke_from = self.require_dify_context().invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(getattr(invoke_from, "value", invoke_from)) - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: + def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") return HumanInputRequired( form_id=form_entity.id, form_content=form_entity.rendered_content, inputs=node_data.inputs, actions=node_data.user_actions, - display_in_ui=display_in_ui, node_id=self.id, node_title=node_data.title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -217,49 +172,32 @@ class HumanInputNode(Node[HumanInputNodeData]): This method will: 1. Generate a unique form ID 2. Create form content with variable substitution - 3. Create form in database + 3. Persist the form through the configured repository 4. Send form via configured delivery methods 5. Suspend workflow execution 6. Wait for form submission to resume """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - dify_ctx = self.require_dify_context() + form = self._runtime.get_form(node_id=self.id) if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=dify_ctx.app_id, - workflow_execution_id=self._workflow_execution_id, + form_entity = self._runtime.create_form( node_id=self.id, - form_config=self._node_data, + node_data=self._node_data, rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - dify_ctx.user_id - if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} - else None - ), - backstage_recipient_required=True, ) - form_entity = self._form_repository.create_form(params) - # Create human input required event logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + "Human Input node suspended workflow for form. node_id=%s, form_id=%s", self.id, form_entity.id, ) yield self._form_to_pause_event(form_entity) return - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): + if form.status in { + HumanInputFormStatus.TIMEOUT, + HumanInputFormStatus.EXPIRED, + } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): yield HumanInputFormTimeoutEvent( node_title=self._node_data.title, expiration_time=form.expiration_time, diff --git a/api/dify_graph/nodes/if_else/__init__.py b/api/graphon/nodes/if_else/__init__.py similarity index 100% rename from api/dify_graph/nodes/if_else/__init__.py rename to api/graphon/nodes/if_else/__init__.py diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/graphon/nodes/if_else/entities.py similarity index 77% rename from api/dify_graph/nodes/if_else/entities.py rename to api/graphon/nodes/if_else/entities.py index ff09f3c023..d59b782747 100644 --- a/api/dify_graph/nodes/if_else/entities.py +++ b/api/graphon/nodes/if_else/entities.py @@ -2,9 +2,9 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.utils.condition.entities import Condition +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/graphon/nodes/if_else/if_else_node.py similarity index 87% rename from api/dify_graph/nodes/if_else/if_else_node.py rename to api/graphon/nodes/if_else/if_else_node.py index 7c0370e48c..81e934971a 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/graphon/nodes/if_else/if_else_node.py @@ -3,13 +3,13 @@ from typing import Any, Literal from typing_extensions import deprecated -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.runtime import VariablePool -from dify_graph.utils.condition.entities import Condition -from dify_graph.utils.condition.processor import ConditionProcessor +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.runtime import VariablePool +from graphon.utils.condition.entities import Condition +from graphon.utils.condition.processor import ConditionProcessor class IfElseNode(Node[IfElseNodeData]): @@ -57,8 +57,8 @@ class IfElseNode(Node[IfElseNodeData]): break else: - # TODO: Update database then remove this - # Fallback to old structure if cases are not defined + # TODO: Remove this once all graph definitions use the `cases` structure. + # Fallback to the legacy node shape when `cases` are not defined. input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, diff --git a/api/dify_graph/nodes/iteration/__init__.py b/api/graphon/nodes/iteration/__init__.py similarity index 100% rename from api/dify_graph/nodes/iteration/__init__.py rename to api/graphon/nodes/iteration/__init__.py diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/graphon/nodes/iteration/entities.py similarity index 89% rename from api/dify_graph/nodes/iteration/entities.py rename to api/graphon/nodes/iteration/entities.py index 58fd112b12..30b6e4bea8 100644 --- a/api/dify_graph/nodes/iteration/entities.py +++ b/api/graphon/nodes/iteration/entities.py @@ -3,9 +3,9 @@ from typing import Any from pydantic import Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base import BaseIterationNodeData, BaseIterationState class ErrorHandleMode(StrEnum): diff --git a/api/dify_graph/nodes/iteration/exc.py b/api/graphon/nodes/iteration/exc.py similarity index 82% rename from api/dify_graph/nodes/iteration/exc.py rename to api/graphon/nodes/iteration/exc.py index d9947e09bc..7b6af61b9d 100644 --- a/api/dify_graph/nodes/iteration/exc.py +++ b/api/graphon/nodes/iteration/exc.py @@ -20,3 +20,7 @@ class IterationGraphNotFoundError(IterationNodeError): class IterationIndexNotFoundError(IterationNodeError): """Raised when the iteration index is not found.""" + + +class ChildGraphAbortedError(IterationNodeError): + """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/graphon/nodes/iteration/iteration_node.py similarity index 74% rename from api/dify_graph/nodes/iteration/iteration_node.py rename to api/graphon/nodes/iteration/iteration_node.py index 033ec8672f..c013739653 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/graphon/nodes/iteration/iteration_node.py @@ -1,27 +1,29 @@ import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from contextlib import suppress from datetime import UTC, datetime +from threading import Lock from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import ( IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -30,16 +32,15 @@ from dify_graph.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.runtime import VariablePool -from dify_graph.variables import IntegerVariable, NoneSegment -from dify_graph.variables.segments import ArrayAnySegment, ArraySegment -from dify_graph.variables.variables import Variable -from libs.datetime_utils import naive_utc_now +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.runtime import VariablePool +from graphon.variables import IntegerVariable, NoneSegment +from graphon.variables.segments import ArrayAnySegment, ArraySegment from .exc import ( + ChildGraphAbortedError, InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -49,10 +50,10 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.context import IExecutionContext - from dify_graph.graph_engine import GraphEngine + from graphon.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) @@ -93,7 +94,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self._validate_start_node() - started_at = naive_utc_now() + started_at = datetime.now(UTC).replace(tzinfo=None) iter_run_map: dict[str, float] = {} outputs: list[object] = [] usage_accumulator = [LLMUsage.empty_usage()] @@ -199,23 +200,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_engine = self._create_graph_engine(index, item) # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Sync conversation variables after each iteration completes - self._sync_conversation_variables_from_snapshot( - self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool + try: + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, ) - ) - - # Accumulate usage from this iteration - usage_accumulator[0] = self._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) + finally: + self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -233,13 +225,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks + started_child_engines: dict[int, GraphEngine] = {} + started_child_engines_lock = Lock() + merged_usage_indexes: set[int] = set() future_to_index: dict[ Future[ tuple[ float, list[GraphNodeEventBase], object | None, - dict[str, Variable], LLMUsage, ] ], @@ -248,10 +242,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): for index, item in enumerate(iterator_list_value): yield IterationNextEvent(index=index) future = executor.submit( - self._execute_single_iteration_parallel, + self._execute_tracked_iteration_parallel, index=index, item=item, - execution_context=self._capture_execution_context(), + started_child_engines=started_child_engines, + started_child_engines_lock=started_child_engines_lock, ) future_to_index[future] = index @@ -264,7 +259,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iteration_duration, events, output_value, - conversation_snapshot, iteration_usage, ) = result @@ -279,11 +273,31 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iter_run_map[str(index)] = iteration_duration usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - - # Sync conversation variables after iteration completion - self._sync_conversation_variables_from_snapshot(conversation_snapshot) + merged_usage_indexes.add(index) except Exception as e: + if index not in merged_usage_indexes: + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + if isinstance(e, ChildGraphAbortedError): + self._abort_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, + ) + self._drain_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + usage_accumulator=usage_accumulator, + merged_usage_indexes=merged_usage_indexes, + ) + raise e + # Handle errors based on error_handle_mode match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -301,48 +315,118 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs[:] = [output for output in outputs if output is not None] + @staticmethod + def _merge_graph_engine_usage( + *, + usage_accumulator: list[LLMUsage], + graph_engine: "GraphEngine | None", + ) -> None: + if graph_engine is None: + return + usage_accumulator[0] = IterationNode._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) + + def _abort_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + reason: str, + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + + graph_engine = started_child_engines.get(index) + if graph_engine is not None: + graph_engine.request_abort(reason) + + future.cancel() + + def _drain_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + usage_accumulator: list[LLMUsage], + merged_usage_indexes: set[int], + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + if future.cancelled(): + continue + + with suppress(Exception): + future.result() + + if index in merged_usage_indexes: + continue + + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + + def _execute_tracked_iteration_parallel( + self, + *, + index: int, + item: object, + started_child_engines: dict[int, "GraphEngine"], + started_child_engines_lock: Lock, + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + graph_engine = self._create_graph_engine(index, item) + with started_child_engines_lock: + started_child_engines[index] = graph_engine + + return self._execute_parallel_iteration_with_graph_engine( + index=index, + graph_engine=graph_engine, + ) + def _execute_single_iteration_parallel( self, index: int, item: object, - execution_context: "IExecutionContext", - ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: """Execute a single iteration in parallel mode and return results.""" - with execution_context: - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] + graph_engine = self._create_graph_engine(index, item) + return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) - graph_engine = self._create_graph_engine(index, item) + def _execute_parallel_iteration_with_graph_engine( + self, + *, + index: int, + graph_engine: "GraphEngine", + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + """Execute a prepared child engine in parallel mode and return results.""" + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + events: list[GraphNodeEventBase] = [] + outputs_temp: list[object] = [] - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) + # Collect events instead of yielding them directly + for event in self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs_temp, + graph_engine=graph_engine, + ): + events.append(event) - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - conversation_snapshot = self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # Get the output value from the temporary outputs list + output_value = outputs_temp[0] if outputs_temp else None + iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - return ( - iteration_duration, - events, - output_value, - conversation_snapshot, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _capture_execution_context(self) -> "IExecutionContext": - """Capture current execution context for parallel iterations.""" - from dify_graph.context import capture_current_context - - return capture_current_context() + return ( + iteration_duration, + events, + output_value, + graph_engine.graph_runtime_state.llm_usage, + ) def _handle_iteration_success( self, @@ -516,23 +600,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: - conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: - parent_pool = self.graph_runtime_state.variable_pool - parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - - current_keys = set(parent_conversations.keys()) - snapshot_keys = set(snapshot.keys()) - - for removed_key in current_keys - snapshot_keys: - parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) - - for name, variable in snapshot.items(): - parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) - def _append_iteration_info_to_event( self, event: GraphNodeEventBase, @@ -575,6 +642,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): else: outputs.append(result.to_object()) return + elif isinstance(event, GraphRunAbortedEvent): + raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) elif isinstance(event, GraphRunFailedEvent): match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -586,8 +655,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return def _create_graph_engine(self, index: int, item: object): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import ChildGraphNotFoundError # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -602,14 +671,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # append iteration variable (item, index) to variable pool variable_pool_copy.add([self._node_id, "index"], index) variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) root_node_id = self.node_data.start_node_id if root_node_id is None: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") @@ -618,9 +679,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, + variable_pool=variable_pool_copy, ) except ChildGraphNotFoundError as exc: raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/iteration/iteration_start_node.py b/api/graphon/nodes/iteration/iteration_start_node.py similarity index 61% rename from api/dify_graph/nodes/iteration/iteration_start_node.py rename to api/graphon/nodes/iteration/iteration_start_node.py index a8ecf3d83b..3a44d3d81d 100644 --- a/api/dify_graph/nodes/iteration/iteration_start_node.py +++ b/api/graphon/nodes/iteration/iteration_start_node.py @@ -1,7 +1,7 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import IterationStartNodeData +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.iteration.entities import IterationStartNodeData class IterationStartNode(Node[IterationStartNodeData]): diff --git a/api/dify_graph/nodes/list_operator/__init__.py b/api/graphon/nodes/list_operator/__init__.py similarity index 100% rename from api/dify_graph/nodes/list_operator/__init__.py rename to api/graphon/nodes/list_operator/__init__.py diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/graphon/nodes/list_operator/entities.py similarity index 93% rename from api/dify_graph/nodes/list_operator/entities.py rename to api/graphon/nodes/list_operator/entities.py index 41b3a40b78..0db1c75cdd 100644 --- a/api/dify_graph/nodes/list_operator/entities.py +++ b/api/graphon/nodes/list_operator/entities.py @@ -3,8 +3,8 @@ from enum import StrEnum from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class FilterOperator(StrEnum): diff --git a/api/dify_graph/nodes/list_operator/exc.py b/api/graphon/nodes/list_operator/exc.py similarity index 100% rename from api/dify_graph/nodes/list_operator/exc.py rename to api/graphon/nodes/list_operator/exc.py diff --git a/api/dify_graph/nodes/list_operator/node.py b/api/graphon/nodes/list_operator/node.py similarity index 97% rename from api/dify_graph/nodes/list_operator/node.py rename to api/graphon/nodes/list_operator/node.py index dc8b8904f7..dad17a8f4a 100644 --- a/api/dify_graph/nodes/list_operator/node.py +++ b/api/graphon/nodes/list_operator/node.py @@ -1,12 +1,12 @@ from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from dify_graph.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from graphon.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError diff --git a/api/dify_graph/nodes/llm/__init__.py b/api/graphon/nodes/llm/__init__.py similarity index 100% rename from api/dify_graph/nodes/llm/__init__.py rename to api/graphon/nodes/llm/__init__.py diff --git a/api/dify_graph/nodes/llm/entities.py b/api/graphon/nodes/llm/entities.py similarity index 89% rename from api/dify_graph/nodes/llm/entities.py rename to api/graphon/nodes/llm/entities.py index 6ca01a21da..196152548c 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/graphon/nodes/llm/entities.py @@ -3,11 +3,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode -from dify_graph.nodes.base.entities import VariableSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode +from graphon.nodes.base.entities import VariableSelector +from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig class ModelConfig(BaseModel): diff --git a/api/dify_graph/nodes/llm/exc.py b/api/graphon/nodes/llm/exc.py similarity index 100% rename from api/dify_graph/nodes/llm/exc.py rename to api/graphon/nodes/llm/exc.py diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/graphon/nodes/llm/file_saver.py similarity index 77% rename from api/dify_graph/nodes/llm/file_saver.py rename to api/graphon/nodes/llm/file_saver.py index 50e52a3b6f..0bedb42f3a 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/graphon/nodes/llm/file_saver.py @@ -1,11 +1,9 @@ import mimetypes import typing as tp -from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.tools.signature import sign_tool_file -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.protocols import HttpClientProtocol +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol class LLMFileSaver(tp.Protocol): @@ -57,17 +55,20 @@ class LLMFileSaver(tp.Protocol): class FileSaverImpl(LLMFileSaver): - _tenant_id: str - _user_id: str + _tool_file_manager: ToolFileManagerProtocol + _file_reference_factory: FileReferenceFactoryProtocol - def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): - self._user_id = user_id - self._tenant_id = tenant_id + def __init__( + self, + *, + tool_file_manager: ToolFileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, + http_client: HttpClientProtocol, + ): + self._tool_file_manager = tool_file_manager + self._file_reference_factory = file_reference_factory self._http_client = http_client - def _get_tool_file_manager(self): - return ToolFileManager() - def save_remote_url(self, url: str, file_type: FileType) -> File: http_response = self._http_client.get(url) http_response.raise_for_status() @@ -83,30 +84,24 @@ class FileSaverImpl(LLMFileSaver): file_type: FileType, extension_override: str | None = None, ) -> File: - tool_file_manager = self._get_tool_file_manager() - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - # TODO(QuantumGhost): what is conversation id? - conversation_id=None, + tool_file = self._tool_file_manager.create_file_by_raw( file_binary=data, mimetype=mime_type, ) extension_override = _validate_extension_override(extension_override) extension = _get_extension(mime_type, extension_override) - url = sign_tool_file(tool_file.id, extension) - - return File( - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=tool_file.name, - extension=extension, - mime_type=mime_type, - size=len(data), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, + return self._file_reference_factory.build_from_mapping( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": tool_file.name, + "extension": extension, + "mime_type": mime_type, + "size": len(data), + "tool_file_id": str(tool_file.id), + "related_id": str(tool_file.id), + "storage_key": tool_file.file_key, + } ) diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/graphon/nodes/llm/llm_utils.py similarity index 78% rename from api/dify_graph/nodes/llm/llm_utils.py rename to api/graphon/nodes/llm/llm_utils.py index 2be391a424..11a1d83a9d 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/graphon/nodes/llm/llm_utils.py @@ -1,31 +1,33 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, cast +import json +import logging +import re +from collections.abc import Mapping, Sequence +from typing import Any -from core.model_manager import ModelInstance -from dify_graph.file import FileType, file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( +from graphon.file import FileType, file_manager +from graphon.file.models import File +from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentType, PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, SystemPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment, FileSegment -from dify_graph.variables.segments import ArrayAnySegment, NoneSegment +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.nodes.base.entities import VariableSelector +from graphon.runtime import VariablePool +from graphon.template_rendering import Jinja2TemplateRenderer +from graphon.variables import ArrayFileSegment, FileSegment +from graphon.variables.segments import ArrayAnySegment, NoneSegment from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig from .exc import ( @@ -34,16 +36,20 @@ from .exc import ( NoPromptFoundError, TemplateTypeNotSupportError, ) -from .protocols import TemplateRenderer +from .runtime_protocols import PreparedLLMProtocol + +CONTEXT_PLACEHOLDER = "{{#context#}}" + +logger = logging.getLogger(__name__) + +VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") +MAX_RESOLVED_VALUE_LENGTH = 1024 -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - dict(model_instance.credentials), - ) +def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: + model_schema = model_instance.get_model_schema() if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") + raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") return model_schema @@ -114,9 +120,9 @@ def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence[File], - context: str | None = None, + context: str = "", memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -125,7 +131,7 @@ def fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] model_schema = fetch_model_schema(model_instance=model_instance) @@ -277,11 +283,11 @@ def fetch_prompt_messages( def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] for message in messages: @@ -300,7 +306,7 @@ def handle_list_messages( ) continue - template = message.text.replace("{#context#}", context) if context else message.text + template = message.text.replace(CONTEXT_PLACEHOLDER, context) segment_group = variable_pool.convert_template(template) file_contents: list[PromptMessageContentUnionTypes] = [] for segment in segment_group.value: @@ -335,7 +341,7 @@ def render_jinja2_message( template: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> str: if not template: return "" @@ -346,16 +352,16 @@ def render_jinja2_message( for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) + return template_renderer.render_template(template, jinja2_inputs) def handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: if template.edition_type == "jinja2": result_text = render_jinja2_message( @@ -365,7 +371,7 @@ def handle_completion_template( template_renderer=template_renderer, ) else: - template_text = template.text.replace("{#context#}", context) if context else template.text + template_text = template.text.replace(CONTEXT_PLACEHOLDER, context) result_text = variable_pool.convert_template(template_text).text return [ combine_message_content_with_role( @@ -391,7 +397,11 @@ def combine_message_content_with_role( raise NotImplementedError(f"Role {role} is not supported") -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: +def calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: rest_tokens = 2000 runtime_model_schema = fetch_model_schema(model_instance=model_instance) runtime_model_parameters = model_instance.parameters @@ -421,7 +431,7 @@ def handle_memory_chat_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, ) -> Sequence[PromptMessage]: if not memory or not memory_config: return [] @@ -436,7 +446,7 @@ def handle_memory_completion_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, ) -> str: if not memory or not memory_config: return "" @@ -475,3 +485,61 @@ def _append_file_prompts( prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + + +def _coerce_resolved_value(raw: str) -> int | float | bool | str: + """Try to restore the original type from a resolved template string. + + Variable references are always resolved to text, but completion params may + expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to + the ``temperature`` parameter). This helper attempts a JSON parse so that + ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not + valid JSON literals are returned as-is. + """ + stripped = raw.strip() + if not stripped: + return raw + + try: + parsed: object = json.loads(stripped) + except (json.JSONDecodeError, ValueError): + return raw + + if isinstance(parsed, (int, float, bool)): + return parsed + return raw + + +def resolve_completion_params_variables( + completion_params: Mapping[str, Any], + variable_pool: VariablePool, +) -> dict[str, Any]: + """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. + + Security notes: + - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to + prevent denial-of-service through excessively large variable payloads. + - This follows the same ``VariablePool.convert_template`` pattern used across + Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream + model plugin receives these values as structured JSON key-value pairs — they + are never concatenated into raw HTTP headers or SQL queries. + - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are + restored to their native type rather than sent as a bare string. + """ + resolved: dict[str, Any] = {} + for key, value in completion_params.items(): + if isinstance(value, str) and VARIABLE_PATTERN.search(value): + segment_group = variable_pool.convert_template(value) + text = segment_group.text + if len(text) > MAX_RESOLVED_VALUE_LENGTH: + logger.warning( + "Resolved value for param '%s' truncated from %d to %d chars", + key, + len(text), + MAX_RESOLVED_VALUE_LENGTH, + ) + text = text[:MAX_RESOLVED_VALUE_LENGTH] + resolved[key] = _coerce_resolved_value(text) + else: + resolved[key] = value + return resolved diff --git a/api/dify_graph/nodes/llm/node.py b/api/graphon/nodes/llm/node.py similarity index 62% rename from api/dify_graph/nodes/llm/node.py rename to api/graphon/nodes/llm/node.py index 5ed90ed7e3..4de2a95465 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/graphon/nodes/llm/node.py @@ -7,33 +7,24 @@ import logging import re import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast -from sqlalchemy import select - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.tools.signature import sign_upload_file -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, NodeType, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities import ( +from graphon.file import File, FileType, file_manager +from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkWithStructuredOutput, @@ -41,10 +32,17 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, @@ -52,22 +50,26 @@ from dify_graph.node_events import ( StreamChunkEvent, StreamCompletedEvent, ) -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import VariablePool -from dify_graph.variables import ( +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import HttpClientProtocol +from graphon.prompt_entities import CompletionModelPromptTemplate, MemoryConfig +from graphon.runtime import VariablePool +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError +from graphon.variables import ( ArrayFileSegment, ArraySegment, + FileSegment, NoneSegment, ObjectSegment, StringSegment, ) -from extensions.ext_database import db -from models.dataset import SegmentAttachmentBinding -from models.model import UploadFile from . import llm_utils from .entities import ( @@ -79,13 +81,16 @@ from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, VariableNotFoundError, ) -from .file_saver import FileSaverImpl, LLMFileSaver +from .file_saver import LLMFileSaver if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file.models import File + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -101,11 +106,12 @@ class LLMNode(Node[LLMNodeData]): _file_outputs: list[File] _llm_file_saver: LLMFileSaver - _credentials_provider: CredentialsProvider - _model_factory: ModelFactory - _model_instance: ModelInstance + _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None + _prompt_message_serializer: PromptMessageSerializerProtocol + _jinja2_template_renderer: Jinja2TemplateRenderer | None + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _default_query_selector: tuple[str, ...] | None def __init__( self, @@ -114,13 +120,16 @@ class LLMNode(Node[LLMNodeData]): graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol, + retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, + default_query_selector: Sequence[str] | None = None, ): super().__init__( id=id, @@ -131,20 +140,15 @@ class LLMNode(Node[LLMNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory - self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer + self._retriever_attachment_loader = retriever_attachment_loader + self._jinja2_template_renderer = jinja2_template_renderer + self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None @classmethod def version(cls) -> str: @@ -190,10 +194,11 @@ class LLMNode(Node[LLMNodeData]): generator = self._fetch_context(node_data=self.node_data) context = None context_files: list[File] = [] - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event + if generator is not None: + for event in generator: + context = event.context + context_files = event.context_files or [] + yield event if context: node_inputs["#context#"] = context @@ -202,6 +207,10 @@ class LLMNode(Node[LLMNodeData]): # fetch model config model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) model_name = model_instance.model_name model_provider = model_instance.provider model_stop = model_instance.stop @@ -211,15 +220,17 @@ class LLMNode(Node[LLMNodeData]): query: str | None = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template - if not query and ( - query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if ( + not query + and self._default_query_selector + and (query_variable := variable_pool.get(self._default_query_selector)) ): query = query_variable.text prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, - context=context, + context=context or "", memory=memory, model_instance=model_instance, stop=model_stop, @@ -230,7 +241,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, - template_renderer=self._template_renderer, + jinja2_template_renderer=self._jinja2_template_renderer, ) # handle invoke result @@ -238,7 +249,6 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -281,7 +291,7 @@ class LLMNode(Node[LLMNodeData]): process_data = { "model_mode": self.node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=self.node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -349,10 +359,9 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def invoke_llm( *, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, - user_id: str, structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, @@ -363,35 +372,35 @@ class LLMNode(Node[LLMNodeData]): ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: model_parameters = model_instance.parameters invoke_model_parameters = dict(model_parameters) - - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( structured_output=structured_output or {}, ) request_start_time = time.perf_counter() - invoke_result = invoke_llm_with_structured_output( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, + invoke_result = cast( + LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + model_instance.invoke_llm_with_structured_output( + prompt_messages=prompt_messages, + json_schema=output_schema, + model_parameters=invoke_model_parameters, + stop=stop, + stream=True, + ), ) else: request_start_time = time.perf_counter() - invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, + invoke_result = cast( + LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=invoke_model_parameters, + tools=None, + stop=stop, + stream=True, + ), ) return LLMNode.handle_invoke_result( @@ -400,6 +409,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs=file_outputs, node_id=node_id, node_type=node_type, + model_instance=model_instance, reasoning_format=reasoning_format, request_start_time=request_start_time, ) @@ -412,6 +422,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs: list[File], node_id: str, node_type: NodeType, + model_instance: PreparedLLMProtocol | object, reasoning_format: Literal["separated", "tagged"] = "tagged", request_start_time: float | None = None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: @@ -483,8 +494,14 @@ class LLMNode(Node[LLMNodeData]): usage = result.delta.usage if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - except OutputParserError as e: - raise LLMNodeError(f"Failed to parse structured output: {e}") + except Exception as e: + if hasattr(model_instance, "is_structured_output_parse_error") and cast( + PreparedLLMProtocol, model_instance + ).is_structured_output_parse_error(e): + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + if type(e).__name__ == "OutputParserError": + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + raise # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() @@ -687,30 +704,8 @@ class LLMNode(Node[LLMNodeData]): segment_id = retriever_resource.get("segment_id") if not segment_id: continue - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.segment_id == segment_id, - ) - ).all() - if attachments_with_bindings: - for _, upload_file in attachments_with_bindings: - attachment_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.require_dify_context().tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), - ) - context_files.append(attachment_info) + if self._retriever_attachment_loader is not None: + context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) yield RunRetrieverResourceEvent( retriever_resources=original_retriever_resource, context=context_str.strip(), @@ -753,9 +748,9 @@ class LLMNode(Node[LLMNodeData]): *, sys_query: str | None = None, sys_files: Sequence[File], - context: str | None = None, + context: str = "", memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -764,24 +759,186 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - return llm_utils.fetch_prompt_messages( - sys_query=sys_query, - sys_files=sys_files, - context=context, - memory=memory, - model_instance=model_instance, - prompt_template=prompt_template, - stop=stop, - memory_config=memory_config, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - variable_pool=variable_pool, - jinja2_variables=jinja2_variables, - context_files=context_files, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if sys_query: + message = LLMNodeChatModelMessage( + text=sys_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + # For issue #11247 - Check if prompt content is a string or a list + if isinstance(prompt_content, str): + prompt_content = str(prompt_content) + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + + # Add current query to the prompt message + if sys_query: + if isinstance(prompt_content, str): + prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + # The sys_files will be deprecated later + if vision_enabled and sys_files: + file_prompts = [] + for file in sys_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Remove empty messages and filter unsupported content + filtered_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, list): + prompt_message_content: list[PromptMessageContentUnionTypes] = [] + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) + continue + + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_schema.features + ) + ): + continue + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: + prompt_message.content = prompt_message_content[0].data + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue + filtered_prompt_messages.append(prompt_message) + + if len(filtered_prompt_messages) == 0: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping( @@ -825,9 +982,6 @@ class LLMNode(Node[LLMNodeData]): if node_data.vision.enabled: variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if node_data.prompt_config: enable_jinja = False @@ -877,20 +1031,62 @@ class LLMNode(Node[LLMNodeData]): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: - return llm_utils.handle_list_messages( - messages=messages, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail_config, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + template = message.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages @staticmethod def handle_blocking_result( @@ -1027,5 +1223,150 @@ class LLMNode(Node[LLMNodeData]): return self.node_data.retry_config.retry_enabled @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance + + +def _combine_message_content_with_role( + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole +): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + case _: + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None, +): + if not template: + return "" + + jinja2_inputs = {} + for jinja2_variable in jinja2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + if jinja2_template_renderer is None: + raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") + return jinja2_template_renderer.render_template(template, jinja2_inputs) + + +def _calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: + rest_tokens = 2000 + runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters + + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in runtime_model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> Sequence[PromptMessage]: + memory_messages: Sequence[PromptMessage] = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + else: + template_text = template.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/graphon/nodes/llm/protocols.py b/api/graphon/nodes/llm/protocols.py new file mode 100644 index 0000000000..65bfd533d1 --- /dev/null +++ b/api/graphon/nodes/llm/protocols.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol + + +class CredentialsProvider(Protocol): + """Port for loading runtime credentials for a provider/model pair.""" + + def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + """Return credentials for the target provider/model or raise a domain error.""" + ... + + +class ModelFactory(Protocol): + """Port for creating prepared graph-facing LLM runtimes for execution.""" + + def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: + """Create a prepared LLM runtime that is ready for graph execution.""" + ... diff --git a/api/graphon/nodes/llm/runtime_protocols.py b/api/graphon/nodes/llm/runtime_protocols.py new file mode 100644 index 0000000000..dbe415d363 --- /dev/null +++ b/api/graphon/nodes/llm/runtime_protocols.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Protocol + +from graphon.file import File +from graphon.model_runtime.entities import LLMMode, PromptMessage +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity + + +class PreparedLLMProtocol(Protocol): + """A graph-facing LLM runtime with provider-specific setup already applied.""" + + @property + def provider(self) -> str: ... + + @property + def model_name(self) -> str: ... + + @property + def parameters(self) -> Mapping[str, Any]: ... + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: ... + + @property + def stop(self) -> Sequence[str] | None: ... + + def get_model_schema(self) -> AIModelEntity: ... + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def is_structured_output_parse_error(self, error: Exception) -> bool: ... + + +class PromptMessageSerializerProtocol(Protocol): + """Port for converting compiled prompt messages into persisted process data.""" + + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: ... + + +class RetrieverAttachmentLoaderProtocol(Protocol): + """Port for resolving retriever segment attachments into graph file references.""" + + def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/dify_graph/nodes/loop/__init__.py b/api/graphon/nodes/loop/__init__.py similarity index 100% rename from api/dify_graph/nodes/loop/__init__.py rename to api/graphon/nodes/loop/__init__.py diff --git a/api/dify_graph/nodes/loop/entities.py b/api/graphon/nodes/loop/entities.py similarity index 88% rename from api/dify_graph/nodes/loop/entities.py rename to api/graphon/nodes/loop/entities.py index f0bfad5a0f..e7362769e9 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/graphon/nodes/loop/entities.py @@ -3,11 +3,11 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState -from dify_graph.utils.condition.entities import Condition -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base import BaseLoopNodeData, BaseLoopState +from graphon.utils.condition.entities import Condition +from graphon.variables.types import SegmentType _VALID_VAR_TYPE = frozenset( [ diff --git a/api/dify_graph/nodes/loop/loop_end_node.py b/api/graphon/nodes/loop/loop_end_node.py similarity index 60% rename from api/dify_graph/nodes/loop/loop_end_node.py rename to api/graphon/nodes/loop/loop_end_node.py index 0287708fb3..c0562b59c4 100644 --- a/api/dify_graph/nodes/loop/loop_end_node.py +++ b/api/graphon/nodes/loop/loop_end_node.py @@ -1,7 +1,7 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopEndNodeData +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.loop.entities import LoopEndNodeData class LoopEndNode(Node[LoopEndNodeData]): diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/graphon/nodes/loop/loop_node.py similarity index 91% rename from api/dify_graph/nodes/loop/loop_node.py rename to api/graphon/nodes/loop/loop_node.py index 3c546ffa23..d574e9f7ae 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/graphon/nodes/loop/loop_node.py @@ -2,23 +2,24 @@ import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import datetime +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Literal, cast -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import ( LoopFailedEvent, LoopNextEvent, LoopStartedEvent, @@ -27,18 +28,17 @@ from dify_graph.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from dify_graph.utils.condition.processor import ConditionProcessor -from dify_graph.variables import Segment, SegmentType -from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable -from libs.datetime_utils import naive_utc_now +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData +from graphon.utils.condition.processor import ConditionProcessor +from graphon.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable if TYPE_CHECKING: - from dify_graph.graph_engine import GraphEngine + from graphon.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): @@ -91,7 +91,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value - start_at = naive_utc_now() + start_at = datetime.now(UTC).replace(tzinfo=None) condition_processor = ConditionProcessor() loop_duration_map: dict[str, float] = {} @@ -124,10 +124,13 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): self._clear_loop_subgraph_variables(loop_node_ids) graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - loop_start_time = naive_utc_now() - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + loop_start_time = datetime.now(UTC).replace(tzinfo=None) + try: + reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + finally: + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) # Track loop duration - loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds() + loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() # Accumulate outputs from the sub-graph's response nodes for key, value in graph_engine.graph_runtime_state.outputs.items(): @@ -142,9 +145,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # For other outputs, just update self.graph_runtime_state.set_output(key, value) - # Accumulate usage from the sub-graph execution - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -256,6 +256,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): yield event if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: reach_break_node = True + if isinstance(event, GraphRunAbortedEvent): + raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) if isinstance(event, GraphRunFailedEvent): raise Exception(event.error) @@ -409,8 +411,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return build_segment_with_type(var_type, value) def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -420,16 +421,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): call_depth=self.workflow_call_depth, ) - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, ) diff --git a/api/dify_graph/nodes/loop/loop_start_node.py b/api/graphon/nodes/loop/loop_start_node.py similarity index 60% rename from api/dify_graph/nodes/loop/loop_start_node.py rename to api/graphon/nodes/loop/loop_start_node.py index e171b4df2f..2b17054ae2 100644 --- a/api/dify_graph/nodes/loop/loop_start_node.py +++ b/api/graphon/nodes/loop/loop_start_node.py @@ -1,7 +1,7 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopStartNodeData +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.loop.entities import LoopStartNodeData class LoopStartNode(Node[LoopStartNodeData]): diff --git a/api/dify_graph/nodes/parameter_extractor/__init__.py b/api/graphon/nodes/parameter_extractor/__init__.py similarity index 100% rename from api/dify_graph/nodes/parameter_extractor/__init__.py rename to api/graphon/nodes/parameter_extractor/__init__.py diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/graphon/nodes/parameter_extractor/entities.py similarity index 93% rename from api/dify_graph/nodes/parameter_extractor/entities.py rename to api/graphon/nodes/parameter_extractor/entities.py index 2fb042c16c..8fda1b9e79 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/graphon/nodes/parameter_extractor/entities.py @@ -7,11 +7,11 @@ from pydantic import ( field_validator, ) -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig +from graphon.prompt_entities import MemoryConfig +from graphon.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" _OLD_SELECT_TYPE_NAME = "select" diff --git a/api/dify_graph/nodes/parameter_extractor/exc.py b/api/graphon/nodes/parameter_extractor/exc.py similarity index 97% rename from api/dify_graph/nodes/parameter_extractor/exc.py rename to api/graphon/nodes/parameter_extractor/exc.py index c25b809d1c..faa90313c1 100644 --- a/api/dify_graph/nodes/parameter_extractor/exc.py +++ b/api/graphon/nodes/parameter_extractor/exc.py @@ -1,6 +1,6 @@ from typing import Any -from dify_graph.variables.types import SegmentType +from graphon.variables.types import SegmentType class ParameterExtractorNodeError(ValueError): diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py similarity index 83% rename from api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py rename to api/graphon/nodes/parameter_extractor/parameter_extractor_node.py index 3913a27697..25379e325c 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/graphon/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,21 +5,16 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.model_manager import ModelInstance -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import File +from graphon.model_runtime.entities import ImagePromptMessageContent, LLMMode +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -27,17 +22,18 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm import llm_utils -from dify_graph.runtime import VariablePool -from dify_graph.variables.types import ArrayValidation, SegmentType -from factories.variable_factory import build_segment_with_type +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import NodeRunResult +from graphon.nodes.base import variable_template_parser +from graphon.nodes.base.node import Node +from graphon.nodes.llm import LLMNode, llm_utils +from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol +from graphon.runtime import VariablePool +from graphon.variables import build_segment_with_type +from graphon.variables.types import ArrayValidation, SegmentType from .entities import ParameterExtractorNodeData from .exc import ( @@ -65,9 +61,8 @@ from .prompts import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState def extract_json(text): @@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - _model_instance: ModelInstance - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" + _model_instance: PreparedLLMProtocol + _prompt_message_serializer: PromptMessageSerializerProtocol _memory: PromptMessageMemory | None def __init__( @@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None = None, + prompt_message_serializer: PromptMessageSerializerProtocol, ) -> None: super().__init__( id=id, @@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory self._model_instance = model_instance + self._prompt_message_serializer = prompt_message_serializer self._memory = memory @classmethod @@ -164,13 +159,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) model_instance = self._model_instance - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc + if model_schema.model_type != ModelType.LLM: + raise InvalidModelTypeError("Model is not a Large Language Model") memory = self._memory if ( @@ -210,8 +208,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages + "prompts": self._prompt_message_serializer.serialize( + model_mode=node_data.model.mode, + prompt_messages=prompt_messages, ), "usage": None, "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), @@ -287,18 +286,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def _invoke( self, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], - stop: Sequence[str], + stop: Sequence[str] | None, ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools, - stop=list(stop), - stream=False, - user=self.require_dify_context().user_id, + invoke_result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=dict(model_instance.parameters), + tools=tools or None, + stop=stop, + stream=False, + ), ) # handle invoke result @@ -317,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -329,7 +330,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): content=query, structure=json.dumps(node_data.get_parameter_json_schema()) ) - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -340,15 +340,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_function_calling_prompt_template( node_data, query, variable_pool, memory, rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -405,7 +401,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -413,9 +409,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate prompt engineering prompt. """ - model_mode = ModelMode(data.model.mode) - - if model_mode == ModelMode.COMPLETION: + if data.model.mode == LLMMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( node_data=data, query=query, @@ -425,7 +419,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - elif model_mode == ModelMode.CHAT: + if data.model.mode == LLMMode.CHAT: return self._generate_prompt_engineering_chat_prompt( node_data=data, query=query, @@ -435,15 +429,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - else: - raise InvalidModelModeError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") def _generate_prompt_engineering_completion_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -451,7 +444,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate completion prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -462,27 +454,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, - query="", - files=files, - context="", - memory_config=node_data.memory, - # AdvancedPromptTransform is still typed against TokenBufferMemory. - memory=cast(Any, memory), + return self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) - return prompt_messages - def _generate_prompt_engineering_chat_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -490,7 +475,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate chat prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -508,15 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): max_token_limit=rest_token, ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -717,8 +697,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage]: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -727,15 +706,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -744,8 +722,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -754,64 +731,54 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( + if node_data.model.mode == LLMMode.COMPLETION: + return LLMNodeCompletionModelPromptTemplate( text=COMPLETION_GENERATE_JSON_PROMPT.format( histories=memory_str, text=input_text, instruction=instruction ) .replace("{γγγ", "") .replace("}γγγ", "") + .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), ) - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _calculate_rest_token( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=[], + vision_enabled=False, + context=context, ) rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - curr_message_tokens = ( - model_type_instance.get_num_tokens( - model_instance.model_name, model_instance.credentials, prompt_messages - ) - + 1000 - ) # add 1000 to ensure tool call messages + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 max_tokens = 0 for parameter_rule in model_schema.parameter_rules: @@ -828,8 +795,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens + def _compile_prompt_messages( + self, + *, + model_instance: PreparedLLMProtocol, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + files: Sequence[File], + vision_enabled: bool, + context: str | None = "", + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> list[PromptMessage]: + prompt_messages, _ = LLMNode.fetch_prompt_messages( + sys_query="", + sys_files=files, + context=context or "", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=model_instance.stop, + memory_config=None, + vision_enabled=vision_enabled, + vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + return list(prompt_messages) + @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod diff --git a/api/dify_graph/nodes/parameter_extractor/prompts.py b/api/graphon/nodes/parameter_extractor/prompts.py similarity index 100% rename from api/dify_graph/nodes/parameter_extractor/prompts.py rename to api/graphon/nodes/parameter_extractor/prompts.py diff --git a/api/dify_graph/nodes/protocols.py b/api/graphon/nodes/protocols.py similarity index 81% rename from api/dify_graph/nodes/protocols.py rename to api/graphon/nodes/protocols.py index 62d3bcdca1..4b050c113c 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/graphon/nodes/protocols.py @@ -1,10 +1,9 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Protocol import httpx -from dify_graph.file import File -from dify_graph.file.models import ToolFile +from graphon.file import File class HttpClientProtocol(Protocol): @@ -35,12 +34,13 @@ class ToolFileManagerProtocol(Protocol): def create_file_by_raw( self, *, - user_id: str, - tenant_id: str, - conversation_id: str | None, file_binary: bytes, mimetype: str, filename: str | None = None, ) -> Any: ... - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... + + +class FileReferenceFactoryProtocol(Protocol): + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/dify_graph/nodes/question_classifier/__init__.py b/api/graphon/nodes/question_classifier/__init__.py similarity index 100% rename from api/dify_graph/nodes/question_classifier/__init__.py rename to api/graphon/nodes/question_classifier/__init__.py diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/graphon/nodes/question_classifier/entities.py similarity index 76% rename from api/dify_graph/nodes/question_classifier/entities.py rename to api/graphon/nodes/question_classifier/entities.py index 0c1601d439..8d5f117315 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/graphon/nodes/question_classifier/entities.py @@ -1,9 +1,9 @@ from pydantic import BaseModel, Field -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm import ModelConfig, VisionConfig +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm import ModelConfig, VisionConfig +from graphon.prompt_entities import MemoryConfig class ClassConfig(BaseModel): diff --git a/api/dify_graph/nodes/question_classifier/exc.py b/api/graphon/nodes/question_classifier/exc.py similarity index 100% rename from api/dify_graph/nodes/question_classifier/exc.py rename to api/graphon/nodes/question_classifier/exc.py diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/graphon/nodes/question_classifier/question_classifier_node.py similarity index 85% rename from api/dify_graph/nodes/question_classifier/question_classifier_node.py rename to api/graphon/nodes/question_classifier/question_classifier_node.py index 59d0a2a4d8..a30ffbb149 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/graphon/nodes/question_classifier/question_classifier_node.py @@ -3,34 +3,32 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.model_manager import ModelInstance -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm import ( +from graphon.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ModelInvokeCompletedEvent, NodeRunResult +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.llm import ( LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils, ) -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from libs.json_in_md_parser import parse_and_check_json_markdown +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol +from graphon.nodes.protocols import HttpClientProtocol +from graphon.template_rendering import Jinja2TemplateRenderer +from graphon.utils.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData from .exc import InvalidModelTypeError @@ -45,8 +43,14 @@ from .template_prompts import ( ) if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file.models import File + from graphon.runtime import GraphRuntimeState + + +class _PassthroughPromptMessageSerializer: + def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: + _ = model_mode + return list(prompt_messages) class QuestionClassifierNode(Node[QuestionClassifierNodeData]): @@ -55,11 +59,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _file_outputs: list["File"] _llm_file_saver: LLMFileSaver - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _model_instance: ModelInstance + _prompt_message_serializer: PromptMessageSerializerProtocol + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _template_renderer: Jinja2TemplateRenderer def __init__( self, @@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, + template_renderer: Jinja2TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol | None = None, ): super().__init__( id=id, @@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() @classmethod def version(cls): @@ -114,6 +111,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variables = {"query": query} # fetch model instance model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) memory = self._memory # fetch instruction node_data.instruction = node_data.instruction or "" @@ -169,7 +170,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, @@ -205,7 +205,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): category_id = category_id_result process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -247,7 +247,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod @@ -285,7 +285,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) @@ -295,7 +295,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): prompt_template=prompt_template, sys_query="", sys_files=[], - context=context, + context=context or "", memory=None, model_instance=model_instance, stop=model_instance.stop, @@ -334,7 +334,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): memory: PromptMessageMemory | None, max_token_limit: int = 2000, ): - model_mode = ModelMode(node_data.model.mode) + model_mode = LLMMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: @@ -350,7 +350,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: + if model_mode == LLMMode.CHAT: system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) @@ -381,7 +381,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) prompt_messages.append(user_prompt_message_3) return prompt_messages - elif model_mode == ModelMode.COMPLETION: + elif model_mode == LLMMode.COMPLETION: return LLMNodeCompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, diff --git a/api/dify_graph/nodes/question_classifier/template_prompts.py b/api/graphon/nodes/question_classifier/template_prompts.py similarity index 100% rename from api/dify_graph/nodes/question_classifier/template_prompts.py rename to api/graphon/nodes/question_classifier/template_prompts.py diff --git a/api/graphon/nodes/runtime.py b/api/graphon/nodes/runtime.py new file mode 100644 index 0000000000..650299898c --- /dev/null +++ b/api/graphon/nodes/runtime.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from datetime import datetime +from typing import TYPE_CHECKING, Any, Protocol + +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) + +if TYPE_CHECKING: + from graphon.nodes.human_input.entities import HumanInputNodeData + from graphon.nodes.human_input.enums import HumanInputFormStatus + from graphon.nodes.tool.entities import ToolNodeData + from graphon.runtime import VariablePool + + +class ToolNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter owned by `core.workflow` and consumed by `graphon`. + + The graph package depends only on these DTOs and lets the workflow layer + translate between graph-owned abstractions and `core.tools` internals. + """ + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool: VariablePool | None, + ) -> ToolRuntimeHandle: ... + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: ... + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: ... + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: ... + + def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... + + +class HumanInputNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter for human-input runtime persistence and delivery.""" + + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: ... + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: ... + + +class HumanInputFormStateProtocol(Protocol): + @property + def id(self) -> str: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... diff --git a/api/dify_graph/nodes/start/__init__.py b/api/graphon/nodes/start/__init__.py similarity index 100% rename from api/dify_graph/nodes/start/__init__.py rename to api/graphon/nodes/start/__init__.py diff --git a/api/dify_graph/nodes/start/entities.py b/api/graphon/nodes/start/entities.py similarity index 58% rename from api/dify_graph/nodes/start/entities.py rename to api/graphon/nodes/start/entities.py index 92ebd1a2ec..7df62e1b2b 100644 --- a/api/dify_graph/nodes/start/entities.py +++ b/api/graphon/nodes/start/entities.py @@ -2,9 +2,9 @@ from collections.abc import Sequence from pydantic import Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.input_entities import VariableEntity +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.variables.input_entities import VariableEntity class StartNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/start/start_node.py b/api/graphon/nodes/start/start_node.py similarity index 69% rename from api/dify_graph/nodes/start/start_node.py rename to api/graphon/nodes/start/start_node.py index 5e6055ea34..cb3f4c1e7d 100644 --- a/api/dify_graph/nodes/start/start_node.py +++ b/api/graphon/nodes/start/start_node.py @@ -2,12 +2,11 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.variables.input_entities import VariableEntityType +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.start.entities import StartNodeData +from graphon.variables.input_entities import VariableEntityType class StartNode(Node[StartNodeData]): @@ -19,15 +18,10 @@ class StartNode(Node[StartNodeData]): return "1" def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) self._validate_and_normalize_json_object_inputs(node_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() - - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - outputs = dict(node_inputs) + outputs = dict(self.graph_runtime_state.variable_pool.flatten(unprefixed_node_id=self.id)) + outputs.update(node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/dify_graph/nodes/template_transform/__init__.py b/api/graphon/nodes/template_transform/__init__.py similarity index 100% rename from api/dify_graph/nodes/template_transform/__init__.py rename to api/graphon/nodes/template_transform/__init__.py diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/graphon/nodes/template_transform/entities.py similarity index 54% rename from api/dify_graph/nodes/template_transform/entities.py rename to api/graphon/nodes/template_transform/entities.py index ac29239958..a27a57f34f 100644 --- a/api/dify_graph/nodes/template_transform/entities.py +++ b/api/graphon/nodes/template_transform/entities.py @@ -1,6 +1,6 @@ -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.entities import VariableSelector class TemplateTransformNodeData(BaseNodeData): diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/graphon/nodes/template_transform/template_transform_node.py similarity index 57% rename from api/dify_graph/nodes/template_transform/template_transform_node.py rename to api/graphon/nodes/template_transform/template_transform_node.py index dc6fce2b0a..4206fb0c1a 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/graphon/nodes/template_transform/template_transform_node.py @@ -1,26 +1,27 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData -from dify_graph.nodes.template_transform.template_renderer import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.template_rendering import ( Jinja2TemplateRenderer, TemplateRenderError, ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _template_renderer: Jinja2TemplateRenderer + _jinja2_template_renderer: Jinja2TemplateRenderer _max_output_length: int def __init__( @@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - template_renderer: Jinja2TemplateRenderer, + jinja2_template_renderer: Jinja2TemplateRenderer, max_output_length: int | None = None, ) -> None: super().__init__( @@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._template_renderer = template_renderer + self._jinja2_template_renderer = jinja2_template_renderer if max_output_length is not None and max_output_length <= 0: raise ValueError("max_output_length must be a positive integer") @@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - rendered = self._template_renderer.render_template(self.node_data.template, variables) + rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData | Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } + _ = graph_config + raw_variables = ( + node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) + ) + variable_mapping: dict[str, Sequence[str]] = {} + for variable_selector in raw_variables: + if isinstance(variable_selector, VariableSelector): + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector + continue + + if not isinstance(variable_selector, Mapping): + continue + + variable = variable_selector.get("variable") + value_selector = variable_selector.get("value_selector") + if ( + isinstance(variable, str) + and isinstance(value_selector, Sequence) + and all(isinstance(selector_part, str) for selector_part in value_selector) + ): + variable_mapping[node_id + "." + variable] = list(value_selector) + + return variable_mapping diff --git a/api/dify_graph/nodes/tool/__init__.py b/api/graphon/nodes/tool/__init__.py similarity index 100% rename from api/dify_graph/nodes/tool/__init__.py rename to api/graphon/nodes/tool/__init__.py diff --git a/api/dify_graph/nodes/tool/entities.py b/api/graphon/nodes/tool/entities.py similarity index 88% rename from api/dify_graph/nodes/tool/entities.py rename to api/graphon/nodes/tool/entities.py index b041ee66fd..54e6048033 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/graphon/nodes/tool/entities.py @@ -1,11 +1,25 @@ +from enum import StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType + + +class ToolProviderType(StrEnum): + """ + Graph-owned enum for persisted tool provider kinds. + """ + + PLUGIN = auto() + BUILT_IN = "builtin" + WORKFLOW = auto() + API = auto() + APP = auto() + DATASET_RETRIEVAL = "dataset-retrieval" + MCP = auto() class ToolEntity(BaseModel): diff --git a/api/dify_graph/nodes/tool/exc.py b/api/graphon/nodes/tool/exc.py similarity index 53% rename from api/dify_graph/nodes/tool/exc.py rename to api/graphon/nodes/tool/exc.py index 7212e8bfc0..1a309e1084 100644 --- a/api/dify_graph/nodes/tool/exc.py +++ b/api/graphon/nodes/tool/exc.py @@ -4,6 +4,18 @@ class ToolNodeError(ValueError): pass +class ToolRuntimeResolutionError(ToolNodeError): + """Raised when the workflow layer cannot construct a tool runtime.""" + + pass + + +class ToolRuntimeInvocationError(ToolNodeError): + """Raised when the workflow layer fails while invoking a tool runtime.""" + + pass + + class ToolParameterError(ToolNodeError): """Exception raised for errors in tool parameters.""" diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/graphon/nodes/tool/tool_node.py similarity index 60% rename from api/dify_graph/nodes/tool/tool_node.py rename to api/graphon/nodes/tool/tool_node.py index 598f0da92e..57ab8ce5d6 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/graphon/nodes/tool/tool_node.py @@ -1,29 +1,25 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment -from dify_graph.variables.variables import ArrayAnyVariable -from factories import file_factory -from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.runtime import ToolNodeRuntimeProtocol +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from graphon.variables.segments import ArrayFileSegment from .entities import ToolNodeData from .exc import ( @@ -33,8 +29,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool class ToolNode(Node[ToolNodeData]): @@ -52,6 +48,7 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state: "GraphRuntimeState", *, tool_file_manager_factory: ToolFileManagerProtocol, + runtime: ToolNodeRuntimeProtocol | None = None, ): super().__init__( id=id, @@ -60,6 +57,9 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state=graph_runtime_state, ) self._tool_file_manager_factory = tool_file_manager_factory + if runtime is None: + raise ValueError("runtime is required") + self._runtime = runtime @classmethod def version(cls) -> str: @@ -73,10 +73,6 @@ class ToolNode(Node[ToolNodeData]): """ Run the tool node """ - from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - - dify_ctx = self.require_dify_context() - # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -86,8 +82,6 @@ class ToolNode(Node[ToolNodeData]): # get tool runtime try: - from core.tools.tool_manager import ToolManager - # This is an issue that caused problems before. # Logically, we shouldn't use the node_data.version field for judgment # But for backward compatibility with historical data @@ -95,13 +89,10 @@ class ToolNode(Node[ToolNodeData]): variable_pool: VariablePool | None = None if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = ToolManager.get_workflow_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - self._node_id, - self.node_data, - dify_ctx.invoke_from, - variable_pool, + tool_runtime = self._runtime.get_runtime( + node_id=self._node_id, + node_data=self.node_data, + variable_pool=variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -116,7 +107,7 @@ class ToolNode(Node[ToolNodeData]): return # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -128,18 +119,12 @@ class ToolNode(Node[ToolNodeData]): node_data=self.node_data, for_log=True, ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, + message_stream = self._runtime.invoke( + tool_runtime=tool_runtime, tool_parameters=parameters, - user_id=dify_ctx.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + provider_name=self.node_data.provider_name, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -159,38 +144,16 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) - except ToolInvokeError as e: + except ToolNodeError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}", - error_type=type(e).__name__, - ) - ) - except PluginInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name), - error_type=type(e).__name__, - ) - ) - except PluginDaemonClientSideError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool, error: {e.description}", + error=str(e), error_type=type(e).__name__, ) ) @@ -198,7 +161,7 @@ class ToolNode(Node[ToolNodeData]): def _generate_parameters( self, *, - tool_parameters: Sequence[ToolParameter], + tool_parameters: Sequence[ToolRuntimeParameter], variable_pool: "VariablePool", node_data: ToolNodeData, for_log: bool = False, @@ -207,7 +170,7 @@ class ToolNode(Node[ToolNodeData]): Generate parameters based on the given tool parameters, variable pool, and node data. Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. variable_pool (VariablePool): The variable pool containing the variables. node_data (ToolNodeData): The data associated with the tool node. @@ -240,107 +203,89 @@ class ToolNode(Node[ToolNodeData]): return result - def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - def _transform_message( self, - messages: Generator[ToolInvokeMessage, None, None], + messages: Generator[ToolRuntimeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, node_id: str, - tool_runtime: Tool, + tool_runtime: ToolRuntimeHandle, + **_: Any, ) -> Generator[NodeEventBase, None, LLMUsage]: """ - Convert ToolInvokeMessages into tuple[plain_text, files] + Convert graph-owned tool runtime messages into node outputs. """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - text = "" files: list[File] = [] json: list[dict | list] = [] variables: dict[str, Any] = {} - for message in message_stream: + for message in messages: if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, + ToolRuntimeMessage.MessageType.IMAGE_LINK, + ToolRuntimeMessage.MessageType.BINARY_LINK, + ToolRuntimeMessage.MessageType.IMAGE, }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool message is missing tool_file_id metadata") - tool_file_id = str(url).split("/")[-1].split(".")[0] - - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not found") + if tool_file.mime_type is None: + raise ToolFileError(f"tool file {tool_file_id} is missing mime type") - mapping = { + file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mime_type), "transfer_method": transfer_method, "url": url, } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) + file = self._runtime.build_file_reference(mapping=file_mapping) files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == ToolRuntimeMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool blob message is missing tool_file_id metadata") + _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not exists") - mapping = { + blob_file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, "transfer_method": FileTransferMethod.TOOL_FILE, } - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + files.append(self._runtime.build_file_reference(mapping=blob_file_mapping)) + elif message.type == ToolRuntimeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) text += message.message.text yield StreamChunkEvent( selector=[node_id, "text"], chunk=message.message.text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + elif message.type == ToolRuntimeMessage.MessageType.JSON: + assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) # JSON message handling for tool node if message.message.json_object: json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == ToolRuntimeMessage.MessageType.LINK: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) # Check if this LINK message is a file link file_obj = (message.meta or {}).get("file") @@ -356,8 +301,8 @@ class ToolNode(Node[ToolNodeData]): chunk=stream_text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -374,7 +319,7 @@ class ToolNode(Node[ToolNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == ToolRuntimeMessage.MessageType.FILE: assert message.meta is not None assert isinstance(message.meta, dict) # Validate that meta contains a 'file' key @@ -385,38 +330,16 @@ class ToolNode(Node[ToolNodeData]): if not isinstance(message.meta["file"], File): raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) + elif message.type == ToolRuntimeMessage.MessageType.LOG: + assert isinstance(message.message, ToolRuntimeMessage.LogMessage) if message.message.metadata: icon = tool_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - + icon, icon_dark = self._runtime.resolve_provider_icons( + provider_name=dict_metadata["provider"], + default_icon=icon, + ) dict_metadata["icon"] = icon dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata @@ -446,7 +369,7 @@ class ToolNode(Node[ToolNodeData]): is_final=True, ) - usage = self._extract_tool_usage(tool_runtime) + usage = self._runtime.get_usage(tool_runtime=tool_runtime) metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, @@ -468,21 +391,6 @@ class ToolNode(Node[ToolNodeData]): return usage - @staticmethod - def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - # Avoid importing WorkflowTool at module import time; rely on duck typing - # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. - latest = getattr(tool_runtime, "latest_usage", None) - # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects - # for any name, so we must type-check here. - if isinstance(latest, LLMUsage): - return latest - if isinstance(latest, dict): - # Allow dict payloads from external runtimes - return LLMUsage.model_validate(latest) - # Fallback to empty usage when attribute is missing or not a valid payload - return LLMUsage.empty_usage() - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/graphon/nodes/tool_runtime_entities.py b/api/graphon/nodes/tool_runtime_entities.py new file mode 100644 index 0000000000..5bb0c16573 --- /dev/null +++ b/api/graphon/nodes/tool_runtime_entities.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum, auto +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class _ToolRuntimeModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeHandle: + """Opaque graph-owned handle for a workflow-layer tool runtime. + + Workflow-specific execution context must stay behind `raw` so the graph + contract does not absorb application-owned concepts. + """ + + raw: object + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeParameter: + """Graph-owned parameter shape used by tool nodes.""" + + name: str + required: bool = False + + +class ToolRuntimeMessage(_ToolRuntimeModel): + """Graph-owned tool invocation message DTO.""" + + class TextMessage(_ToolRuntimeModel): + text: str + + class JsonMessage(_ToolRuntimeModel): + json_object: dict[str, Any] | list[Any] + suppress_output: bool = Field(default=False) + + class BlobMessage(_ToolRuntimeModel): + blob: bytes + + class BlobChunkMessage(_ToolRuntimeModel): + id: str + sequence: int + total_length: int + blob: bytes + end: bool + + class FileMessage(_ToolRuntimeModel): + file_marker: str = Field(default="file_marker") + + class VariableMessage(_ToolRuntimeModel): + variable_name: str + variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None + stream: bool = Field(default=False) + + class LogMessage(_ToolRuntimeModel): + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() + + id: str + label: str + parent_id: str | None = None + error: str | None = None + status: LogStatus + data: dict[str, Any] + metadata: dict[str, Any] = Field(default_factory=dict) + + class RetrieverResourceMessage(_ToolRuntimeModel): + retriever_resources: list[dict[str, Any]] + context: str + + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() + + type: MessageType = MessageType.TEXT + message: ( + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage + ) + meta: dict[str, Any] | None = None diff --git a/api/dify_graph/nodes/variable_aggregator/__init__.py b/api/graphon/nodes/variable_aggregator/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_aggregator/__init__.py rename to api/graphon/nodes/variable_aggregator/__init__.py diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/graphon/nodes/variable_aggregator/entities.py similarity index 77% rename from api/dify_graph/nodes/variable_aggregator/entities.py rename to api/graphon/nodes/variable_aggregator/entities.py index 4779ebd9a9..136fd28f8c 100644 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ b/api/graphon/nodes/variable_aggregator/entities.py @@ -1,8 +1,8 @@ from pydantic import BaseModel -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.types import SegmentType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.variables.types import SegmentType class AdvancedSettings(BaseModel): diff --git a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py similarity index 81% rename from api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py rename to api/graphon/nodes/variable_aggregator/variable_aggregator_node.py index 7d26de6232..71b221e196 100644 --- a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/graphon/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from dify_graph.variables.segments import Segment +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.variable_aggregator.entities import VariableAggregatorNodeData +from graphon.variables.segments import Segment class VariableAggregatorNode(Node[VariableAggregatorNodeData]): diff --git a/api/dify_graph/nodes/variable_assigner/__init__.py b/api/graphon/nodes/variable_assigner/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/__init__.py rename to api/graphon/nodes/variable_assigner/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/common/__init__.py b/api/graphon/nodes/variable_assigner/common/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/common/__init__.py rename to api/graphon/nodes/variable_assigner/common/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/common/exc.py b/api/graphon/nodes/variable_assigner/common/exc.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/common/exc.py rename to api/graphon/nodes/variable_assigner/common/exc.py diff --git a/api/dify_graph/nodes/variable_assigner/common/helpers.py b/api/graphon/nodes/variable_assigner/common/helpers.py similarity index 91% rename from api/dify_graph/nodes/variable_assigner/common/helpers.py rename to api/graphon/nodes/variable_assigner/common/helpers.py index f0b22904a9..4c30e009f2 100644 --- a/api/dify_graph/nodes/variable_assigner/common/helpers.py +++ b/api/graphon/nodes/variable_assigner/common/helpers.py @@ -3,9 +3,9 @@ from typing import Any, TypeVar from pydantic import BaseModel -from dify_graph.variables import Segment -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.types import SegmentType +from graphon.variables import Segment +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables # to minimize risk of collision with user-defined variable names. diff --git a/api/dify_graph/nodes/variable_assigner/v1/__init__.py b/api/graphon/nodes/variable_assigner/v1/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/v1/__init__.py rename to api/graphon/nodes/variable_assigner/v1/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/graphon/nodes/variable_assigner/v1/node.py similarity index 64% rename from api/dify_graph/nodes/variable_assigner/v1/node.py rename to api/graphon/nodes/variable_assigner/v1/node.py index f9b261b191..19ded5f123 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/graphon/nodes/variable_assigner/v1/node.py @@ -1,20 +1,19 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.variable_assigner.common import helpers as common_helpers +from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from graphon.variables import SegmentType, Variable, VariableBase from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: - from dify_graph.runtime import GraphRuntimeState + from graphon.runtime import GraphRuntimeState class VariableAssignerNode(Node[VariableAssignerData]): @@ -56,18 +55,16 @@ class VariableAssignerNode(Node[VariableAssignerData]): node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" mapping[key] = node_data.input_variable_selector return mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) @@ -92,18 +89,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): income_value = SegmentType.get_zero_value(original_variable.value_type) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - # Over write the variable. - self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, + yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "value": income_value.to_object(), + }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, + ) ) diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/graphon/nodes/variable_assigner/v1/node_data.py similarity index 76% rename from api/dify_graph/nodes/variable_assigner/v1/node_data.py rename to api/graphon/nodes/variable_assigner/v1/node_data.py index 57acb29535..4f630bc76c 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ b/api/graphon/nodes/variable_assigner/v1/node_data.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from enum import StrEnum -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class WriteMode(StrEnum): diff --git a/api/dify_graph/nodes/variable_assigner/v2/__init__.py b/api/graphon/nodes/variable_assigner/v2/__init__.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/v2/__init__.py rename to api/graphon/nodes/variable_assigner/v2/__init__.py diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/graphon/nodes/variable_assigner/v2/entities.py similarity index 89% rename from api/dify_graph/nodes/variable_assigner/v2/entities.py rename to api/graphon/nodes/variable_assigner/v2/entities.py index 2b2bbe85de..d1c68c8e8c 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ b/api/graphon/nodes/variable_assigner/v2/entities.py @@ -3,8 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from .enums import InputType, Operation diff --git a/api/dify_graph/nodes/variable_assigner/v2/enums.py b/api/graphon/nodes/variable_assigner/v2/enums.py similarity index 100% rename from api/dify_graph/nodes/variable_assigner/v2/enums.py rename to api/graphon/nodes/variable_assigner/v2/enums.py diff --git a/api/dify_graph/nodes/variable_assigner/v2/exc.py b/api/graphon/nodes/variable_assigner/v2/exc.py similarity index 93% rename from api/dify_graph/nodes/variable_assigner/v2/exc.py rename to api/graphon/nodes/variable_assigner/v2/exc.py index c50aab8668..90d7648574 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/exc.py +++ b/api/graphon/nodes/variable_assigner/v2/exc.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError from .enums import InputType, Operation diff --git a/api/dify_graph/nodes/variable_assigner/v2/helpers.py b/api/graphon/nodes/variable_assigner/v2/helpers.py similarity index 98% rename from api/dify_graph/nodes/variable_assigner/v2/helpers.py rename to api/graphon/nodes/variable_assigner/v2/helpers.py index 38c69cbe3c..ebc6c79476 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/helpers.py +++ b/api/graphon/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,6 @@ from typing import Any -from dify_graph.variables import SegmentType +from graphon.variables import SegmentType from .enums import Operation diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/graphon/nodes/variable_assigner/v2/node.py similarity index 71% rename from api/dify_graph/nodes/variable_assigner/v2/node.py rename to api/graphon/nodes/variable_assigner/v2/node.py index f04a6b3b80..887bd1b604 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/graphon/nodes/variable_assigner/v2/node.py @@ -1,16 +1,15 @@ import json -from collections.abc import Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Mapping, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, cast -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.variable_assigner.common import helpers as common_helpers +from graphon.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from graphon.variables import SegmentType, Variable, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem @@ -24,14 +23,11 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_node_id = item.variable_selector[0] - if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: - return selector_str = ".".join(item.variable_selector) key = f"{node_id}.#{selector_str}#" mapping[key] = item.variable_selector @@ -103,15 +99,18 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): _source_mapping_from_item(var_mapping, node_id, item) return var_mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] + # Preserve intra-node read-after-write behavior without mutating the shared pool + # until the engine processes the emitted VariableUpdatedEvent instances. + working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) try: for item in self.node_data.items: - variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) + variable = working_variable_pool.get(item.variable_selector) # ==================== Validation Part @@ -136,60 +135,64 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) # Get value from variable pool + input_value = item.value if ( item.input_type == InputType.VARIABLE and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} and item.value is not None ): - value = self.graph_runtime_state.variable_pool.get(item.value) + value = working_variable_pool.get(item.value) if value is None: raise VariableNotFoundError(variable_selector=item.value) # Skip if value is NoneSegment if value.value_type == SegmentType.NONE: continue - item.value = value.value + input_value = value.value # If set string / bytes / bytearray to object, try convert string to object. if ( item.operation == Operation.SET and variable.value_type == SegmentType.OBJECT - and isinstance(item.value, str | bytes | bytearray) + and isinstance(input_value, str | bytes | bytearray) ): try: - item.value = json.loads(item.value) + input_value = json.loads(input_value) except json.JSONDecodeError: - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # Check if input value is valid if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=item.value + variable_type=variable.value_type, operation=item.operation, value=input_value ): - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # ==================== Execution Part updated_value = self._handle_item( variable=variable, operation=item.operation, - value=item.value, + value=input_value, ) - variable = variable.model_copy(update={"value": updated_value}) - self.graph_runtime_state.variable_pool.add(variable.selector, variable) - updated_variable_selectors.append(variable.selector) + updated_variable = variable.model_copy(update={"value": updated_value}) + working_variable_pool.add(updated_variable.selector, updated_variable) + updated_variable_selectors.append(updated_variable.selector) except VariableOperatorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + error=str(e), + ) ) + return # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove the duplicated items first. - updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + # remove duplicated items while preserving the first update order. + updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) for selector in updated_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(selector) + variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -197,15 +200,23 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors - if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + if (seg := working_variable_pool.get(selector)) is not None ] process_data = common_helpers.set_updated_variables(process_data, updated_variables) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, + for selector in updated_variable_selectors: + variable = working_variable_pool.get(selector) + if not isinstance(variable, VariableBase): + raise VariableNotFoundError(variable_selector=selector) + yield VariableUpdatedEvent(variable=cast(Variable, variable)) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={}, + ) ) def _handle_item( diff --git a/api/graphon/prompt_entities.py b/api/graphon/prompt_entities.py new file mode 100644 index 0000000000..2b8b106c6c --- /dev/null +++ b/api/graphon/prompt_entities.py @@ -0,0 +1,47 @@ +from typing import Literal + +from pydantic import BaseModel + +from graphon.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """Graph-owned chat prompt template message.""" + + text: str + role: PromptMessageRole + edition_type: Literal["basic", "jinja2"] | None = None + + +class CompletionModelPromptTemplate(BaseModel): + """Graph-owned completion prompt template.""" + + text: str + edition_type: Literal["basic", "jinja2"] | None = None + + +class MemoryConfig(BaseModel): + """Graph-owned memory configuration for prompt assembly.""" + + class RolePrefix(BaseModel): + """Role labels used when serializing completion-model histories.""" + + user: str + assistant: str + + class WindowConfig(BaseModel): + """History windowing controls.""" + + enabled: bool + size: int | None = None + + role_prefix: RolePrefix | None = None + window: WindowConfig + query_prompt_template: str | None = None + + +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/dify_graph/runtime/__init__.py b/api/graphon/runtime/__init__.py similarity index 100% rename from api/dify_graph/runtime/__init__.py rename to api/graphon/runtime/__init__.py diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/graphon/runtime/graph_runtime_state.py similarity index 93% rename from api/dify_graph/runtime/graph_runtime_state.py rename to api/graphon/runtime/graph_runtime_state.py index 41acc6db35..6e4ed202b5 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/graphon/runtime/graph_runtime_state.py @@ -3,6 +3,7 @@ from __future__ import annotations import importlib import json from collections.abc import Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol @@ -10,13 +11,13 @@ from typing import TYPE_CHECKING, Any, ClassVar, Protocol from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder -from dify_graph.enums import NodeExecutionType, NodeState, NodeType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.variable_pool import VariablePool +from graphon.enums import NodeExecutionType, NodeState, NodeType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime.variable_pool import VariablePool if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.entities.pause_reason import PauseReason + from graphon.entities import GraphInitParams + from graphon.entities.pause_reason import PauseReason class ReadyQueueProtocol(Protocol): @@ -142,10 +143,9 @@ class ChildGraphEngineBuilderProtocol(Protocol): *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: ... @@ -211,6 +211,7 @@ class GraphRuntimeState: graph_execution: GraphExecutionProtocol | None = None, response_coordinator: ResponseStreamCoordinatorProtocol | None = None, graph: GraphProtocol | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: self._variable_pool = variable_pool self._start_at = start_at @@ -231,6 +232,9 @@ class GraphRuntimeState: self._ready_queue = ready_queue self._graph_execution = graph_execution self._response_coordinator = response_coordinator + # Application code injects this when worker threads must restore request + # or framework-local state. It is intentionally excluded from snapshots. + self._execution_context = execution_context if execution_context is not None else nullcontext(None) self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() @@ -285,21 +289,19 @@ class GraphRuntimeState: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: + """Create a child graph engine that derives its runtime state from the parent.""" if self._child_engine_builder is None: raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") return self._child_engine_builder.build_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, + parent_graph_runtime_state=self, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) # ------------------------------------------------------------------ @@ -329,6 +331,14 @@ class GraphRuntimeState: self._response_coordinator = self._build_response_coordinator(self._graph) return self._response_coordinator + @property + def execution_context(self) -> AbstractContextManager[object]: + return self._execution_context + + @execution_context.setter + def execution_context(self, value: AbstractContextManager[object] | None) -> None: + self._execution_context = value if value is not None else nullcontext(None) + # ------------------------------------------------------------------ # Scalar state # ------------------------------------------------------------------ @@ -485,13 +495,13 @@ class GraphRuntimeState: # ------------------------------------------------------------------ def _build_ready_queue(self) -> ReadyQueueProtocol: # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("dify_graph.graph_engine.ready_queue") + module = importlib.import_module("graphon.graph_engine.ready_queue") in_memory_cls = module.InMemoryReadyQueue return in_memory_cls() def _build_graph_execution(self) -> GraphExecutionProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution") + module = importlib.import_module("graphon.graph_engine.domain.graph_execution") graph_execution_cls = module.GraphExecution workflow_id = self._pending_graph_execution_workflow_id or "" self._pending_graph_execution_workflow_id = None @@ -499,7 +509,7 @@ class GraphRuntimeState: def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.response_coordinator") + module = importlib.import_module("graphon.graph_engine.response_coordinator") coordinator_cls = module.ResponseStreamCoordinator return coordinator_cls(variable_pool=self.variable_pool, graph=graph) diff --git a/api/dify_graph/runtime/graph_runtime_state_protocol.py b/api/graphon/runtime/graph_runtime_state_protocol.py similarity index 89% rename from api/dify_graph/runtime/graph_runtime_state_protocol.py rename to api/graphon/runtime/graph_runtime_state_protocol.py index 7e55ece3f1..856625a5d3 100644 --- a/api/dify_graph/runtime/graph_runtime_state_protocol.py +++ b/api/graphon/runtime/graph_runtime_state_protocol.py @@ -1,9 +1,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.variables.segments import Segment class ReadOnlyVariablePool(Protocol): @@ -31,9 +30,6 @@ class ReadOnlyGraphRuntimeState(Protocol): All methods return defensive copies to ensure immutability. """ - @property - def system_variable(self) -> SystemVariableReadOnlyView: ... - @property def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" diff --git a/api/dify_graph/runtime/read_only_wrappers.py b/api/graphon/runtime/read_only_wrappers.py similarity index 88% rename from api/dify_graph/runtime/read_only_wrappers.py rename to api/graphon/runtime/read_only_wrappers.py index ca06d88c3d..aaef255204 100644 --- a/api/dify_graph/runtime/read_only_wrappers.py +++ b/api/graphon/runtime/read_only_wrappers.py @@ -4,9 +4,8 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState from .variable_pool import VariablePool @@ -43,10 +42,6 @@ class ReadOnlyGraphRuntimeStateWrapper: self._state = state self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - @property - def system_variable(self) -> SystemVariableReadOnlyView: - return self._state.variable_pool.system_variables.as_view() - @property def variable_pool(self) -> ReadOnlyVariablePoolWrapper: return self._variable_pool_wrapper diff --git a/api/dify_graph/runtime/variable_pool.py b/api/graphon/runtime/variable_pool.py similarity index 63% rename from api/dify_graph/runtime/variable_pool.py rename to api/graphon/runtime/variable_pool.py index e3ef6a2897..b44d1a8abe 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/graphon/runtime/variable_pool.py @@ -6,84 +6,84 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Annotated, Any, Union, cast -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from dify_graph.file import File, FileAttribute, file_manager -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import Segment, SegmentGroup, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import FileSegment, ObjectSegment -from dify_graph.variables.variables import RAGPipelineVariableInput, Variable -from factories import variable_factory +from graphon.file import File, FileAttribute, file_manager +from graphon.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import FileSegment, ObjectSegment +from graphon.variables.variables import RAGPipelineVariableInput, Variable VariableValue = Union[str, int, float, dict[str, object], list[object], File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") +def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: + return defaultdict(dict) + + class VariablePool(BaseModel): + _SYSTEM_VARIABLE_NODE_ID = "sys" + _ENVIRONMENT_VARIABLE_NODE_ID = "env" + _CONVERSATION_VARIABLE_NODE_ID = "conversation" + _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" + # Variable dictionary is a dictionary for looking up variables by their selector. # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", - default=defaultdict(dict), + default_factory=_default_variable_dictionary, ) + system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) + user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) - # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. - user_inputs: Mapping[str, Any] = Field( - description="User inputs", - default_factory=dict, - ) - system_variables: SystemVariable = Field( - description="System variables", - default_factory=SystemVariable.default, - ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list[Variable], - ) - conversation_variables: Sequence[Variable] = Field( - description="Conversation variables.", - default_factory=list[Variable], - ) - rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( - description="RAG pipeline variables.", - default_factory=list, - ) + @model_validator(mode="after") + def _load_legacy_bootstrap_inputs(self) -> VariablePool: + """ + Accept legacy constructor kwargs that still appear throughout the workflow + layer while keeping serialized state focused on `variable_dictionary`. + """ - def model_post_init(self, context: Any, /): - # Create a mapping from field names to SystemVariableKey enum values - self._add_system_variables(self.system_variables) - # Add environment variables to the variable pool - for var in self.environment_variables: - self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool. When restoring from a serialized - # snapshot, `variable_dictionary` already carries the latest runtime values. - # In that case, keep existing entries instead of overwriting them with the - # bootstrap list. - for var in self.conversation_variables: - selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) - if self._has(selector): - continue - self.add(selector, var) - # Add rag pipeline variables to the variable pool - if self.rag_pipeline_variables: - rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) - for rag_var in self.rag_pipeline_variables: - node_id = rag_var.variable.belong_to_node_id - key = rag_var.variable.variable - value = rag_var.value - rag_pipeline_variables_map[node_id][key] = value - for key, value in rag_pipeline_variables_map.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) + self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) + self._ingest_legacy_rag_variables(self.rag_pipeline_variables) + + # These kwargs are accepted for compatibility but should not affect the + # stable serialized form or model equality. + self.system_variables = () + self.environment_variables = () + self.conversation_variables = () + self.rag_pipeline_variables = () + self.user_inputs = {} + return self + + def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: + for variable in variables: + selector = [node_id, variable.name] + normalized_variable = variable + if list(variable.selector) != selector: + normalized_variable = variable.model_copy(update={"selector": selector}) + self.add(normalized_variable.selector, normalized_variable) + + def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: + if not rag_pipeline_variables: + return + + values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_variable_input in rag_pipeline_variables: + values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( + rag_variable_input.value + ) + + for node_id, value in values_by_node_id.items(): + self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) def add(self, selector: Sequence[str], value: Any, /): """ @@ -114,10 +114,10 @@ class VariablePool(BaseModel): if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): - variable = variable_factory.segment_to_variable(segment=value, selector=selector) + variable = segment_to_variable(segment=value, selector=selector) else: - segment = variable_factory.build_segment(value) - variable = variable_factory.segment_to_variable(segment=segment, selector=selector) + segment = build_segment(value) + variable = segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) # Based on the definition of `Variable`, @@ -180,7 +180,7 @@ class VariablePool(BaseModel): return None attr = FileAttribute(attr) attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return variable_factory.build_segment(attr_value) + return build_segment(attr_value) # Navigate through nested attributes result: Any = segment @@ -191,7 +191,7 @@ class VariablePool(BaseModel): return None # Return result as Segment - return result if isinstance(result, Segment) else variable_factory.build_segment(result) + return result if isinstance(result, Segment) else build_segment(result) def _extract_value(self, obj: Any): """Extract the actual value from an ObjectSegment.""" @@ -212,7 +212,7 @@ class VariablePool(BaseModel): """ if not isinstance(obj, dict) or attr not in obj: return None - return variable_factory.build_segment(obj.get(attr)) + return build_segment(obj.get(attr)) def remove(self, selector: Sequence[str], /): """ @@ -239,7 +239,7 @@ class VariablePool(BaseModel): if "." in part and (variable := self.get(part.split("."))): segments.append(variable) else: - segments.append(variable_factory.build_segment(part)) + segments.append(build_segment(part)) return SegmentGroup(value=segments) def get_file(self, selector: Sequence[str], /) -> FileSegment | None: @@ -262,19 +262,18 @@ class VariablePool(BaseModel): return result - def _add_system_variables(self, system_variable: SystemVariable): - sys_var_mapping = system_variable.to_dict() - for key, value in sys_var_mapping.items(): - if value is None: - continue - selector = (SYSTEM_VARIABLE_NODE_ID, key) - # If the system variable already exists, do not add it again. - # This ensures that we can keep the id of the system variables intact. - if self._has(selector): - continue - self.add(selector, value) + def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: + """Return a selector-style snapshot of the entire variable pool.""" + + result: dict[str, object] = {} + for node_id, variables in self.variable_dictionary.items(): + for name, variable in variables.items(): + output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" + result[output_name] = deepcopy(variable.value) + + return result @classmethod def empty(cls) -> VariablePool: """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.default()) + return cls() diff --git a/api/graphon/template_rendering.py b/api/graphon/template_rendering.py new file mode 100644 index 0000000000..0527e58f6d --- /dev/null +++ b/api/graphon/template_rendering.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any + + +class TemplateRenderError(ValueError): + """Raised when rendering a template fails.""" + + +class Jinja2TemplateRenderer(ABC): + """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" + + @abstractmethod + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + """Render the template into plain text.""" + raise NotImplementedError diff --git a/api/dify_graph/utils/__init__.py b/api/graphon/utils/__init__.py similarity index 100% rename from api/dify_graph/utils/__init__.py rename to api/graphon/utils/__init__.py diff --git a/api/dify_graph/utils/condition/__init__.py b/api/graphon/utils/condition/__init__.py similarity index 100% rename from api/dify_graph/utils/condition/__init__.py rename to api/graphon/utils/condition/__init__.py diff --git a/api/dify_graph/utils/condition/entities.py b/api/graphon/utils/condition/entities.py similarity index 100% rename from api/dify_graph/utils/condition/entities.py rename to api/graphon/utils/condition/entities.py diff --git a/api/dify_graph/utils/condition/processor.py b/api/graphon/utils/condition/processor.py similarity index 98% rename from api/dify_graph/utils/condition/processor.py rename to api/graphon/utils/condition/processor.py index dea72d96c2..03535927cb 100644 --- a/api/dify_graph/utils/condition/processor.py +++ b/api/graphon/utils/condition/processor.py @@ -2,10 +2,10 @@ import json from collections.abc import Mapping, Sequence from typing import Literal, NamedTuple -from dify_graph.file import FileAttribute, file_manager -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayBooleanSegment, BooleanSegment +from graphon.file import FileAttribute, file_manager +from graphon.runtime import VariablePool +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayBooleanSegment, BooleanSegment from .entities import Condition, SubCondition, SupportedComparisonOperator diff --git a/api/graphon/utils/json_in_md_parser.py b/api/graphon/utils/json_in_md_parser.py new file mode 100644 index 0000000000..4416b4774b --- /dev/null +++ b/api/graphon/utils/json_in_md_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json + + +class OutputParserError(ValueError): + """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" + + +def parse_json_markdown(json_string: str) -> dict | list: + """Extract and parse the first JSON object or array embedded in markdown text.""" + json_string = json_string.strip() + starts = ["```json", "```", "``", "`", "{", "["] + ends = ["```", "``", "`", "}", "]"] + end_index = -1 + start_index = 0 + + for start_marker in starts: + start_index = json_string.find(start_marker) + if start_index != -1: + if json_string[start_index] not in ("{", "["): + start_index += len(start_marker) + break + + if start_index != -1: + for end_marker in ends: + end_index = json_string.rfind(end_marker, start_index) + if end_index != -1: + if json_string[end_index] in ("}", "]"): + end_index += 1 + break + + if start_index == -1 or end_index == -1 or start_index >= end_index: + raise ValueError("could not find json block in the output.") + + extracted_content = json_string[start_index:end_index].strip() + return json.loads(extracted_content) + + +def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as exc: + raise OutputParserError(f"got invalid json object. error: {exc}") from exc + + if isinstance(json_obj, list): + if len(json_obj) == 1 and isinstance(json_obj[0], dict): + json_obj = json_obj[0] + else: + raise OutputParserError(f"got invalid return object. obj:{json_obj}") + + for key in expected_keys: + if key not in json_obj: + raise OutputParserError( + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" + ) + + return json_obj diff --git a/api/dify_graph/variable_loader.py b/api/graphon/variable_loader.py similarity index 82% rename from api/dify_graph/variable_loader.py rename to api/graphon/variable_loader.py index d263450334..03db920d3d 100644 --- a/api/dify_graph/variable_loader.py +++ b/api/graphon/variable_loader.py @@ -2,9 +2,9 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from dify_graph.runtime import VariablePool -from dify_graph.variables import VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH +from graphon.runtime import VariablePool +from graphon.variables import VariableBase +from graphon.variables.consts import SELECTORS_LENGTH class VariableLoader(Protocol): @@ -13,14 +13,6 @@ class VariableLoader(Protocol): A `VariableLoader` is responsible for retrieving additional variables required during the execution of a single node, which are not provided as user inputs. - NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same - application and share the same `app_id`. However, this interface does not enforce that constraint, - and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of - concern and allow for flexible implementations. - - Implementations of `VariableLoader` should almost always have an `app_id` parameter in - their constructor. - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into `WorkflowService.single_step_run`, we may get rid of this interface. """ diff --git a/api/dify_graph/variables/__init__.py b/api/graphon/variables/__init__.py similarity index 83% rename from api/dify_graph/variables/__init__.py rename to api/graphon/variables/__init__.py index be3fc8d97a..e9beb6cb95 100644 --- a/api/dify_graph/variables/__init__.py +++ b/api/graphon/variables/__init__.py @@ -1,3 +1,10 @@ +from .factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, +) from .input_entities import VariableEntity, VariableEntityType from .segment_group import SegmentGroup from .segments import ( @@ -63,8 +70,13 @@ __all__ = [ "SegmentType", "StringSegment", "StringVariable", + "TypeMismatchError", + "UnsupportedSegmentTypeError", "Variable", "VariableBase", "VariableEntity", "VariableEntityType", + "build_segment", + "build_segment_with_type", + "segment_to_variable", ] diff --git a/api/dify_graph/variables/consts.py b/api/graphon/variables/consts.py similarity index 100% rename from api/dify_graph/variables/consts.py rename to api/graphon/variables/consts.py diff --git a/api/dify_graph/variables/exc.py b/api/graphon/variables/exc.py similarity index 100% rename from api/dify_graph/variables/exc.py rename to api/graphon/variables/exc.py diff --git a/api/graphon/variables/factory.py b/api/graphon/variables/factory.py new file mode 100644 index 0000000000..ac693914a7 --- /dev/null +++ b/api/graphon/variables/factory.py @@ -0,0 +1,202 @@ +"""Graph-owned helpers for converting runtime values, segments, and variables. + +These conversions are part of the `graphon` runtime model and must stay +independent from top-level API factory modules so graph nodes and state +containers can operate without importing application-layer packages. +""" + +from collections.abc import Mapping, Sequence +from typing import Any, cast +from uuid import uuid4 + +from graphon.file import File + +from .segments import ( + ArrayAnySegment, + ArrayBooleanSegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArraySegment, + ArrayStringSegment, + BooleanSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType +from .variables import ( + ArrayAnyVariable, + ArrayBooleanVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + BooleanVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + StringVariable, + VariableBase, +) + + +class UnsupportedSegmentTypeError(Exception): + pass + + +class TypeMismatchError(Exception): + pass + + +SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = { + ArrayAnySegment: ArrayAnyVariable, + ArrayBooleanSegment: ArrayBooleanVariable, + ArrayFileSegment: ArrayFileVariable, + ArrayNumberSegment: ArrayNumberVariable, + ArrayObjectSegment: ArrayObjectVariable, + ArrayStringSegment: ArrayStringVariable, + BooleanSegment: BooleanVariable, + FileSegment: FileVariable, + FloatSegment: FloatVariable, + IntegerSegment: IntegerVariable, + NoneSegment: NoneVariable, + ObjectSegment: ObjectVariable, + StringSegment: StringVariable, +} + + +def build_segment(value: Any, /) -> Segment: + """Build a runtime segment from a Python value.""" + if value is None: + return NoneSegment() + if isinstance(value, Segment): + return value + if isinstance(value, str): + return StringSegment(value=value) + if isinstance(value, bool): + return BooleanSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) + if isinstance(value, list): + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if all(isinstance(item, ArraySegment) for item in items): + return ArrayAnySegment(value=value) + if len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return ArrayNumberSegment(value=value) + case SegmentType.BOOLEAN: + return ArrayBooleanSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) + case _: + raise ValueError(f"not supported value {value}") + raise ValueError(f"not supported value {value}") + + +_SEGMENT_FACTORY: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.BOOLEAN: BooleanSegment, + SegmentType.OBJECT: ObjectSegment, + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, + SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, +} + + +def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: + """Build a segment while enforcing compatibility with the expected runtime type.""" + if value is None: + if segment_type == SegmentType.NONE: + return NoneSegment() + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") + + if isinstance(value, list) and len(value) == 0: + if segment_type == SegmentType.ARRAY_ANY: + return ArrayAnySegment(value=value) + if segment_type == SegmentType.ARRAY_STRING: + return ArrayStringSegment(value=value) + if segment_type == SegmentType.ARRAY_BOOLEAN: + return ArrayBooleanSegment(value=value) + if segment_type == SegmentType.ARRAY_NUMBER: + return ArrayNumberSegment(value=value) + if segment_type == SegmentType.ARRAY_OBJECT: + return ArrayObjectSegment(value=value) + if segment_type == SegmentType.ARRAY_FILE: + return ArrayFileSegment(value=value) + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + + inferred_type = SegmentType.infer_segment_type(value) + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) + if inferred_type == segment_type: + segment_class = _SEGMENT_FACTORY[segment_type] + return segment_class(value_type=segment_type, value=value) + if segment_type == SegmentType.NUMBER and inferred_type in (SegmentType.INTEGER, SegmentType.FLOAT): + segment_class = _SEGMENT_FACTORY[inferred_type] + return segment_class(value_type=inferred_type, value=value) + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") + + +def segment_to_variable( + *, + segment: Segment, + selector: Sequence[str], + id: str | None = None, + name: str | None = None, + description: str = "", +) -> VariableBase: + """Convert a runtime segment into a runtime variable for storage in the pool.""" + if isinstance(segment, VariableBase): + return segment + name = name or selector[-1] + id = id or str(uuid4()) + + segment_type = type(segment) + if segment_type not in SEGMENT_TO_VARIABLE_MAP: + raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") + + variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] + return cast( + VariableBase, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=list(selector), + ), + ) diff --git a/api/dify_graph/variables/input_entities.py b/api/graphon/variables/input_entities.py similarity index 97% rename from api/dify_graph/variables/input_entities.py rename to api/graphon/variables/input_entities.py index e6a68ea359..c46ee47714 100644 --- a/api/dify_graph/variables/input_entities.py +++ b/api/graphon/variables/input_entities.py @@ -5,7 +5,7 @@ from typing import Any from jsonschema import Draft7Validator, SchemaError from pydantic import BaseModel, Field, field_validator -from dify_graph.file import FileTransferMethod, FileType +from graphon.file import FileTransferMethod, FileType class VariableEntityType(StrEnum): diff --git a/api/dify_graph/variables/segment_group.py b/api/graphon/variables/segment_group.py similarity index 100% rename from api/dify_graph/variables/segment_group.py rename to api/graphon/variables/segment_group.py diff --git a/api/dify_graph/variables/segments.py b/api/graphon/variables/segments.py similarity index 99% rename from api/dify_graph/variables/segments.py rename to api/graphon/variables/segments.py index bdb213ed48..8902ddc7e9 100644 --- a/api/dify_graph/variables/segments.py +++ b/api/graphon/variables/segments.py @@ -5,7 +5,7 @@ from typing import Annotated, Any, TypeAlias from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator -from dify_graph.file import File +from graphon.file import File from .types import SegmentType diff --git a/api/dify_graph/variables/types.py b/api/graphon/variables/types.py similarity index 93% rename from api/dify_graph/variables/types.py rename to api/graphon/variables/types.py index 53bf495a27..949a693ad2 100644 --- a/api/dify_graph/variables/types.py +++ b/api/graphon/variables/types.py @@ -4,10 +4,10 @@ from collections.abc import Mapping from enum import StrEnum from typing import TYPE_CHECKING, Any -from dify_graph.file.models import File +from graphon.file.models import File if TYPE_CHECKING: - from dify_graph.variables.segments import Segment + from graphon.variables.segments import Segment class ArrayValidation(StrEnum): @@ -220,8 +220,8 @@ class SegmentType(StrEnum): @staticmethod def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency - from factories import variable_factory + # Lazy import to avoid circular dependency between segment types and factory helpers. + from graphon.variables.factory import build_segment, build_segment_with_type match t: case ( @@ -231,19 +231,19 @@ class SegmentType(StrEnum): | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN ): - return variable_factory.build_segment_with_type(t, []) + return build_segment_with_type(t, []) case SegmentType.OBJECT: - return variable_factory.build_segment({}) + return build_segment({}) case SegmentType.STRING: - return variable_factory.build_segment("") + return build_segment("") case SegmentType.INTEGER: - return variable_factory.build_segment(0) + return build_segment(0) case SegmentType.FLOAT: - return variable_factory.build_segment(0.0) + return build_segment(0.0) case SegmentType.NUMBER: - return variable_factory.build_segment(0) + return build_segment(0) case SegmentType.BOOLEAN: - return variable_factory.build_segment(False) + return build_segment(False) case _: raise ValueError(f"unsupported variable type: {t}") diff --git a/api/dify_graph/variables/utils.py b/api/graphon/variables/utils.py similarity index 100% rename from api/dify_graph/variables/utils.py rename to api/graphon/variables/utils.py diff --git a/api/dify_graph/variables/variables.py b/api/graphon/variables/variables.py similarity index 100% rename from api/dify_graph/variables/variables.py rename to api/graphon/variables/variables.py diff --git a/api/dify_graph/workflow_type_encoder.py b/api/graphon/workflow_type_encoder.py similarity index 95% rename from api/dify_graph/workflow_type_encoder.py rename to api/graphon/workflow_type_encoder.py index 3dd846b3cb..7cdc83ebdb 100644 --- a/api/dify_graph/workflow_type_encoder.py +++ b/api/graphon/workflow_type_encoder.py @@ -4,8 +4,8 @@ from typing import Any, overload from pydantic import BaseModel -from dify_graph.file.models import File -from dify_graph.variables import Segment +from graphon.file.models import File +from graphon.variables import Segment class WorkflowRuntimeTypeConverter: diff --git a/api/libs/broadcast_channel/channel.py b/api/libs/broadcast_channel/channel.py index d4cb3e9971..8eeac37232 100644 --- a/api/libs/broadcast_channel/channel.py +++ b/api/libs/broadcast_channel/channel.py @@ -125,7 +125,8 @@ class BroadcastChannel(Protocol): a specific topic, all subscription should receive the published message. There are no restriction for the persistence of messages. Once a subscription is created, it - should receive all subsequent messages published. + should receive all subsequent messages published. However, a subscription should not receive + any message published before the subscription is established. `BroadcastChannel` implementations must be thread-safe and support concurrent use by multiple threads. """ diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index d6ec5504ca..983f785027 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -63,21 +63,45 @@ class _StreamsSubscription(Subscription): def __init__(self, client: Redis | RedisCluster, key: str): self._client = client self._key = key - self._closed = threading.Event() - self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() - self._start_lock = threading.Lock() + + # The `_lock` lock is used to + # + # 1. protect the _listener attribute + # 2. prevent repeated releases of underlying resoueces. (The _closed flag.) + # + # INVARIANT: the implementation must hold the lock while + # reading and writing the _listener / `_closed` attribute. + self._lock = threading.Lock() + self._closed: bool = False + # self._closed = threading.Event() self._listener: threading.Thread | None = None def _listen(self) -> None: - try: - while not self._closed.is_set(): - streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + """The `_listen` method handles the message retrieval loop. It requires a dedicated thread + and is not intended for direct invocation. + The thread is started by `_start_if_needed`. + """ + + # since this method runs in a dedicated thread, acquiring `_lock` inside this method won't cause + # deadlock. + + # Setting initial last id to `$` to signal redis that we only want new messages. + # + # ref: https://redis.io/docs/latest/commands/xread/#the-special--id + last_id = "$" + try: + while True: + with self._lock: + if self._closed: + break + streams = self._client.xread({self._key: last_id}, block=1000, count=100) if not streams: continue - for _key, entries in streams: + for _, entries in streams: for entry_id, fields in entries: data = None if isinstance(fields, dict): @@ -89,37 +113,48 @@ class _StreamsSubscription(Subscription): data_bytes = bytes(data) if data_bytes is not None: self._queue.put_nowait(data_bytes) - self._last_id = entry_id + last_id = entry_id finally: self._queue.put_nowait(self._SENTINEL) - self._listener = None + with self._lock: + self._listener = None + self._closed = True def _start_if_needed(self) -> None: + """This method must be called with `_lock` held.""" if self._listener is not None: return # Ensure only one listener thread is created under concurrent calls - with self._start_lock: - if self._listener is not None or self._closed.is_set(): - return - self._listener = threading.Thread( - target=self._listen, - name=f"redis-streams-sub-{self._key}", - daemon=True, - ) - self._listener.start() + if self._listener is not None or self._closed: + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() def __iter__(self) -> Iterator[bytes]: # Iterator delegates to receive with timeout; stops on closure. - self._start_if_needed() - while not self._closed.is_set(): - item = self.receive(timeout=1) + with self._lock: + self._start_if_needed() + + while True: + with self._lock: + if self._closed: + return + try: + item = self.receive(timeout=1) + except SubscriptionClosedError: + return if item is not None: yield item def receive(self, timeout: float | None = 0.1) -> bytes | None: - if self._closed.is_set(): - raise SubscriptionClosedError("The Redis streams subscription is closed") - self._start_if_needed() + with self._lock: + if self._closed: + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() try: if timeout is None: @@ -129,29 +164,33 @@ class _StreamsSubscription(Subscription): except queue.Empty: return None - if item is self._SENTINEL or self._closed.is_set(): + if item is self._SENTINEL: raise SubscriptionClosedError("The Redis streams subscription is closed") assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" return bytes(item) def close(self) -> None: - if self._closed.is_set(): - return - self._closed.set() - listener = self._listener - if listener is not None: + with self._lock: + if self._closed: + return + self._closed = True + listener = self._listener + if listener is not None: + self._listener = None + # We close the listener outside of the with block to avoid holding the + # lock for a long time. + if listener is not None and listener.is_alive(): listener.join(timeout=2.0) if listener.is_alive(): logger.warning( "Streams subscription listener for key %s did not stop within timeout; keeping reference.", self._key, ) - else: - self._listener = None # Context manager helpers def __enter__(self) -> Self: - self._start_if_needed() + with self._lock: + self._start_if_needed() return self def __exit__(self, exc_type, exc_value, traceback) -> bool | None: diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py index c08578981b..e0a6ec2cac 100644 --- a/api/libs/datetime_utils.py +++ b/api/libs/datetime_utils.py @@ -2,7 +2,7 @@ import abc import datetime from typing import Protocol -import pytz +import pytz # type: ignore[import-untyped] class _NowFunction(Protocol): diff --git a/api/libs/helper.py b/api/libs/helper.py index e7572cc025..b1815859a5 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -21,9 +21,9 @@ from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account @@ -174,6 +174,18 @@ def normalize_uuid(value: str | UUID) -> str: raise ValueError("must be a valid UUID") from exc +def parse_uuid_str_or_none(value: str | None) -> str | None: + """ + Return None for missing/empty UUID-like values. + + Keep non-empty values unchanged to avoid changing behavior in paths that + currently pass placeholder IDs in tests/mocks. + """ + if value is None or not str(value).strip(): + return None + return str(value) + + UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d..dce332b01d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -18,15 +18,23 @@ if TYPE_CHECKING: from models.model import EndUser +def _resolve_current_user() -> EndUser | Account | None: + """ + Resolve the current user proxy to its underlying user object. + This keeps unit tests working when they patch `current_user` directly + instead of bootstrapping a full Flask-Login manager. + """ + user_proxy = current_user + get_current_object = getattr(user_proxy, "_get_current_object", None) + return get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + def current_account_with_tenant(): """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. """ - user_proxy = current_user - - get_current_object = getattr(user_proxy, "_get_current_object", None) - user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + user = _resolve_current_user() if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") @@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) - user = _get_user() + user = _resolve_current_user() if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. check_csrf_token(request, user.id) diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 1afb42304d..76e741301c 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -28,6 +28,7 @@ class AccessTokenResponse(TypedDict, total=False): class GitHubEmailRecord(TypedDict, total=False): email: str primary: bool + verified: bool class GitHubRawUserInfo(TypedDict): @@ -130,25 +131,51 @@ class GitHubOAuth(OAuth): response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) + # Only call the /user/emails endpoint when the profile email is absent, + # i.e. the user has "Keep my email addresses private" enabled. + resolved_email = user_info.get("email") or "" + if not resolved_email: + resolved_email = self._get_email_from_emails_endpoint(headers) + + return {**user_info, "email": resolved_email} + + @staticmethod + def _get_email_from_emails_endpoint(headers: dict[str, str]) -> str: + """Fetch the best available email from GitHub's /user/emails endpoint. + + Prefers the primary email, then falls back to any verified email. + Returns an empty string when no usable email is found. + """ try: - email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) + email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) email_response.raise_for_status() - email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) - primary_email = next((email for email in email_info if email.get("primary") is True), None) + email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) except (httpx.HTTPStatusError, ValidationError): logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True) - primary_email = None + return "" - return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} + primary = next((r for r in email_records if r.get("primary") is True), None) + if primary: + return primary.get("email", "") + + # No primary email; try any verified email as a fallback. + verified = next((r for r in email_records if r.get("verified") is True), None) + if verified: + return verified.get("email", "") + + return "" def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) - email = payload.get("email") + email = payload.get("email") or "" if not email: - raise ValueError( - 'Dify currently not supports the "Keep my email addresses private" feature,' - " please disable it and login again" - ) + # When no email is available from the profile or /user/emails endpoint, + # fall back to GitHub's noreply address so sign-in can still proceed. + # Use only the numeric ID (not the login) so the address stays stable + # even if the user renames their GitHub account. + github_id = payload["id"] + email = f"{github_id}@users.noreply.github.com" + logger.info("GitHub user %s has no public email; using noreply address", payload["login"]) return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email) diff --git a/api/libs/schedule_utils.py b/api/libs/schedule_utils.py index 1ab5f499e9..b80a5ea722 100644 --- a/api/libs/schedule_utils.py +++ b/api/libs/schedule_utils.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -import pytz +import pytz # type: ignore[import-untyped] from croniter import croniter diff --git a/api/models/dataset.py b/api/models/dataset.py index d0163e6984..e323ccfd7f 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -43,7 +43,9 @@ from .enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, SummaryStatus, + TidbAuthBindingStatus, ) from .model import App, Tag, TagBinding, UploadFile from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index @@ -135,7 +137,7 @@ class Dataset(Base): default=DatasetPermissionEnum.ONLY_ME, ) data_source_type = mapped_column(EnumText(DataSourceType, length=255)) - indexing_technique: Mapped[str | None] = mapped_column(String(255)) + indexing_technique: Mapped[IndexTechniqueType | None] = mapped_column(EnumText(IndexTechniqueType, length=255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -494,7 +496,9 @@ class Document(Base): ) doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) - doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) + doc_form: Mapped[IndexStructureType] = mapped_column( + EnumText(IndexStructureType, length=255), nullable=False, server_default=sa.text("'text_model'") + ) doc_language = mapped_column(String(255), nullable=True) need_summary: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @@ -998,7 +1002,9 @@ class ChildChunk(Base): # indexing fields index_node_id = mapped_column(String(255), nullable=True) index_node_hash = mapped_column(String(255), nullable=True) - type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + ) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1239,7 +1245,9 @@ class TidbAuthBinding(TypeBase): cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'")) + status: Mapped[TidbAuthBindingStatus] = mapped_column( + EnumText(TidbAuthBindingStatus, length=255), nullable=False, server_default=sa.text("'CREATING'") + ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/enums.py b/api/models/enums.py index 4849099d30..cdec7b2f12 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -222,6 +222,13 @@ class DatasetMetadataType(StrEnum): TIME = "time" +class SegmentType(StrEnum): + """Document segment type""" + + AUTOMATIC = "automatic" + CUSTOMIZED = "customized" + + class SegmentStatus(StrEnum): """Document segment status""" @@ -323,3 +330,10 @@ class ProviderQuotaType(StrEnum): if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") + + +class ApiTokenType(StrEnum): + """API Token type""" + + APP = "app" + DATASET = "dataset" diff --git a/api/models/execution_extra_content.py b/api/models/execution_extra_content.py index d0bd34efec..b2d09a7732 100644 --- a/api/models/execution_extra_content.py +++ b/api/models/execution_extra_content.py @@ -66,8 +66,8 @@ class HumanInputContent(ExecutionExtraContent): form_id: Mapped[str] = mapped_column(StringUUID, nullable=True) @classmethod - def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent": - return cls(form_id=form_id, message_id=message_id) + def new(cls, *, workflow_run_id: str, form_id: str, message_id: str | None) -> "HumanInputContent": + return cls(workflow_run_id=workflow_run_id, form_id=form_id, message_id=message_id) form: Mapped["HumanInputForm"] = relationship( "HumanInputForm", diff --git a/api/models/human_input.py b/api/models/human_input.py index 48e7fbb9ea..b4c7a634b6 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,11 +6,8 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) +from core.workflow.human_input_compat import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index b098966052..bcb142db56 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,10 +3,11 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto +from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 @@ -20,17 +21,19 @@ from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 +from models.utils.file_input_compat import build_file_from_input_mapping from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db from .enums import ( + ApiTokenType, AppMCPServerStatus, AppStatus, BannerStatus, @@ -43,6 +46,7 @@ from .enums import ( MessageChainType, MessageFileBelongsTo, MessageStatus, + ProviderQuotaType, TagType, ) from .provider_ids import GenericProviderID @@ -55,6 +59,32 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def _resolve_app_tenant_id(app_id: str) -> str: + resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not resolved_tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return resolved_tenant_id + + +def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: + resolved_tenant_id = owner_tenant_id + + def resolve_owner_tenant_id() -> str: + nonlocal resolved_tenant_id + if resolved_tenant_id is None: + resolved_tenant_id = _resolve_app_tenant_id(app_id) + return resolved_tenant_id + + return resolve_owner_tenant_id + + class EnabledConfig(TypedDict): enabled: bool @@ -587,7 +617,9 @@ class AppModelConfig(TypeBase): __tablename__ = "app_model_configs" __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) @@ -936,7 +968,9 @@ class AccountTrialAppRecord(Base): class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv4_string, init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1051,23 +1085,26 @@ class Conversation(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: stored input payloads may come from before or after the + # graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant + # resolution at the SQLAlchemy model boundary instead of pushing ownership back + # into `graphon.file.File`. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) # Convert file mapping to File object for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1080,15 +1117,12 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1396,21 +1430,23 @@ class Message(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: message inputs are persisted as JSON and must remain + # readable across file payload shape changes. Do not assume `tenant_id` + # is serialized into each file mapping going forward. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1423,15 +1459,12 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1606,6 +1639,7 @@ class Message(Base): "upload_file_id": message_file.upload_file_id, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: @@ -1619,6 +1653,7 @@ class Message(Base): "url": message_file.url, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: @@ -1633,6 +1668,7 @@ class Message(Base): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) else: raise ValueError( @@ -1783,7 +1819,7 @@ class MessageFile(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[FileType] = mapped_column(EnumText(FileType, length=255), nullable=False) transfer_method: Mapped[FileTransferMethod] = mapped_column( EnumText(FileTransferMethod, length=255), nullable=False ) @@ -1845,7 +1881,9 @@ class AppAnnotationHitHistory(TypeBase): sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) source: Mapped[str] = mapped_column(LongText, nullable=False) @@ -2095,7 +2133,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(String(16), nullable=False) + type: Mapped[ApiTokenType] = mapped_column(EnumText(ApiTokenType, length=16), nullable=False) token: Mapped[str] = mapped_column(String(255), nullable=False) last_used_at = mapped_column(sa.DateTime, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -2490,7 +2528,9 @@ class TenantCreditPool(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + pool_type: Mapped[ProviderQuotaType] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=False, default=ProviderQuotaType.TRIAL, server_default="trial" + ) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/tools.py b/api/models/tools.py index 01182af867..63b27b9413 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -145,7 +145,9 @@ class ApiToolProvider(TypeBase): icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema schema: Mapped[str] = mapped_column(LongText, nullable=False) - schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) + schema_type_str: Mapped[ApiProviderSchemaType] = mapped_column( + EnumText(ApiProviderSchemaType, length=40), nullable=False + ) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id diff --git a/api/models/utils/__init__.py b/api/models/utils/__init__.py new file mode 100644 index 0000000000..b390b8106b --- /dev/null +++ b/api/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .file_input_compat import build_file_from_input_mapping + +__all__ = ["build_file_from_input_mapping"] diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py new file mode 100644 index 0000000000..dee1cc507a --- /dev/null +++ b/api/models/utils/file_input_compat.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from functools import lru_cache +from typing import Any + +from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod + + +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + + return None + + +def resolve_file_mapping_tenant_id( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> str: + tenant_id = file_mapping.get("tenant_id") + if isinstance(tenant_id, str) and tenant_id: + return tenant_id + + return tenant_resolver() + + +def build_file_from_stored_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_id: str, +) -> File: + """ + Canonicalize a persisted file payload against the current tenant context. + + Stored JSON rows can outlive file schema changes, so rebuild storage-backed + files through the workflow factory instead of trusting serialized metadata. + Pure external ``REMOTE_URL`` payloads without a backing upload row are + passed through because there is no server-owned record to rebind. + """ + + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + mapping = dict(file_mapping) + mapping.pop("tenant_id", None) + record_id = resolve_file_record_id(mapping) + transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) + + if transfer_method == FileTransferMethod.TOOL_FILE and record_id: + mapping["tool_file_id"] = record_id + elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: + mapping["upload_file_id"] = record_id + elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: + mapping["datasource_file_id"] = record_id + + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + if isinstance(url, str) and url: + mapping["remote_url"] = url + return File.model_validate(mapping) + + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_get_file_access_controller(), + ) + + +def build_file_from_input_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> File: + """ + Rehydrate persisted model input payloads into graph `File` objects. + + This compatibility layer exists because model JSON rows can outlive file payload + schema changes. Legacy rows may carry `related_id` and `tenant_id`, while newer + rows may only carry `reference`. Keep ownership resolution here, at the model + boundary, instead of pushing tenant data back into `graphon.file.File`. + """ + + transfer_method = FileTransferMethod.value_of(file_mapping["transfer_method"]) + record_id = resolve_file_record_id(file_mapping) + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id="", + ) + + tenant_id = resolve_file_mapping_tenant_id(file_mapping=file_mapping, tenant_resolver=tenant_resolver) + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id=tenant_id, + ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 0b13c0a074..fa9ebc76a5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -24,19 +24,20 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import ( +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey -from dify_graph.file.constants import maybe_file_object -from dify_graph.file.models import File -from dify_graph.variables import utils as variable_utils -from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey +from graphon.file.constants import maybe_file_object +from graphon.file.models import File +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -48,8 +49,8 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from dify_graph.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account @@ -57,6 +58,7 @@ from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID +from .utils.file_input_compat import build_file_from_stored_mapping logger = logging.getLogger(__name__) @@ -64,6 +66,15 @@ SerializedWorkflowValue = dict[str, Any] SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] +def _resolve_workflow_app_tenant_id(app_id: str) -> str: + from .model import App + + tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return tenant_id + + class WorkflowContentDict(TypedDict): graph: Mapping[str, Any] features: dict[str, Any] @@ -274,7 +285,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(node_config) + return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -419,7 +430,7 @@ class Workflow(Base): # bug "selected": false, } - For specific node type, refer to `dify_graph.nodes` + For specific node type, refer to `graphon.nodes` """ graph_dict = self.graph_dict if "nodes" not in graph_dict: @@ -1222,7 +1233,9 @@ class WorkflowAppLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[WorkflowAppLogCreatedFrom] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=False + ) created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1302,10 +1315,14 @@ class WorkflowArchiveLog(TypeBase): log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + log_created_from: Mapped[WorkflowAppLogCreatedFrom | None] = mapped_column( + EnumText(WorkflowAppLogCreatedFrom, length=255), nullable=True + ) run_version: Mapped[str] = mapped_column(String(255), nullable=False) - run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[WorkflowExecutionStatus] = mapped_column( + EnumText(WorkflowExecutionStatus, length=255), nullable=False + ) run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False ) @@ -1445,7 +1462,7 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. # - # ref: api/dify_graph/entities/variable_pool.py:18 + # ref: api/graphon/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), @@ -1560,10 +1577,9 @@ class WorkflowDraftVariable(Base): def _loads_value(self) -> Segment: value = json.loads(self.value) - return self.build_segment_with_type(self.value_type, value) + return self.build_segment_from_serialized_value(self.value_type, value) - @staticmethod - def rebuild_file_types(value: Any): + def _rebuild_file_types(self, value: Any): # NOTE(QuantumGhost): Temporary workaround for structured data handling. # By this point, `output` has been converted to dict by # `WorkflowEntry.handle_special_values`, so we need to @@ -1577,13 +1593,72 @@ class WorkflowDraftVariable(Base): if isinstance(value, dict): if not maybe_file_object(value): return cast(Any, value) - return File.model_validate(value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + return build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], value), + tenant_id=tenant_id, + ) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] if not maybe_file_object(first): return cast(Any, value) - file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + file_list: list[File] = [] + for item in value_list: + file_list.append( + build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], item), + tenant_id=tenant_id, + ) + ) + return cast(Any, file_list) + else: + return cast(Any, value) + + def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment: + # Persisted draft variable rows may contain historical file payloads. + # Rebuild them through the file factory so tenant ownership, signed URLs, + # and storage-backed metadata come from canonical records instead of the + # serialized JSON blob. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + @staticmethod + def rebuild_file_types(value: Any): + # Keep the class-level fallback for callers that only need lightweight + # structural reconstruction. Persisted draft-variable payloads should go + # through `build_segment_from_serialized_value()` so file metadata is + # rebuilt from canonical storage records. + if isinstance(value, dict): + if not maybe_file_object(value): + return cast(Any, value) + normalized_file = dict(value) + normalized_file.pop("tenant_id", None) + return File.model_validate(normalized_file) + elif isinstance(value, list) and value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + file_list: list[File] = [] + for item in value_list: + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(File.model_validate(normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/pyproject.toml b/api/pyproject.toml index 7e70bc76d3..cd1a730e53 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.13.2" +version = "1.13.3" requires-python = ">=3.11,<3.13" dependencies = [ @@ -174,7 +174,7 @@ dev = [ "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.55.0", + "pyrefly>=0.57.1", ] ############################################################ @@ -209,7 +209,7 @@ evaluation = ["ragas>=0.2.0", "deepeval>=2.0.0"] # Required by vector store clients ############################################################ vdb = [ - "alibabacloud_gpdb20160503~=3.8.0", + "alibabacloud_gpdb20160503~=5.1.0", "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", "clickhouse-connect~=0.14.1", @@ -247,7 +247,7 @@ module = [ "configs.middleware.cache.redis_pubsub_config", "extensions.ext_redis", "tasks.workflow_execution_tasks", - "dify_graph.nodes.base.node", + "graphon.nodes.base.node", "services.human_input_delivery_test_service", "core.app.apps.advanced_chat.app_generator", "controllers.console.human_input_form", diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index ad3c1e8389..dc0adbf50d 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -109,34 +109,34 @@ core/trigger/debug/event_selectors.py core/trigger/entities/entities.py core/trigger/provider.py core/workflow/workflow_entry.py -dify_graph/entities/workflow_execution.py -dify_graph/file/file_manager.py -dify_graph/graph_engine/error_handler.py -dify_graph/graph_engine/layers/execution_limits.py -dify_graph/nodes/agent/agent_node.py -dify_graph/nodes/base/node.py -dify_graph/nodes/code/code_node.py -dify_graph/nodes/datasource/datasource_node.py -dify_graph/nodes/document_extractor/node.py -dify_graph/nodes/human_input/human_input_node.py -dify_graph/nodes/if_else/if_else_node.py -dify_graph/nodes/iteration/iteration_node.py -dify_graph/nodes/knowledge_index/knowledge_index_node.py +graphon/entities/workflow_execution.py +graphon/file/file_manager.py +graphon/graph_engine/error_handler.py +graphon/graph_engine/layers/execution_limits.py +graphon/nodes/agent/agent_node.py +graphon/nodes/base/node.py +graphon/nodes/code/code_node.py +graphon/nodes/datasource/datasource_node.py +graphon/nodes/document_extractor/node.py +graphon/nodes/human_input/human_input_node.py +graphon/nodes/if_else/if_else_node.py +graphon/nodes/iteration/iteration_node.py +graphon/nodes/knowledge_index/knowledge_index_node.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -dify_graph/nodes/list_operator/node.py -dify_graph/nodes/llm/node.py -dify_graph/nodes/loop/loop_node.py -dify_graph/nodes/parameter_extractor/parameter_extractor_node.py -dify_graph/nodes/question_classifier/question_classifier_node.py -dify_graph/nodes/start/start_node.py -dify_graph/nodes/template_transform/template_transform_node.py -dify_graph/nodes/tool/tool_node.py -dify_graph/nodes/trigger_plugin/trigger_event_node.py -dify_graph/nodes/trigger_schedule/trigger_schedule_node.py -dify_graph/nodes/trigger_webhook/node.py -dify_graph/nodes/variable_aggregator/variable_aggregator_node.py -dify_graph/nodes/variable_assigner/v1/node.py -dify_graph/nodes/variable_assigner/v2/node.py +graphon/nodes/list_operator/node.py +graphon/nodes/llm/node.py +graphon/nodes/loop/loop_node.py +graphon/nodes/parameter_extractor/parameter_extractor_node.py +graphon/nodes/question_classifier/question_classifier_node.py +graphon/nodes/start/start_node.py +graphon/nodes/template_transform/template_transform_node.py +graphon/nodes/tool/tool_node.py +graphon/nodes/trigger_plugin/trigger_event_node.py +graphon/nodes/trigger_schedule/trigger_schedule_node.py +graphon/nodes/trigger_webhook/node.py +graphon/nodes/variable_aggregator/variable_aggregator_node.py +graphon/nodes/variable_assigner/v1/node.py +graphon/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 2fa065bcc8..3595ea33f0 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index a96c4acb31..ffc17e92cf 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -40,9 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index be28b7e613..03ce574dca 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from dify_graph.entities.pause_reason import PauseReason +from graphon.entities.pause_reason import PauseReason class WorkflowPauseEntity(ABC): diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 77e40fc6fc..44735eb769 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -14,7 +14,7 @@ from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index fdd3e123e4..5bb0c74ada 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -33,17 +33,17 @@ from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,25 +61,13 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], ) -> HumanInputRequired: form_content = "" inputs = [] actions = [] - display_in_ui = False resolved_default_values: dict[str, Any] = {} node_title = "Human Input" form_id = reason_model.form_id @@ -99,25 +87,16 @@ def _build_human_input_required_reason( form_content = definition.form_content inputs = list(definition.inputs) actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - return HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, actions=actions, - display_in_ui=display_in_ui, node_id=node_id, node_title=node_title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -823,22 +802,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id ] form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} if form_ids: form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + pause_reasons.append(_build_human_input_required_reason(reason, form_model)) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 508db22eb0..67f8795d3f 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -18,9 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from dify_graph.nodes.human_input.entities import FormDefinition -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 8b9d973d6d..6ceb3ef856 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -8,6 +8,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -57,7 +58,7 @@ def create_clusters(batch_size): account=new_cluster["account"], password=new_cluster["password"], active=False, - status="CREATING", + status=TidbAuthBindingStatus.CREATING, ) db.session.add(tidb_auth_binding) db.session.commit() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1befa0e8b5..10003b1b97 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -9,6 +9,7 @@ from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding +from models.enums import TidbAuthBindingStatus @app.celery.task(queue="dataset") @@ -18,7 +19,10 @@ def update_tidb_serverless_status_task(): try: # check the number of idle tidb serverless tidb_serverless_list = db.session.scalars( - select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + select(TidbAuthBinding).where( + TidbAuthBinding.active == False, + TidbAuthBinding.status == TidbAuthBindingStatus.CREATING, + ) ).all() if len(tidb_serverless_list) == 0: return diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 68cb3438ca..643a2a2a84 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -27,15 +27,15 @@ from core.trigger.constants import ( ) from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..a9ec357455 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -12,10 +12,10 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.app_event import app_was_created from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -92,7 +92,7 @@ class AppService: default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "") # get default model instance try: @@ -124,11 +124,19 @@ class AppService: "completion_params": {}, } else: - provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM - ) - default_model_config["model"]["provider"] = provider - default_model_config["model"]["name"] = model + try: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except Exception: + logger.exception("Get default provider model failed, tenant_id: %s", tenant_id) + provider = default_model_config["model"].get("provider") + model = default_model_config["model"].get("name") + + if provider: + default_model_config["model"]["provider"] = provider + if model: + default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] default_model_config["model"] = json.dumps(default_model_dict) @@ -197,6 +205,7 @@ class AppService: tenant_id=current_user.current_tenant_id, app_id=app.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, @@ -241,7 +250,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +266,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index d556230044..6e9d6b1c73 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -7,8 +7,8 @@ new GraphEngine command channel mechanism. from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.graph_engine.manager import GraphEngineManager from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1794ea9947..9e743bf7b1 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -9,8 +9,8 @@ from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( @@ -61,7 +61,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) @@ -71,7 +71,7 @@ class AudioService: buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + return {"text": model_instance.invoke_speech2text(file=buffer)} @classmethod def transcript_tts( @@ -109,7 +109,7 @@ class AudioService: voice = cast(str | None, text_to_speech_dict.get("voice")) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) @@ -123,9 +123,7 @@ class AudioService: else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( - content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) except Exception as e: raise e @@ -155,7 +153,7 @@ class AudioService: @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_base.py b/api/services/auth/api_key_auth_base.py index dd74a8f1b5..2e1b723e82 100644 --- a/api/services/auth/api_key_auth_base.py +++ b/api/services/auth/api_key_auth_base.py @@ -1,8 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any + +from typing_extensions import TypedDict + + +class AuthCredentials(TypedDict): + auth_type: str + config: dict[str, Any] class ApiKeyAuthBase(ABC): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): self.credentials = credentials @abstractmethod diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 7ae31b0768..6e183b70e3 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,9 +1,9 @@ -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): + def __init__(self, provider: str, credentials: AuthCredentials): auth_factory = self.get_apikey_auth_factory(provider) self.auth = auth_factory(credentials) diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index b002706931..c9e5610aea 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index afaed28ac9..e5e2319ce1 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -2,11 +2,11 @@ import json import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class JinaAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "bearer": diff --git a/api/services/auth/watercrawl/watercrawl.py b/api/services/auth/watercrawl/watercrawl.py index b2d28a83d1..cbdc908690 100644 --- a/api/services/auth/watercrawl/watercrawl.py +++ b/api/services/auth/watercrawl/watercrawl.py @@ -3,11 +3,11 @@ from urllib.parse import urljoin import httpx -from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.api_key_auth_base import ApiKeyAuthBase, AuthCredentials class WatercrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): + def __init__(self, credentials: AuthCredentials): super().__init__(credentials) auth_type = credentials.get("auth_type") if auth_type != "x-api-key": diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 0e0eab00ad..c6b32b373e 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -10,10 +10,10 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 566c27c0f3..545c5048d5 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -10,9 +10,9 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index f00e3fe01e..287d513f48 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.variables.variables import VariableBase +from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 1954602571..2894826935 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from models.enums import ProviderQuotaType logger = logging.getLogger(__name__) @@ -16,7 +17,10 @@ class CreditPoolService: def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" credit_pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + tenant_id=tenant_id, + quota_limit=dify_config.HOSTED_POOL_CREDITS, + quota_used=0, + pool_type=ProviderQuotaType.TRIAL, ) db.session.add(credit_pool) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index cdab90a3dc..3e2342b1a7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -21,16 +21,16 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -58,6 +58,7 @@ from models.enums import ( IndexingStatus, ProcessRuleMode, SegmentStatus, + SegmentType, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -227,8 +228,8 @@ class DatasetService: if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == "high_quality": - model_manager = ModelManager() + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name) @@ -253,7 +254,10 @@ class DatasetService: retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) - dataset = Dataset(name=name, indexing_technique=indexing_technique) + dataset = Dataset( + name=name, + indexing_technique=IndexTechniqueType(indexing_technique) if indexing_technique else None, + ) # dataset = Dataset(name=name, provider=provider, config=config) dataset.description = description dataset.created_by = account.id @@ -348,9 +352,9 @@ class DatasetService: @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -367,7 +371,7 @@ class DatasetService: @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, @@ -384,7 +388,7 @@ class DatasetService: @staticmethod def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, provider=model_provider, @@ -405,7 +409,7 @@ class DatasetService: @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=reranking_model_provider, @@ -716,13 +720,13 @@ class DatasetService: if "indexing_technique" not in data: return None if dataset.indexing_technique != data["indexing_technique"]: - if data["indexing_technique"] == "economy": + if data["indexing_technique"] == IndexTechniqueType.ECONOMY: # Remove embedding model configuration for economy mode filtered_data["embedding_model"] = None filtered_data["embedding_model_provider"] = None filtered_data["collection_binding_id"] = None return "remove" - elif data["indexing_technique"] == "high_quality": + elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) return "add" @@ -742,7 +746,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( @@ -860,7 +864,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) try: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -952,9 +956,9 @@ class DatasetService: dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure - dataset.indexing_technique = knowledge_configuration.indexing_technique - if knowledge_configuration.indexing_technique == "high_quality": - model_manager = ModelManager() + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", @@ -975,7 +979,7 @@ class DatasetService: embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number else: raise ValueError("Invalid index method") @@ -990,13 +994,13 @@ class DatasetService: action = None if dataset.indexing_technique != knowledge_configuration.indexing_technique: # if update indexing_technique - if knowledge_configuration.indexing_technique == "economy": + if knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") - elif knowledge_configuration.indexing_technique == "high_quality": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: action = "add" # get embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=knowledge_configuration.embedding_model_provider, @@ -1017,7 +1021,7 @@ class DatasetService: ) dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1028,7 +1032,7 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent # Skip embedding model checks if not provided in the update request - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: skip_embedding_update = False try: # Handle existing model provider @@ -1049,7 +1053,7 @@ class DatasetService: or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = None try: embedding_model = model_manager.get_model_instance( @@ -1088,7 +1092,7 @@ class DatasetService: ) except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - elif dataset.indexing_technique == "economy": + elif dataset.indexing_technique == IndexTechniqueType.ECONOMY: if dataset.keyword_number != knowledge_configuration.keyword_number: dataset.keyword_number = knowledge_configuration.keyword_number dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() @@ -1439,7 +1443,7 @@ class DocumentService: .filter( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, - Document.doc_form != "qa_model", # Skip qa_model documents + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .update({Document.need_summary: need_summary}, synchronize_session=False) ) @@ -1906,9 +1910,9 @@ class DocumentService: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = knowledge_config.indexing_technique - if knowledge_config.indexing_technique == "high_quality": - model_manager = ModelManager() + dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2039,7 +2043,7 @@ class DocumentService: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = knowledge_config.doc_form + document.doc_form = IndexStructureType(knowledge_config.doc_form) document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch @@ -2220,7 +2224,7 @@ class DocumentService: # dataset.indexing_technique = knowledge_config.indexing_technique # if knowledge_config.indexing_technique == "high_quality": - # model_manager = ModelManager() + # model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: # dataset_embedding_model = knowledge_config.embedding_model # dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2639,7 +2643,7 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = naive_utc_now() document.created_from = created_from - document.doc_form = document_data.doc_form + document.doc_form = IndexStructureType(document_data.doc_form) db.session.add(document) db.session.commit() # update document segment @@ -2688,7 +2692,7 @@ class DocumentService: dataset_collection_binding_id = None retrieval_model = None - if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: assert knowledge_config.embedding_model_provider assert knowledge_config.embedding_model dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -2711,7 +2715,7 @@ class DocumentService: tenant_id=tenant_id, name="", data_source_type=knowledge_config.data_source.info_list.data_source_type, - indexing_technique=knowledge_config.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_config.indexing_technique), created_by=account.id, embedding_model=knowledge_config.embedding_model, embedding_model_provider=knowledge_config.embedding_model_provider, @@ -3100,7 +3104,7 @@ class DocumentService: class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") if not args["answer"].strip(): @@ -3124,8 +3128,8 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3157,7 +3161,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] @@ -3207,8 +3211,8 @@ class SegmentService: try: with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3229,9 +3233,9 @@ class SegmentService: doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality" and embedding_model: + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY and embedding_model: # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: tokens = embedding_model.get_text_embedding_num_tokens( texts=[content + segment_item["answer"]] )[0] @@ -3254,7 +3258,7 @@ class SegmentService: completed_at=naive_utc_now(), created_by=current_user.id, ) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count @@ -3321,7 +3325,7 @@ class SegmentService: content = args.content or segment.content if segment.content == content: segment.word_count = len(content) - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3344,9 +3348,9 @@ class SegmentService: if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3381,7 +3385,7 @@ class SegmentService: # When user manually provides summary, allow saving even if summary_index_setting doesn't exist # summary_index_setting is only needed for LLM generation, not for manual summary vectorization # Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # Query existing summary from database from models.dataset import DocumentSegmentSummary @@ -3408,8 +3412,8 @@ class SegmentService: else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3418,7 +3422,7 @@ class SegmentService: ) # calc embedding use tokens - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore else: @@ -3435,7 +3439,7 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == "qa_model": + if document.doc_form == IndexStructureType.QA_INDEX: segment.answer = args.answer segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change @@ -3448,9 +3452,9 @@ class SegmentService: db.session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3480,7 +3484,7 @@ class SegmentService: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) # Handle summary index when content changed - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary existing_summary = ( @@ -3786,7 +3790,7 @@ class SegmentService: child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED update_child_chunks.append(child_chunk) else: new_child_chunks_args.append(child_chunk_update_args) @@ -3845,7 +3849,7 @@ class SegmentService: child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() - child_chunk.type = "customized" + child_chunk.type = SegmentType.CUSTOMIZED db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) db.session.commit() diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index f3b2adb965..2b7bebb01e 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -14,9 +14,9 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter -from dify_graph.model_runtime.entities.provider_entities import FormType from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 9dd595f516..6679c08ebd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -15,9 +15,9 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, ModelCredentialSchema, ProviderCredentialSchema, diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 4cf42b7f44..d2fa98f5e2 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -9,8 +9,8 @@ from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition -from dify_graph.nodes.http_request.exc import InvalidHttpMethodError from extensions.ext_database import db +from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, diff --git a/api/services/file_service.py b/api/services/file_service.py index a7060f3b92..c11f018f52 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,10 +20,10 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 9993d24c70..d490ad1561 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,8 +9,8 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db +from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 229e6608da..861d952c93 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,16 +8,16 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService @@ -177,21 +177,21 @@ class EmailDeliveryTestHandler: def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] - member_user_ids: list[str] = [] + bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) + bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) - if recipients.whole_workspace: - member_user_ids = [] + if recipients.include_bound_group: + bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: + elif bound_reference_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) + for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 2e74c50963..76598d31ac 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -11,12 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( +from graphon.nodes.human_input.entities import ( FormDefinition, HumanInputSubmissionValidationError, validate_human_input_submission, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc..3d6fdb08a3 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,12 +1,20 @@ import boto3 +from pydantic import BaseModel, Field from configs import dify_config +class BedrockRetrievalSetting(BaseModel): + """Retrieval settings for Amazon Bedrock knowledge base queries.""" + + top_k: int | None = Field(default=None, description="Maximum number of results to retrieve") + score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") + + class ExternalDatasetTestService: # this service is only for internal testing @staticmethod - def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str): # get bedrock client client = boto3.client( "bedrock-agent-runtime", @@ -20,7 +28,7 @@ class ExternalDatasetTestService: knowledgeBaseId=knowledge_id, retrievalConfiguration={ "vectorSearchConfiguration": { - "numberOfResults": retrieval_setting.get("top_k"), + "numberOfResults": retrieval_setting.top_k, "overrideSearchType": "HYBRID", } }, @@ -33,7 +41,7 @@ class ExternalDatasetTestService: retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: # filter out results with score less than threshold - if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + if retrieval_result.get("score") < retrieval_setting.score_threshold: continue result = { "metadata": retrieval_result.get("metadata"), diff --git a/api/services/message_service.py b/api/services/message_service.py index fc87802f51..0c4a334b47 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -12,8 +12,8 @@ from core.model_manager import ModelManager from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating @@ -255,7 +255,7 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bf3b6db3ed..469357d6e0 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,14 +10,15 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from extensions.ext_database import db +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential @@ -26,8 +27,9 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ @@ -40,7 +42,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -61,7 +63,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -83,7 +85,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -222,8 +224,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -310,7 +312,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -495,8 +497,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -532,6 +534,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=assembly.model_provider_factory, ) def _custom_credentials_validate( @@ -542,6 +545,7 @@ class ModelLoadBalancingService: model: str, credentials: dict, load_balancing_model_config: LoadBalancingModelConfig | None = None, + model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, ): """ @@ -552,6 +556,7 @@ class ModelLoadBalancingService: :param model: model name :param credentials: credentials :param load_balancing_model_config: load balancing model config + :param model_provider_factory: model provider factory sharing the active runtime :param validate: validate credentials :return: """ @@ -581,7 +586,8 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = ModelProviderFactory(tenant_id) + if model_provider_factory is None: + model_provider_factory = provider_configuration.get_model_provider_factory() if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0ddd6b9b1a..e634f90603 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,9 @@ import logging from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -25,8 +25,9 @@ class ModelProviderService: Model Provider Service """ - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def _get_provider_configuration(self, tenant_id: str, provider: str): """ @@ -43,7 +44,7 @@ class ModelProviderService: ProviderNotFoundError: If provider doesn't exist """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -60,7 +61,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_responses = [] for provider_configuration in provider_configurations.values(): @@ -138,7 +139,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models return [ @@ -146,6 +147,26 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] + def get_provider_available_credentials(self, tenant_id: str, provider: str): + return self._get_provider_manager(tenant_id).get_provider_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) + + def get_provider_model_available_credentials( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + ): + return self._get_provider_manager(tenant_id).get_provider_model_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type, + model_name=model, + ) + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -391,7 +412,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True) @@ -476,7 +497,9 @@ class ModelProviderService: model_type_enum = ModelType.value_of(model_type) try: - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + result = self._get_provider_manager(tenant_id).get_default_model( + tenant_id=tenant_id, model_type=model_type_enum + ) return ( DefaultModelResponse( model=result.model, @@ -507,7 +530,7 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - self.provider_manager.update_default_model_record( + self._get_provider_manager(tenant_id).update_default_model_record( tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) @@ -523,7 +546,7 @@ class ModelProviderService: :param lang: language (zh_Hans or en_US) :return: """ - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) return byte_data, mime_type diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 296b9f0890..8a28537528 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -34,26 +34,32 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping +from core.workflow.system_variables import ( + SystemVariableKey, + build_bootstrap_variables, + build_system_variables, + default_system_variables, + get_system_segment, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.workflow_node_execution import ( +from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.graph_events.base import GraphNodeEventBase -from dify_graph.node_events.base import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.variables import VariableBase -from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.graph_events.base import GraphNodeEventBase +from graphon.node_events.base import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore @@ -88,6 +94,12 @@ from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) +def _build_seeded_variable_pool(variables: Sequence[Variable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + return variable_pool + + class RagPipelineService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize RagPipelineService with repository dependencies.""" @@ -521,13 +533,7 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ), + variable_pool=_build_seeded_variable_pool(default_system_variables()), variable_loader=DraftVarLoader( engine=db.engine, app_id=pipeline.id, @@ -959,10 +965,10 @@ class RagPipelineService: workflow_node_execution.error = error # update document status variable_pool = node_instance.graph_runtime_state.variable_pool - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + invoke_from = get_system_segment(variable_pool, SystemVariableKey.INVOKE_FROM) if invoke_from: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: @@ -1276,7 +1282,7 @@ class RagPipelineService: else: enclosing_node_id = None - system_inputs = SystemVariable( + system_inputs = build_system_variables( datasource_type=args.get("datasource_type", "online_document"), datasource_info=args.get("datasource_info", {}), ) @@ -1287,12 +1293,11 @@ class RagPipelineService: node_id=node_id, user_inputs={}, user_id=current_user.id, - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs={}, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], + variable_pool=_build_seeded_variable_pool( + build_bootstrap_variables( + system_variables=system_inputs, + rag_pipeline_variables=(), + ) ), variable_loader=DraftVarLoader( engine=db.engine, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index deb59da8d3..1b8207cc31 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -22,17 +22,18 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name from core.plugin.entities.plugin import PluginDependency +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode @@ -311,13 +312,13 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -343,7 +344,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -443,18 +444,18 @@ class RagPipelineDslService: "icon_background": icon_background, "icon_url": icon_url, }, - indexing_technique=knowledge_configuration.indexing_technique, + indexing_technique=IndexTechniqueType(knowledge_configuration.indexing_technique), created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: - dataset.indexing_technique = knowledge_configuration.indexing_technique + dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure - if knowledge_configuration.indexing_technique == "high_quality": + if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) .where( @@ -480,7 +481,7 @@ class RagPipelineDslService: dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider - elif knowledge_configuration.indexing_technique == "economy": + elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: dataset.keyword_number = knowledge_configuration.keyword_number # Update summary_index_setting if provided if knowledge_configuration.summary_index_setting is not None: @@ -772,7 +773,7 @@ class RagPipelineDslService: ) case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) - if knowledge_index_entity.indexing_technique == "high_quality": + if knowledge_index_entity.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if knowledge_index_entity.embedding_model_provider: dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 1d0aafd5fd..215a8c8528 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -9,6 +9,7 @@ from flask_login import current_user from constants import DOCUMENT_EXTENSIONS from core.plugin.impl.plugin import PluginInstaller +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory @@ -79,9 +80,9 @@ class RagPipelineTransformService: pipeline = self._create_pipeline(pipeline_yaml) # save chunk structure to dataset - if doc_form == "hierarchical_model": + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset.chunk_structure = "hierarchical_model" - elif doc_form == "text_model": + elif doc_form == IndexStructureType.PARAGRAPH_INDEX: dataset.chunk_structure = "text_model" else: raise ValueError("Unsupported doc form") @@ -101,38 +102,38 @@ class RagPipelineTransformService: def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} - if doc_form == "text_model": + if doc_form == IndexStructureType.PARAGRAPH_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.NOTION_IMPORT: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case DataSourceType.WEBSITE_CRAWL: - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: pipeline_yaml = yaml.safe_load(f) - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: # get graph from transform.website-crawl-general-economy.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) case _: raise ValueError("Unsupported datasource type") - elif doc_form == "hierarchical_model": + elif doc_form == IndexStructureType.PARENT_CHILD_INDEX: match datasource_type: case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml @@ -169,11 +170,11 @@ class RagPipelineTransformService: ): knowledge_configuration_dict = node.get("data", {}) - if indexing_technique == "high_quality": + if indexing_technique == IndexTechniqueType.HIGH_QUALITY: knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - if indexing_technique == "economy": + if indexing_technique == IndexTechniqueType.ECONOMY: retrieval_model.search_method = RetrievalMethod.KEYWORD_SEARCH knowledge_configuration.retrieval_model = retrieval_model else: diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 00a2144800..c91f621ffb 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -31,9 +31,9 @@ from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.enums import WorkflowType from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 943dfc972b..4334412c8b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -12,10 +12,11 @@ from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -140,7 +141,7 @@ class SummaryIndexService: session: Optional SQLAlchemy session. If provided, uses this session instead of creating a new one. If not provided, creates a new session and commits automatically. """ - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.warning( "Summary vectorization skipped for dataset %s: indexing_technique is not high_quality", dataset.id, @@ -191,7 +192,7 @@ class SummaryIndexService: # Calculate embedding tokens for summary (for logging and statistics) embedding_tokens = 0 try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -200,7 +201,8 @@ class SummaryIndexService: ) if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) - embedding_tokens = tokens_list[0] if tokens_list else 0 + raw_embedding_tokens = tokens_list[0] if tokens_list else 0 + embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0 except Exception as e: logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) @@ -724,7 +726,7 @@ class SummaryIndexService: List of created DocumentSegmentSummary instances """ # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( "Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'", dataset.id, @@ -851,7 +853,7 @@ class SummaryIndexService: ) # Remove from vector database (but keep records) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: try: @@ -889,7 +891,7 @@ class SummaryIndexService: segment_ids: List of segment IDs to enable summaries for. If None, enable all. """ # Only enable summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return with session_factory.create_session() as session: @@ -981,7 +983,7 @@ class SummaryIndexService: return # Delete from vector database - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id] if summary_node_ids: vector = Vector(dataset) @@ -1012,7 +1014,7 @@ class SummaryIndexService: Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality """ # Only update summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return None # When user manually provides summary, allow saving even if summary_index_setting doesn't exist diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 408b1c22d1..9190a67249 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -20,8 +20,8 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 101b2fe5a2..931ca5021a 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -12,8 +12,8 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 7e9d010d2f..a827222c1d 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -13,7 +13,7 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from dify_graph.entities.graph_config import NodeConfigDict +from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 24bbeda329..dca00a466b 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -18,9 +18,9 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 3c1a4cc747..5d9be84c06 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -15,6 +15,7 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( @@ -23,13 +24,13 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, WebhookParameter, ) -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.file.models import FileTransferMethod -from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.file.models import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger @@ -46,6 +47,7 @@ except ImportError: magic = None # type: ignore[assignment] logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WebhookService: @@ -422,6 +424,7 @@ class WebhookService: return file_factory.build_from_mapping( mapping=mapping, tenant_id=webhook_trigger.tenant_id, + access_controller=_file_access_controller, ) @classmethod diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 60dc1dedb8..d0a4317065 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -6,9 +6,9 @@ from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config -from dify_graph.file.models import File -from dify_graph.nodes.variable_assigner.common.helpers import UpdatedVariable -from dify_graph.variables.segments import ( +from graphon.file.models import File +from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable +from graphon.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -20,7 +20,7 @@ from dify_graph.variables.segments import ( Segment, StringSegment, ) -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index b66fdd7a20..5fd310b689 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,12 +4,12 @@ from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument @@ -45,9 +45,9 @@ class VectorService: if not processing_rule: raise ValueError("No processing rule found.") # get embedding model instance - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -112,7 +112,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) @@ -197,7 +197,7 @@ class VectorService: "dataset_id": child_segment.dataset_id, }, ) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # save vector index vector = Vector(dataset=dataset) vector.add_texts([child_document], duplicate_check=True) @@ -237,7 +237,7 @@ class VectorService: delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: delete_node_ids.append(delete_child_chunk.index_node_id) - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index vector = Vector(dataset=dataset) if delete_node_ids: @@ -252,7 +252,7 @@ class VectorService: @classmethod def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return attachments = segment.attachments diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f0596e44c8..1f3993505c 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -17,13 +17,13 @@ from core.app.apps.completion.app_config_manager import CompletionAppConfigManag from core.helper import encrypter from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.models import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db +from graphon.file.models import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index f1601cd6be..ea42ed0a8b 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from typing_extensions import TypedDict -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f124e137c3..0b5c89e574 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,28 +14,36 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.trigger.constants import is_trigger_node_type -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, SystemVariableKey -from dify_graph.file.models import File -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables import Segment, StringSegment, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import ( - ArrayFileSegment, - FileSegment, +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from graphon.enums import NodeType +from graphon.file.models import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation from models.enums import ConversationFromSource, DraftVariableType +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -71,7 +79,7 @@ class UpdateNotSupportedError(WorkflowDraftVariableError): class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # - # ref: dify_graph.variable_loader.VariableLoader + # ref: graphon.variable_loader.VariableLoader # Database engine used for loading variables. _engine: Engine @@ -120,7 +128,11 @@ class DraftVarLoader(VariableLoader): elif isinstance(value, ArrayFileSegment): files.extend(value.value) with Session(bind=self._engine) as session: - storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader = StorageKeyLoader( + session, + tenant_id=self._tenant_id, + access_controller=DatabaseFileAccessController(), + ) storage_key_loader.load_storage_keys(files) offloaded_draft_vars = [] @@ -174,7 +186,7 @@ class DraftVarLoader(VariableLoader): return (draft_var.node_id, draft_var.name), variable deserialized = json.loads(content) - segment = WorkflowDraftVariable.build_segment_with_type(variable_file.value_type, deserialized) + segment = draft_var.build_segment_from_serialized_value(variable_file.value_type, deserialized) variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), @@ -838,6 +850,12 @@ class DraftVariableSaver: self._user = user self._enclosing_node_id = enclosing_node_id + def _resolve_app_tenant_id(self) -> str: + tenant_id = self._session.scalar(select(App.tenant_id).where(App.id == self._app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {self._app_id}") + return tenant_id + def _create_dummy_output_variable(self): return WorkflowDraftVariable.new_node_variable( app_id=self._app_id, @@ -892,27 +910,18 @@ class DraftVariableSaver: for name, value in output.items(): value_seg = _build_segment_for_serialized_values(value) node_id, name = self._normalize_variable_for_start_node(name) - # If node_id is not `sys`, it means that the variable is a user-defined input field - # in `Start` node. - if node_id != SYSTEM_VARIABLE_NODE_ID: - draft_vars.append( - WorkflowDraftVariable.new_node_variable( - app_id=self._app_id, - user_id=self._user.id, - node_id=self._node_id, - name=name, - node_execution_id=self._node_execution_id, - value=value_seg, - visible=True, - editable=True, - ) - ) - has_non_sys_variables = True - else: + if node_id == SYSTEM_VARIABLE_NODE_ID: if name == SystemVariableKey.FILES: # Here we know the type of variable must be `array[file]`, we - # just build files from the value. - files = [File.model_validate(v) for v in value] + # just rebuild files from the serialized payload. + tenant_id = self._resolve_app_tenant_id() + files = [ + build_file_from_stored_mapping( + file_mapping=v, + tenant_id=tenant_id, + ) + for v in value + ] if files: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: @@ -928,15 +937,47 @@ class DraftVariableSaver: editable=self._should_variable_be_editable(node_id, name), ) ) + elif node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + user_id=self._user.id, + name=name, + value=value_seg, + ) + ) + has_non_sys_variables = True + else: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + user_id=self._user.id, + node_id=node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(node_id, self._node_type, name), + editable=self._should_variable_be_editable(node_id, name), + ) + ) + has_non_sys_variables = True if not has_non_sys_variables: draft_vars.append(self._create_dummy_output_variable()) return draft_vars def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: - if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): - return self._node_id, name - _, name_ = name.split(".", maxsplit=1) - return SYSTEM_VARIABLE_NODE_ID, name_ + for reserved_node_id in ( + SYSTEM_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + CONVERSATION_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + ): + prefix = f"{reserved_node_id}." + if name.startswith(prefix): + _, name_ = name.split(".", maxsplit=1) + return reserved_node_id, name_ + + return self._node_id, name def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: draft_vars = [] diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 8f323ebb8b..5fca444723 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -22,10 +22,10 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.entities import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 66976058c0..785f6f108c 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -12,48 +12,51 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams, WorkflowNodeExecution -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file import File -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.repositories.human_input_form_repository import FormCreateParams -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import load_into_variable_pool -from dify_graph.variables import VariableBase -from dify_graph.variables.input_entities import VariableEntityType -from dify_graph.variables.variables import Variable +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings +from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType @@ -82,6 +85,8 @@ from .human_input_delivery_test_service import ( from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService from .workflow_restore import apply_published_workflow_snapshot_to_draft +_file_access_controller = DatabaseFileAccessController() + class WorkflowService: """ @@ -486,13 +491,15 @@ class WorkflowService: :raises ValueError: If the model configuration is invalid or credentials fail policy checks """ try: - from core.model_manager import ModelManager - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType + + # Model instance resolution and provider status lookup must reuse the + # same request-scoped runtime so validation does not silently split + # provider discovery and credential reads across different caches. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) # Get model instance to validate provider+model combination - model_manager = ModelManager() - model_manager.get_model_instance( + assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -501,8 +508,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) target_model = None @@ -607,11 +613,10 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) @@ -765,6 +770,7 @@ class WorkflowService: user_id=account.id, user_inputs=user_inputs, workflow=draft_workflow, + node_id=node_id, # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. conversation_variables=[], node_type=node_type, @@ -772,11 +778,13 @@ class WorkflowService: ) else: - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=draft_workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -895,7 +903,6 @@ class WorkflowService: node_id=node_id, node_title=node.title, resolved_default_values=resolved_default_values, - form_token=None, ) return human_input_required.model_dump(mode="json") @@ -995,17 +1002,20 @@ class WorkflowService: if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) + node_data = HumanInputNodeData.model_validate( + normalize_human_input_node_data_for_graph(node_config["data"]), + from_attributes=True, + ) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, ) if delivery_method is None: raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( + delivery_method = apply_dify_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id, + actor_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1055,7 +1065,7 @@ class WorkflowService: node_data: HumanInputNodeData, delivery_method_id: str, ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: + for method in parse_human_input_delivery_methods(node_data): if str(method.id) == delivery_method_id: return method return None @@ -1070,9 +1080,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( - app_id=app_model.id, workflow_execution_id=None, node_id=node_id, form_config=node_data, @@ -1138,7 +1147,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node @@ -1155,11 +1164,13 @@ class WorkflowService: draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -1419,10 +1430,10 @@ class WorkflowService: Raises: ValueError: If the node data format is invalid """ - from dify_graph.nodes.human_input.entities import HumanInputNodeData + from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(node_data) + HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1511,38 +1522,48 @@ def _setup_variable_pool( user_id: str, user_inputs: Mapping[str, Any], workflow: Workflow, + node_id: str, node_type: NodeType, conversation_id: str, conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if is_start_node_type(node_type): - system_variable = SystemVariable( - user_id=user_id, - app_id=workflow.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=workflow.id, - files=files or [], - workflow_execution_id=str(uuid.uuid4()), - ) + system_variable_values: dict[str, Any] = { + "user_id": user_id, + "app_id": workflow.app_id, + "timestamp": int(naive_utc_now().timestamp()), + "workflow_id": workflow.id, + "files": files or [], + "workflow_execution_id": str(uuid.uuid4()), + } - # Only add chatflow-specific variables for non-workflow types + # Only add chatflow-specific variables for non-workflow types. if workflow.type != WorkflowType.WORKFLOW: - system_variable.query = query - system_variable.conversation_id = conversation_id - system_variable.dialogue_count = 1 + system_variable_values.update( + { + "query": query, + "conversation_id": conversation_id, + "dialogue_count": 1, + } + ) + + system_variable = build_system_variables(system_variable_values) else: - system_variable = SystemVariable.default() + system_variable = default_system_variables() # init variable pool - variable_pool = VariablePool( - system_variables=system_variable, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=cast(list[Variable], conversation_variables), # + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variable, + environment_variables=workflow.environment_variables, + conversation_variables=cast(list[Variable], conversation_variables), + ), ) + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1567,7 +1588,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): raise ValueError(f"expected dict for file object, got {type(value)}") - return build_from_mapping(mapping=value, tenant_id=tenant_id) + return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): raise ValueError(f"expected list for file list object, got {type(value)}") @@ -1575,6 +1596,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia return [] if not isinstance(value[0], dict): raise ValueError(f"expected dict for first element in the file list, got {type(value)}") - return build_from_mappings(mappings=value, tenant_id=tenant_id) + return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index a9a8b892c2..dafa36cc34 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -36,7 +37,7 @@ def add_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fc6bf03454..c734e1321b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -67,7 +68,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 432732af95..c9aa8fadb7 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -26,7 +27,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=dataset_collection_binding.id, ) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 7b5cd46b00..41cf7ccbf6 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import exists, select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -44,7 +45,7 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, collection_binding_id=app_annotation_setting.collection_binding_id, ) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 1fe43c3d62..2c07fe0f31 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -64,7 +65,7 @@ def enable_annotation_reply_task( old_dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=old_dataset_collection_binding.provider_name, embedding_model=old_dataset_collection_binding.model_name, collection_binding_id=old_dataset_collection_binding.id, @@ -93,7 +94,7 @@ def enable_annotation_reply_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 6ff34c0e74..f41da1d373 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -5,6 +5,7 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -37,7 +38,7 @@ def update_annotation_to_index_task( dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, collection_binding_id=dataset_collection_binding.id, diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 174aa50343..458099d99e 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -21,8 +21,8 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -239,13 +239,18 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode + response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], + workflow_run_id: str, + app_mode: AppMode, ) -> None: topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) for event in response_stream: try: - payload = json.dumps(event) - except TypeError: + if isinstance(event, BaseModel): + payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) + else: + payload = json.dumps(event, ensure_ascii=False, default=str) + except (TypeError, ValueError): logger.exception("error while encoding event") continue diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index d247cf5cf7..6365400dd1 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -21,8 +21,8 @@ from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 49dee00919..ed8a24b336 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,9 +11,10 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -109,7 +110,7 @@ def batch_create_segment_to_index_task( df = pd.read_csv(file_path) content = [] for _, row in df.iterrows(): - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: data = {"content": row.iloc[0], "answer": row.iloc[1]} else: data = {"content": row.iloc[0]} @@ -119,8 +120,8 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None - if dataset_config["indexing_technique"] == "high_quality": - model_manager = ModelManager() + if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: + model_manager = ModelManager.for_tenant(tenant_id=dataset_config["tenant_id"]) embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], provider=dataset_config["embedding_model_provider"], @@ -159,7 +160,7 @@ def batch_create_segment_to_index_task( status="completed", completed_at=naive_utc_now(), ) - if document_config["doc_form"] == "qa_model": + if document_config["doc_form"] == IndexStructureType.QA_INDEX: segment_document.answer = segment["answer"] segment_document.word_count += len(segment["answer"]) word_count_change += segment_document.word_count diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index e05d63426c..23a80fa106 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -10,6 +10,7 @@ from configs import dify_config from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now @@ -126,7 +127,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): logger.warning("Dataset %s not found after indexing", dataset_id) return - if dataset.indexing_technique == "high_quality": + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: summary_index_setting = dataset.summary_index_setting if summary_index_setting and summary_index_setting.get("enable"): # expire all session to get latest document's indexing status @@ -150,7 +151,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) if ( document.indexing_status == IndexingStatus.COMPLETED - and document.doc_form != "qa_model" + and document.doc_form != IndexStructureType.QA_INDEX and document.need_summary is True ): try: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index 6493833edc..e3d82d2851 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -7,6 +7,7 @@ import click from celery import shared_task from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -59,7 +60,7 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: return # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary generation for dataset {dataset_id}: " diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index dd3b6a4530..fd743205a1 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -7,10 +7,10 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d241783359..f8ae3f4b6e 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,10 +11,10 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from dify_graph.runtime import GraphRuntimeState, VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 39c2f4103e..6f490ab7ea 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -9,6 +9,7 @@ from celery import shared_task from sqlalchemy import or_, select from core.db.session_factory import session_factory +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument from services.summary_index_service import SummaryIndexService @@ -52,7 +53,7 @@ def regenerate_summary_index_task( return # Only regenerate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": + if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: logger.info( click.style( f"Skipping summary regeneration for dataset {dataset_id}: " @@ -106,7 +107,7 @@ def regenerate_summary_index_task( ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived - DatasetDocument.doc_form != "qa_model", # Skip qa_model documents + DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .all() @@ -209,7 +210,7 @@ def regenerate_summary_index_task( for dataset_document in dataset_documents: # Skip qa_model documents - if dataset_document.doc_form == "qa_model": + if dataset_document.doc_form == IndexStructureType.QA_INDEX: continue try: diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 75ae1f6316..25ea53dfac 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -27,8 +27,8 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.enums import WorkflowExecutionStatus from enums.quota_type import QuotaType, unlimited +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, @@ -179,7 +179,7 @@ def _record_trigger_failure_log( app_id=workflow.app_id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=created_by_role, created_by=created_by, ) diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index f41118e592..ae1c2991c9 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -12,8 +12,8 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from graphon.entities.workflow_execution import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 466ef6c858..a0fd739325 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -12,10 +12,10 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 4fdbb7d9f3..a876b0c4aa 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -2,7 +2,7 @@ from collections.abc import Generator from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from dify_graph.node_events import StreamCompletedEvent +from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3e79792b5b..b2de11b068 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,7 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index db4bbc1ca1..878d9b24df 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -6,10 +6,11 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -192,19 +197,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -313,7 +315,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -337,7 +339,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -364,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 1d7b835fd2..a942690cbd 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -13,6 +13,7 @@ from unittest.mock import patch import pytest from extensions.ext_redis import redis_client +from models.enums import ApiTokenType from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken @@ -279,7 +280,7 @@ class TestEndToEndCacheFlow: test_token = ApiToken() test_token.id = "test-e2e-id" test_token.token = test_token_value - test_token.type = test_scope + test_token.type = ApiTokenType.APP test_token.app_id = "test-app" test_token.tenant_id = "test-tenant" test_token.last_used_at = None diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 4e184c93fd..c4146d5ccd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -8,23 +8,23 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.model import PluginModelClient # import monkeypatch -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ( AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, ModelType, ) -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity class MockModelClass(PluginModelClient): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 9d3a869691..0b21ff1d2a 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,15 +6,15 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import StringVariable +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index bc83c6cc12..f6f4cf260b 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -5,8 +5,8 @@ import pytest from sqlalchemy import delete from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -192,7 +192,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -423,7 +423,7 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 5b0f86fed1..a9a2617bae 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType @@ -15,7 +15,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3a2b6b866..7573e00872 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -6,13 +6,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f885f69e55..17ea7de881 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,12 +9,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.graph import Graph -from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -54,7 +55,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -81,6 +82,7 @@ def init_http_node(config: dict): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) return node @@ -189,20 +191,20 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from dify_graph.enums import BuiltinNodeTypes - from dify_graph.nodes.http_request.entities import ( + from core.workflow.system_variables import build_system_variables + from graphon.enums import BuiltinNodeTypes + from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) - from dify_graph.nodes.http_request.exc import AuthorizationConfigError - from dify_graph.nodes.http_request.executor import Executor - from dify_graph.runtime import VariablePool - from dify_graph.system_variable import SystemVariable + from graphon.nodes.http_request.exc import AuthorizationConfigError + from graphon.nodes.http_request.executor import Executor + from graphon.runtime import VariablePool # Create variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="test", files=[]), + system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -700,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock): # Create independent variable pool for this test only variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -728,6 +730,7 @@ def test_nested_object_variable_selector(setup_http_mock): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d628348f1e..fa5d63cfbf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -7,14 +7,16 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import LLMNode +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.protocols import HttpClientProtocol +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -51,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", app_id=app_id, workflow_id=workflow_id, @@ -66,6 +68,11 @@ def init_llm_node(config: dict) -> LLMNode: variable_pool.add(["abc", "output"], "sunny") graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol) + prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [ + message.model_dump(mode="json") for message in prompt_messages + ] + llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( id=str(uuid.uuid4()), @@ -75,7 +82,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), - template_renderer=MagicMock(spec=TemplateRenderer), + llm_file_saver=llm_file_saver, + prompt_message_serializer=prompt_message_serializer, http_client=MagicMock(spec=HttpClientProtocol), ) @@ -115,8 +123,8 @@ def test_execute_llm(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -159,8 +167,8 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(*_args, **_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + def mock_fetch_prompt_messages_1(**_kwargs): + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), @@ -231,8 +239,8 @@ def test_execute_llm_with_jinja2(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -276,7 +284,7 @@ def test_execute_llm_with_jinja2(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 62d9af0196..367b5bbc11 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyPromptMessageSerializer +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -56,7 +57,7 @@ def init_parameter_extractor_node(config: dict, memory=None): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), user_inputs={}, @@ -77,6 +78,7 @@ def init_parameter_extractor_node(config: dict, memory=None): model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), memory=memory, + prompt_message_serializer=DifyPromptMessageSerializer(), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 7bb4f905c3..9e3e1a47e3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -3,12 +3,12 @@ import uuid from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -66,7 +66,7 @@ def test_execute_template_transform(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -90,7 +90,7 @@ def test_execute_template_transform(): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - template_renderer=_SimpleJinja2Renderer(), + jinja2_template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index a6717ada31..f9ec51ee10 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.nodes.tool.tool_node import ToolNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +41,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -64,11 +65,12 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=DifyToolNodeRuntime(init_params.run_context), ) return node -def test_tool_variable_invoke(): +def test_tool_variable_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", @@ -103,7 +105,7 @@ def test_tool_variable_invoke(): assert item.node_run_result.outputs.get("text") is not None -def test_tool_mixed_invoke(): +def test_tool_mixed_invoke(monkeypatch): node = init_tool_node( config={ "id": "1", diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef0ca4232d..48bf3ca446 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -33,6 +33,9 @@ from extensions.ext_database import db logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +DEFAULT_SANDBOX_TEST_IMAGE = "langgenius/dify-sandbox:0.2.14" +SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" + class _CloserProtocol(Protocol): """_Closer is any type which implement the close() method.""" @@ -163,11 +166,11 @@ class DifyTestContainers: wait_for_logs(self.redis, "Ready to accept connections", timeout=30) logger.info("Redis container is ready and accepting connections") - # Start Dify Sandbox container for code execution environment - # Dify Sandbox provides a secure environment for executing user code - # Use pinned version 0.2.12 to match production docker-compose configuration + # Start Dify Sandbox container for code execution environment. + # Default to the production-pinned image while allowing local overrides for debugging. logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network) + sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE) + self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -177,7 +180,12 @@ class DifyTestContainers: sandbox_port = self.dify_sandbox.get_exposed_port(8194) os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" - logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + logger.info( + "Dify Sandbox container started successfully - Image: %s Host: %s, Port: %s", + sandbox_image, + sandbox_host, + sandbox_port, + ) # Wait for Dify Sandbox to be ready logger.info("Waiting for Dify Sandbox to be ready to accept connections...") @@ -187,7 +195,7 @@ class DifyTestContainers: # Start Dify Plugin Daemon container for plugin management # Dify Plugin Daemon provides plugin lifecycle management and execution logger.info("Initializing Dify Plugin Daemon container...") - self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network( + self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.3-local").with_network( self.network ) self.dify_plugin_daemon.with_exposed_ports(5002) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 4f606dccb8..5b51510388 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py new file mode 100644 index 0000000000..6b51ec98bc --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -0,0 +1,342 @@ +"""Authenticated controller integration tests for console message APIs.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.app.message import ChatMessagesQuery, FeedbackExportQuery, MessageFeedbackPayload +from controllers.console.app.message import attach_message_extra_contents as _attach_message_extra_contents +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app_id: str, account_id: str, mode: AppMode) -> Conversation: + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Test Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + account_id: str, + *, + created_at_offset_seconds: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(seconds=created_at_offset_seconds) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=1, + message_unit_price=Decimal("0.0001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=1, + answer_unit_price=Decimal("0.0001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal("0.0002"), + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +class TestMessageValidators: + def test_chat_messages_query_validators(self) -> None: + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self) -> None: + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self) -> None: + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +def test_chat_message_list_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": str(uuid4())}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_chat_message_list_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, account.id, created_at_offset_seconds=0) + second = _create_message( + db_session_with_containers, + app.id, + conversation.id, + account.id, + created_at_offset_seconds=1, + ) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages", + query_string={"conversation_id": conversation.id, "limit": 1}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["limit"] == 1 + assert payload["has_more"] is True + assert len(payload["data"]) == 1 + assert payload["data"][0]["id"] == second.id + + +def test_message_feedback_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": str(uuid4()), "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_message_feedback_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + response = test_client_with_containers.post( + f"/console/api/apps/{app.id}/feedbacks", + json={"message_id": message.id, "rating": "like"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + feedback = db_session_with_containers.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == message.id) + ) + assert feedback is not None + assert feedback.rating == FeedbackRating.LIKE + assert feedback.from_account_id == account.id + + +def test_message_annotation_count( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + db_session_with_containers.add( + MessageAnnotation( + app_id=app.id, + conversation_id=conversation.id, + message_id=message.id, + question="Q", + content="A", + account_id=account.id, + ) + ) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/annotations/count", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"count": 1} + + +def test_message_suggested_questions_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"data": ["q1", "q2"]} + + +@pytest.mark.parametrize( + ("exc", "expected_status", "expected_code"), + [ + (MessageNotExistsError(), 404, "not_found"), + (ConversationNotExistsError(), 404, "not_found"), + (ProviderTokenNotInitError(), 400, "provider_not_initialize"), + (QuotaExceededError(), 400, "provider_quota_exceeded"), + (ModelCurrentlyNotSupportError(), 400, "model_currently_not_support"), + (SuggestedQuestionsAfterAnswerDisabledError(), 403, "app_suggested_questions_after_answer_disabled"), + (Exception(), 500, "internal_server_error"), + ], +) +def test_message_suggested_questions_errors( + exc: Exception, + expected_status: int, + expected_code: str, + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + message_id = str(uuid4()) + + with patch( + "controllers.console.app.message.MessageService.get_suggested_questions_after_answer", + side_effect=exc, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/chat-messages/{message_id}/suggested-questions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == expected_status + payload = response.get_json() + assert payload is not None + assert payload["code"] == expected_code + + +def test_message_feedback_export_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("services.feedback_service.FeedbackService.export_feedbacks", return_value={"exported": True}): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/feedbacks/export", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"exported": True} + + +def test_message_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, app.mode) + message = _create_message(db_session_with_containers, app.id, conversation.id, account.id) + + with patch( + "controllers.console.app.message.attach_message_extra_contents", + side_effect=_attach_message_extra_contents, + ): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/messages/{message.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == message.id diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..963cfe53e5 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,334 @@ +"""Controller integration tests for console statistic routes.""" + +from datetime import timedelta +from decimal import Decimal +from unittest.mock import patch +from uuid import uuid4 + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import AppMode, Conversation, Message, MessageFeedback +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation( + db_session: Session, + app_id: str, + account_id: str, + *, + mode: AppMode, + created_at_offset_days: int = 0, +) -> Conversation: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + conversation = Conversation( + app_id=app_id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=mode, + name="Stats Conversation", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + from_source=ConversationFromSource.CONSOLE, + from_account_id=account_id, + created_at=created_at, + updated_at=created_at, + ) + db_session.add(conversation) + db_session.commit() + return conversation + + +def _create_message( + db_session: Session, + app_id: str, + conversation_id: str, + *, + from_account_id: str | None, + from_end_user_id: str | None = None, + message_tokens: int = 1, + answer_tokens: int = 1, + total_price: Decimal = Decimal("0.01"), + provider_response_latency: float = 1.0, + created_at_offset_days: int = 0, +) -> Message: + created_at = naive_utc_now() + timedelta(days=created_at_offset_days) + message = Message( + app_id=app_id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation_id, + inputs={}, + query="Hello", + message={"type": "text", "content": "Hello"}, + message_tokens=message_tokens, + message_unit_price=Decimal("0.001"), + message_price_unit=Decimal("0.001"), + answer="Hi there", + answer_tokens=answer_tokens, + answer_unit_price=Decimal("0.001"), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=provider_response_latency, + total_price=total_price, + currency="USD", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=from_end_user_id, + from_account_id=from_account_id, + created_at=created_at, + updated_at=created_at, + app_mode=AppMode.CHAT, + ) + db_session.add(message) + db_session.commit() + return message + + +def _create_like_feedback( + db_session: Session, + app_id: str, + conversation_id: str, + message_id: str, + account_id: str, +) -> None: + db_session.add( + MessageFeedback( + app_id=app_id, + conversation_id=conversation_id, + message_id=message_id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + from_account_id=account_id, + ) + ) + db_session.commit() + + +def test_daily_message_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["message_count"] == 1 + + +def test_daily_conversation_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-conversations", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["conversation_count"] == 1 + + +def test_daily_terminals_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=None, + from_end_user_id=str(uuid4()), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-end-users", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["terminal_count"] == 1 + + +def test_daily_token_cost_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + message_tokens=40, + answer_tokens=60, + total_price=Decimal("0.02"), + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/token-costs", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["data"][0]["token_count"] == 100 + assert Decimal(payload["data"][0]["total_price"]) == Decimal("0.02") + + +def test_average_session_interaction_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-session-interactions", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["interactions"] == 2.0 + + +def test_user_satisfaction_rate_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + first = _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + for _ in range(9): + _create_message(db_session_with_containers, app.id, conversation.id, from_account_id=account.id) + _create_like_feedback(db_session_with_containers, app.id, conversation.id, first.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/user-satisfaction-rate", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["rate"] == 100.0 + + +def test_average_response_time_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.COMPLETION) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + provider_response_latency=1.234, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/average-response-time", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["latency"] == 1234.0 + + +def test_tokens_per_second_statistic( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + conversation = _create_conversation(db_session_with_containers, app.id, account.id, mode=app.mode) + _create_message( + db_session_with_containers, + app.id, + conversation.id, + from_account_id=account.id, + answer_tokens=31, + provider_response_latency=2.0, + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/tokens-per-second", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json()["data"][0]["tps"] == 15.5 + + +def test_invalid_time_range( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch("controllers.console.app.statistic.parse_time_range", side_effect=ValueError("Invalid time")): + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=invalid&end=invalid", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "Invalid time" + + +def test_time_range_params_passed( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + import datetime + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + start = datetime.datetime.now() + end = datetime.datetime.now() + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(start, end)) as mock_parse: + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/statistics/daily-messages?start=something&end=something", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + mock_parse.assert_called_once_with("something", "something", "UTC") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..290be87697 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,415 @@ +"""Authenticated controller integration tests for workflow draft variable APIs.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from factories.variable_factory import segment_to_variable +from graphon.variables.segments import StringSegment +from models import Workflow +from models.model import AppMode +from models.workflow import WorkflowDraftVariable +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) + + +def _create_draft_workflow( + db_session: Session, + app_id: str, + tenant_id: str, + account_id: str, + *, + environment_variables: list | None = None, + conversation_variables: list | None = None, +) -> Workflow: + workflow = Workflow.new( + tenant_id=tenant_id, + app_id=app_id, + type="workflow", + version=Workflow.VERSION_DRAFT, + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=account_id, + environment_variables=environment_variables or [], + conversation_variables=conversation_variables or [], + rag_pipeline_variables=[], + ) + db_session.add(workflow) + db_session.commit() + return workflow + + +def _create_node_variable( + db_session: Session, + app_id: str, + user_id: str, + *, + node_id: str = "node_1", + name: str = "test_var", +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + user_id=user_id, + node_id=node_id, + name=name, + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + visible=True, + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _create_system_variable( + db_session: Session, app_id: str, user_id: str, name: str = "query" +) -> WorkflowDraftVariable: + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + user_id=user_id, + name=name, + value=StringSegment(value="system-value"), + node_execution_id=str(uuid.uuid4()), + editable=True, + ) + db_session.add(variable) + db_session.commit() + return variable + + +def _build_environment_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[ENVIRONMENT_VARIABLE_NODE_ID, name], + name=name, + description=f"Environment variable {name}", + ) + + +def _build_conversation_variable(name: str, value: str): + return segment_to_variable( + segment=StringSegment(value=value), + selector=[CONVERSATION_VARIABLE_NODE_ID, name], + name=name, + description=f"Conversation variable {name}", + ) + + +def test_workflow_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables?page=1&limit=20", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"items": [], "total": 0} + + +def test_workflow_variable_collection_get_not_exist( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "draft_workflow_not_exist" + + +def test_workflow_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_node_variable(db_session_with_containers, app.id, account.id) + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_2", name="other_var") + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + remaining = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + ) + ).all() + assert remaining == [] + + +def test_node_variable_collection_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other") + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [node_variable.id] + + +def test_node_variable_collection_get_invalid_node_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/nodes/sys/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "invalid_param" + + +def test_node_variable_collection_delete( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + target = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + untouched = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456") + target_id = target.id + untouched_id = untouched.id + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/nodes/node_123/variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == target_id)) + is None + ) + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == untouched_id)) + is not None + ) + + +def test_variable_api_get_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "test_var" + + +def test_variable_api_get_not_found( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/variables/{uuid.uuid4()}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "not_found" + + +def test_variable_api_patch_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.patch( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + json={"name": "renamed_var"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["id"] == variable.id + assert payload["name"] == "renamed_var" + + refreshed = db_session_with_containers.scalar( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id) + ) + assert refreshed is not None + assert refreshed.name == "renamed_var" + + +def test_variable_api_delete_success( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.delete( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_variable_reset_api_put_success_returns_no_content_without_execution( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow(db_session_with_containers, app.id, tenant.id, account.id) + variable = _create_node_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.put( + f"/console/api/apps/{app.id}/workflows/draft/variables/{variable.id}/reset", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar(select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable.id)) + is None + ) + + +def test_conversation_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + conversation_variables=[_build_conversation_variable("session_name", "Alice")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/conversation-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["name"] for item in payload["items"]] == ["session_name"] + + created = db_session_with_containers.scalars( + select(WorkflowDraftVariable).where( + WorkflowDraftVariable.app_id == app.id, + WorkflowDraftVariable.user_id == account.id, + WorkflowDraftVariable.node_id == CONVERSATION_VARIABLE_NODE_ID, + ) + ).all() + assert len(created) == 1 + + +def test_system_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + variable = _create_system_variable(db_session_with_containers, app.id, account.id) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/system-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert [item["id"] for item in payload["items"]] == [variable.id] + + +def test_environment_variable_collection_get( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) + _create_draft_workflow( + db_session_with_containers, + app.id, + tenant.id, + account.id, + environment_variables=[_build_environment_variable("api_key", "secret-value")], + ) + + response = test_client_with_containers.get( + f"/console/api/apps/{app.id}/workflows/draft/environment-variables", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["items"][0]["name"] == "api_key" + assert payload["items"][0]["value"] == "secret-value" diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..00309c25d6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,131 @@ +"""Controller integration tests for API key data source auth routes.""" + +import json +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.source import DataSourceApiKeyAuthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_api_key_auth_data_source( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert len(payload["sources"]) == 1 + assert payload["sources"][0]["provider"] == "custom_provider" + + +def test_get_api_key_auth_data_source_empty( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + response = test_client_with_containers.get( + "/console/api/api-key-auth/data-source", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"sources": []} + + +def test_create_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth"), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + + +def test_create_binding_failure( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth", + side_effect=ValueError("Invalid structure"), + ), + ): + response = test_client_with_containers.post( + "/console/api/api-key-auth/data-source/binding", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 500 + payload = response.get_json() + assert payload is not None + assert payload["code"] == "auth_failed" + assert payload["message"] == "Invalid structure" + + +def test_delete_binding_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant.id, + category="api_key", + provider="custom_provider", + credentials=json.dumps({"auth_type": "api_key", "config": {"api_key": "encrypted"}}), + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + response = test_client_with_containers.delete( + f"/console/api/api-key-auth/data-source/{binding.id}", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 204 + assert ( + db_session_with_containers.scalar( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.id == binding.id) + ) + is None + ) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 0000000000..81b5423261 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,120 @@ +"""Controller integration tests for console OAuth data source routes.""" + +from unittest.mock import MagicMock, patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.source import DataSourceOauthBinding +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_get_oauth_url_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + provider = MagicMock() + provider.get_authorization_url.return_value = "http://oauth.provider/auth" + + with ( + patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}), + patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None), + ): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/notion", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert tenant.id == account.current_tenant_id + assert response.status_code == 200 + assert response.get_json() == {"data": "http://oauth.provider/auth"} + provider.get_authorization_url.assert_called_once() + + +def test_get_oauth_url_invalid_provider( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get( + "/console/api/oauth/data-source/unknown_provider", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_oauth_callback_successful(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion?code=mock_code") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + +def test_oauth_callback_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/notion") + + assert response.status_code == 302 + assert "error=Access%20denied" in response.location + + +def test_oauth_callback_invalid_provider(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/callback/invalid?code=mock_code") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid provider"} + + +def test_get_binding_successful(test_client_with_containers: FlaskClient) -> None: + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=auth_code_123") + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.get_access_token.assert_called_once_with("auth_code_123") + + +def test_get_binding_missing_code(test_client_with_containers: FlaskClient) -> None: + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": MagicMock()}): + response = test_client_with_containers.get("/console/api/oauth/data-source/binding/notion?code=") + + assert response.status_code == 400 + assert response.get_json() == {"error": "Invalid code"} + + +def test_sync_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, tenant = create_console_account_and_tenant(db_session_with_containers) + binding = DataSourceOauthBinding( + tenant_id=tenant.id, + access_token="test-access-token", + provider="notion", + source_info={"workspace_name": "Workspace", "workspace_icon": None, "workspace_id": tenant.id, "pages": []}, + disabled=False, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + + provider = MagicMock() + with patch("controllers.console.auth.data_source_oauth.get_oauth_providers", return_value={"notion": provider}): + response = test_client_with_containers.get( + f"/console/api/oauth/data-source/notion/{binding.id}/sync", + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + provider.sync_data_source.assert_called_once_with(binding.id) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_email_register.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 724c80f18c..879c337319 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for email register controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestEmailRegisterSendEmailApi: - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") @@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - mock_session_cls, app, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False mock_account = MagicMock() - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = mock_account feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), @@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi: assert response == {"result": "success", "data": "token-123"} mock_is_freeze.assert_called_once_with("invitee@example.com") mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) mock_extract_ip.assert_called_once() mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") @@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi: feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -114,7 +107,6 @@ class TestEmailRegisterResetApi: @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") @patch("controllers.console.auth.email_register.AccountService.login") @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") @@ -125,7 +117,6 @@ class TestEmailRegisterResetApi: mock_get_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_create_account, mock_login, mock_reset_login_rate, @@ -136,14 +127,10 @@ class TestEmailRegisterResetApi: token_pair = MagicMock() token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} mock_login.return_value = token_pair - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = None feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -159,19 +146,19 @@ class TestEmailRegisterResetApi: mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_forgot_password.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 8403777dc9..7b7393dade 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for forgot password controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestForgotPasswordSendEmailApi: - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - mock_session_cls, app, ): mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_email.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) controller_features = SimpleNamespace(is_allow_register=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( "controllers.console.auth.forgot_password.FeatureService.get_system_features", return_value=controller_features, @@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_send_email.assert_called_once_with( account=mock_account, email="user@example.com", @@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @@ -126,7 +120,6 @@ class TestForgotPasswordResetApi: mock_get_reset_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_update_account, app, ): @@ -134,12 +127,8 @@ class TestForgotPasswordResetApi: mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - wraps_features = SimpleNamespace(enable_email_password_login=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): @@ -157,20 +146,22 @@ class TestForgotPasswordResetApi: assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_update_account.assert_called_once() -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" + from unittest.mock import MagicMock + mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py similarity index 92% rename from api/tests/unit_tests/controllers/console/auth/test_oauth.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 6345c2ab23..a2f1328579 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -1,7 +1,10 @@ +"""Testcontainers integration tests for OAuth controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.mark.parametrize( ("github_config", "google_config", "expected_github", "expected_google"), @@ -64,10 +65,8 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_oauth_provider(self): @@ -131,10 +130,8 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def oauth_setup(self): @@ -190,15 +187,8 @@ class TestOAuthCallback: (KeyError("Missing key"), "OAuth process failed"), ], ) - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions( - self, mock_get_providers, mock_db, resource, app, exception, expected_error - ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - + def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): # Import the real requests module to create a proper exception import httpx @@ -258,7 +248,6 @@ class TestOAuthCallback: ) @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") @@ -269,7 +258,6 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - mock_db, mock_tenant_service, mock_account_service, resource, @@ -278,10 +266,6 @@ class TestOAuthCallback: account_status, expected_redirect, ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - mock_db.session.commit = MagicMock() mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -306,14 +290,12 @@ class TestOAuthCallback: @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") def test_should_activate_pending_account( self, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -338,12 +320,10 @@ class TestOAuthCallback: assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None - mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.redirect") @@ -352,7 +332,6 @@ class TestOAuthCallback: mock_redirect, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -414,6 +393,10 @@ class TestOAuthCallback: class TestAccountGeneration: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def user_info(self): return OAuthUserInfo(id="123", name="Test User", email="test@example.com") @@ -425,15 +408,10 @@ class TestAccountGeneration: return account @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.console.auth.oauth.Session") @patch("controllers.console.auth.oauth.Account") - @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account + self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account ): - # Mock db.engine for Session creation - mock_db.engine = MagicMock() - # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) @@ -443,15 +421,14 @@ class TestAccountGeneration: # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None - mock_session_instance = MagicMock() - mock_session.return_value.__enter__.return_value = mock_session_instance mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account - mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + mock_get_account.assert_called_once() - def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() first_result = MagicMock() first_result.scalar_one_or_none.return_value = None @@ -462,7 +439,7 @@ class TestAccountGeneration: result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert result == expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( @@ -478,10 +455,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_handle_account_generation_scenarios( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, @@ -519,10 +494,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_register_with_lowercase_email( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..2ef27133d8 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,365 @@ +"""Controller integration tests for console OAuth server routes.""" + +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.model import OAuthProviderApp +from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + ensure_dify_setup, +) + + +def _build_oauth_provider_app() -> OAuthProviderApp: + return OAuthProviderApp( + app_icon="icon_url", + client_id="test_client_id", + client_secret="test_secret", + app_label={"en-US": "Test App"}, + redirect_uris=["http://localhost/callback"], + scope="read,write", + ) + + +def test_oauth_provider_successful_post( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload is not None + assert payload["app_icon"] == "icon_url" + assert payload["app_label"] == {"en-US": "Test App"} + assert payload["scope"] == "read,write" + + +def test_oauth_provider_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload is not None + assert "redirect_uri is invalid" in payload["message"] + + +def test_oauth_provider_invalid_client_id( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + response = test_client_with_containers.post( + "/console/api/oauth/provider", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ) + + assert response.status_code == 404 + payload = response.get_json() + assert payload is not None + assert "client_id is invalid" in payload["message"] + + +def test_oauth_authorize_successful( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="auth_code_123", + ) as mock_sign, + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/authorize", + json={"client_id": "test_client_id"}, + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 200 + assert response.get_json() == {"code": "auth_code_123"} + mock_sign.assert_called_once_with("test_client_id", account.id) + + +def test_oauth_token_authorization_code_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("access_123", "refresh_123"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "access_123", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "refresh_123", + } + + +def test_oauth_token_authorization_code_grant_missing_code( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "code is required" + + +def test_oauth_token_authorization_code_grant_invalid_secret( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "client_secret is invalid" + + +def test_oauth_token_authorization_code_grant_invalid_redirect_uri( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "redirect_uri is invalid" + + +def test_oauth_token_refresh_token_grant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token", + return_value=("new_access", "new_refresh"), + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "access_token": "new_access", + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": "new_refresh", + } + + +def test_oauth_token_refresh_token_grant_missing_token( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "refresh_token"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "refresh_token is required" + + +def test_oauth_token_invalid_grant_type( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/token", + json={"client_id": "test_client_id", "grant_type": "invalid_grant"}, + ) + + assert response.status_code == 400 + assert response.get_json()["message"] == "invalid grant_type" + + +def test_oauth_account_successful_retrieval( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + account.avatar = "avatar_url" + db_session_with_containers.commit() + + with ( + patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ), + patch( + "controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token", + return_value=account, + ), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ) + + assert response.status_code == 200 + assert response.get_json() == { + "name": "Test User", + "email": account.email, + "avatar": "avatar_url", + "interface_language": "en-US", + "timezone": "UTC", + } + + +def test_oauth_account_missing_authorization_header( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Authorization header is required"} + + +def test_oauth_account_invalid_authorization_header_format( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + ensure_dify_setup(db_session_with_containers) + + with patch( + "controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app", + return_value=_build_oauth_provider_app(), + ): + response = test_client_with_containers.post( + "/console/api/oauth/provider/account", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ) + + assert response.status_code == 401 + assert response.get_json() == {"error": "Invalid Authorization header format"} diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py similarity index 81% rename from api/tests/unit_tests/controllers/console/auth/test_password_reset.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 9488cf528e..8f9db287e3 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -1,17 +1,10 @@ -""" -Test suite for password reset authentication flows. +"""Testcontainers integration tests for password reset authentication flows.""" -This module tests the password reset mechanism including: -- Password reset email sending -- Verification code validation -- Password reset with token -- Rate limiting and security checks -""" +from __future__ import annotations from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -28,31 +21,12 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError -@pytest.fixture(autouse=True) -def _mock_forgot_password_session(): - with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - mock_session_cls.return_value.__exit__.return_value = None - yield mock_session - - -@pytest.fixture(autouse=True) -def _mock_forgot_password_db(): - with patch("controllers.console.auth.forgot_password.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -62,7 +36,6 @@ class TestForgotPasswordSendEmailApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -73,20 +46,10 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, ): - """ - Test successful password reset email sending. - - Verifies that: - - Email is sent to valid account - - Reset token is generated and returned - - IP rate limiting is checked - """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" @@ -104,9 +67,8 @@ class TestForgotPasswordSendEmailApi: assert response["data"] == "reset_token_123" mock_send_email.assert_called_once() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): """ Test password reset email blocked by IP rate limit. @@ -115,7 +77,6 @@ class TestForgotPasswordSendEmailApi: - No email is sent when rate limited """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -133,7 +94,6 @@ class TestForgotPasswordSendEmailApi: (None, "en-US"), # Defaults to en-US when not provided ], ) - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @@ -144,7 +104,6 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - mock_wraps_db, app, mock_account, language_input, @@ -158,7 +117,6 @@ class TestForgotPasswordSendEmailApi: - Unsupported languages default to en-US """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_account.return_value = mock_account mock_send_email.return_value = "token" @@ -180,13 +138,9 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -199,7 +153,6 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): """ @@ -212,7 +165,6 @@ class TestForgotPasswordCheckApi: - Rate limit is reset on success """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_generate_token.return_value = (None, "new_token") @@ -236,7 +188,6 @@ class TestForgotPasswordCheckApi: ) mock_reset_rate_limit.assert_called_once_with("test@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @@ -249,10 +200,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - mock_db, app, ): - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} mock_generate_token.return_value = (None, "fresh-token") @@ -271,9 +220,8 @@ class TestForgotPasswordCheckApi: mock_revoke_token.assert_called_once_with("upper_token") mock_reset_rate_limit.assert_called_once_with("user@example.com") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app): """ Test code verification blocked by rate limit. @@ -282,7 +230,6 @@ class TestForgotPasswordCheckApi: - Prevents brute force attacks on verification codes """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True # Act & Assert @@ -295,10 +242,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(EmailPasswordResetLimitError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with invalid token. @@ -306,7 +252,6 @@ class TestForgotPasswordCheckApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = None @@ -320,10 +265,9 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): """ Test code verification with mismatched email. @@ -332,7 +276,6 @@ class TestForgotPasswordCheckApi: - Prevents token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} @@ -346,11 +289,10 @@ class TestForgotPasswordCheckApi: with pytest.raises(InvalidEmailError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, mock_db, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): """ Test code verification with incorrect code. @@ -359,7 +301,6 @@ class TestForgotPasswordCheckApi: - Rate limit counter is incremented """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} @@ -380,11 +321,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self): - """Create Flask test application.""" - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_account(self): @@ -394,7 +332,6 @@ class TestForgotPasswordResetApi: account.name = "Test User" return account - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -405,7 +342,6 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - mock_wraps_db, app, mock_account, ): @@ -418,7 +354,6 @@ class TestForgotPasswordResetApi: - Success response is returned """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -436,9 +371,8 @@ class TestForgotPasswordResetApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with("valid_token") - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, mock_db, app): + def test_reset_password_mismatch(self, mock_get_data, app): """ Test password reset with mismatched passwords. @@ -447,7 +381,6 @@ class TestForgotPasswordResetApi: - No password update occurs """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} # Act & Assert @@ -460,9 +393,8 @@ class TestForgotPasswordResetApi: with pytest.raises(PasswordMismatchError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, mock_db, app): + def test_reset_password_invalid_token(self, mock_get_data, app): """ Test password reset with invalid token. @@ -470,7 +402,6 @@ class TestForgotPasswordResetApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -483,9 +414,8 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, mock_db, app): + def test_reset_password_wrong_phase(self, mock_get_data, app): """ Test password reset with token not in reset phase. @@ -494,7 +424,6 @@ class TestForgotPasswordResetApi: - Prevents use of verification-phase tokens for reset """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "verify"} # Act & Assert @@ -507,13 +436,10 @@ class TestForgotPasswordResetApi: with pytest.raises(InvalidTokenError): api.post() - @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found( - self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app - ): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): """ Test password reset for non-existent account. @@ -521,7 +447,6 @@ class TestForgotPasswordResetApi: - AccountNotFound is raised when account doesn't exist """ # Arrange - mock_wraps_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} mock_get_account.return_value = None diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py new file mode 100644 index 0000000000..9e2084f393 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -0,0 +1,85 @@ +"""Shared helpers for authenticated console controller integration tests.""" + +import uuid + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HEADER_NAME_CSRF_TOKEN +from libs.datetime_utils import naive_utc_now +from libs.token import _real_cookie_name, generate_csrf_token +from models import Account, DifySetup, Tenant, TenantAccountJoin +from models.account import AccountStatus, TenantAccountRole +from models.model import App, AppMode +from services.account_service import AccountService + + +def ensure_dify_setup(db_session: Session) -> None: + """Create a setup marker once so setup-protected console routes can be exercised.""" + if db_session.scalar(select(DifySetup).limit(1)) is not None: + return + + db_session.add(DifySetup(version=dify_config.project.version)) + db_session.commit() + + +def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: + """Create an initialized owner account with a current tenant.""" + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="Test User", + interface_language="en-US", + status=AccountStatus.ACTIVE, + ) + account.initialized_at = naive_utc_now() + db_session.add(account) + db_session.commit() + + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + + db_session.add( + TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + ) + db_session.commit() + + account.set_tenant_id(tenant.id) + account.timezone = "UTC" + db_session.commit() + + ensure_dify_setup(db_session) + return account, tenant + + +def create_console_app(db_session: Session, tenant_id: str, account_id: str, mode: AppMode) -> App: + """Create a minimal app row that can be loaded by get_app_model.""" + app = App( + tenant_id=tenant_id, + name="Test App", + mode=mode, + enable_site=True, + enable_api=True, + created_by=account_id, + ) + db_session.add(app) + db_session.commit() + return app + + +def authenticate_console_client(test_client: FlaskClient, account: Account) -> dict[str, str]: + """Attach console auth cookies/headers for endpoints guarded by login_required.""" + access_token = AccountService.get_account_jwt_token(account) + csrf_token = generate_csrf_token(account.id) + test_client.set_cookie(_real_cookie_name("csrf_token"), csrf_token, domain="localhost") + return { + "Authorization": f"Bearer {access_token}", + HEADER_NAME_CSRF_TOKEN: csrf_token, + } diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 96fb7ea293..b8840c4ba8 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,17 +31,18 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.variable_pool import SystemVariable, VariablePool +from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from graphon.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper +from graphon.runtime.variable_pool import VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -212,7 +213,7 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) @@ -544,7 +545,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from dify_graph.graph_events.graph import ( + from graphon.graph_events.graph import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 781e297fa4..00d7496a40 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document @@ -38,7 +39,7 @@ class TestGetAvailableDatasetsIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -55,7 +56,7 @@ class TestGetAvailableDatasetsIntegration: name=f"Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -112,7 +113,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived @@ -165,7 +166,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, @@ -218,7 +219,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=status, # Not completed enabled=True, archived=False, @@ -336,7 +337,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -416,7 +417,7 @@ class TestGetAvailableDatasetsIntegration: created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, @@ -459,7 +460,7 @@ class TestKnowledgeRetrievalIntegration: provider="dify", data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, ) db_session_with_containers.add(dataset) @@ -476,7 +477,7 @@ class TestKnowledgeRetrievalIntegration: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 9d0fad4b12..e0c58f0f5c 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,20 +7,17 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.nodes.human_input.entities import ( +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.repositories.human_input_form_repository import FormCreateParams +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -68,7 +65,6 @@ def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCre user_actions=[UserAction(id="approve", title="Approve")], ) return FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=form_config, @@ -84,7 +80,7 @@ def _build_email_delivery( ) -> EmailDeliveryMethod: return EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + recipients=EmailRecipients(include_bound_group=whole_workspace, items=recipients), subject="Approval Needed", body="Please review", ) @@ -100,7 +96,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,13 +125,13 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[ _build_email_delivery( whole_workspace=False, recipients=[ - MemberRecipient(user_id=members[0].id), + MemberRecipient(reference_id=members[0].id), ExternalRecipient(email="external@example.com"), ], ) @@ -173,10 +169,9 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( @@ -210,9 +205,8 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 9733735df3..ae8c0716a4 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -10,22 +10,23 @@ from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole @@ -39,7 +40,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -52,7 +53,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -66,7 +67,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, workflow_id=workflow_id, @@ -120,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 8e70fc0bb0..2e207ddc67 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -6,10 +6,11 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -193,19 +198,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -314,7 +316,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -338,7 +340,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -365,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index fb8d1808f9..0fd03813da 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta from decimal import Decimal from uuid import uuid4 -from dify_graph.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.entities import FormDefinition, UserAction from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py new file mode 100644 index 0000000000..a79208f649 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_streams_channel.py @@ -0,0 +1,227 @@ +""" +Integration tests for Redis Streams broadcast channel implementation using TestContainers. + +This suite focuses on the semantics that differ from Redis Pub/Sub: +- Every active subscription should receive each newly published message. +- Each subscription should only observe messages published after its listener starts. +""" + +import threading +import time +import uuid +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest +import redis +from testcontainers.redis import RedisContainer + +from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic +from libs.broadcast_channel.exc import SubscriptionClosedError +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + +class TestRedisStreamsBroadcastChannelIntegration: + """Integration tests for Redis Streams broadcast channel with a real Redis instance.""" + + @pytest.fixture(scope="class") + def redis_container(self) -> Iterator[RedisContainer]: + """Create a Redis container for integration testing.""" + with RedisContainer(image="redis:6-alpine") as container: + yield container + + @pytest.fixture(scope="class") + def redis_client(self, redis_container: RedisContainer) -> redis.Redis: + """Create a Redis client connected to the test container.""" + host = redis_container.get_container_host_ip() + port = redis_container.get_exposed_port(6379) + return redis.Redis(host=host, port=port, decode_responses=False) + + @pytest.fixture + def broadcast_channel(self, redis_client: redis.Redis) -> BroadcastChannel: + """Create a StreamsBroadcastChannel instance with a real Redis client.""" + return StreamsBroadcastChannel(redis_client) + + @classmethod + def _get_test_topic_name(cls) -> str: + return f"test_streams_topic_{uuid.uuid4()}" + + @staticmethod + def _start_subscription(subscription: Subscription) -> None: + """Start the background listener and confirm the subscription queue is empty.""" + assert subscription.receive(timeout=0.05) is None + + @staticmethod + def _receive_message(subscription: Subscription, *, timeout_seconds: float = 2.0) -> bytes: + """Poll until a message is received or the timeout expires.""" + deadline = time.monotonic() + timeout_seconds + while time.monotonic() < deadline: + message = subscription.receive(timeout=0.1) + if message is not None: + return message + pytest.fail("Timed out waiting for a message") + + def test_close_an_active_subscription_should_stop_iteration(self, broadcast_channel: BroadcastChannel) -> None: + """Closing an active subscription should terminate the iterator cleanly.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + consuming_event = threading.Event() + + def consume() -> list[bytes]: + messages: list[bytes] = [] + consuming_event.set() + for message in subscription: + messages.append(message) + return messages + + with ThreadPoolExecutor(max_workers=1) as executor: + consumer_future = executor.submit(consume) + assert consuming_event.wait(timeout=1.0) + subscription.close() + assert consumer_future.result(timeout=2.0) == [] + + def test_end_to_end_messaging(self, broadcast_channel: BroadcastChannel) -> None: + """A producer should publish a message that a live subscription can consume.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + producer = topic.as_producer() + subscription = topic.subscribe() + message = b"hello streams" + + try: + self._start_subscription(subscription) + producer.publish(message) + + assert self._receive_message(subscription) == message + assert subscription.receive(timeout=0.1) is None + finally: + subscription.close() + + def test_multiple_subscriptions_each_receive_each_new_message(self, broadcast_channel: BroadcastChannel) -> None: + """Each active subscription should receive the same newly published message.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscriptions = [topic.subscribe() for _ in range(3)] + new_message = b"message-visible-to-every-subscriber" + + try: + for subscription in subscriptions: + self._start_subscription(subscription) + + topic.publish(new_message) + + for subscription in subscriptions: + assert self._receive_message(subscription) == new_message + assert subscription.receive(timeout=0.1) is None + finally: + for subscription in subscriptions: + subscription.close() + + def test_each_subscription_only_receives_messages_published_after_it_starts( + self, + broadcast_channel: BroadcastChannel, + ) -> None: + """A late subscription should not replay messages that existed before its listener started.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + first_subscription = topic.subscribe() + second_subscription = topic.subscribe() + message_before_any_subscription = b"before-any-subscription" + message_after_first_subscription = b"after-first-subscription" + message_after_second_subscription = b"after-second-subscription" + + try: + topic.publish(message_before_any_subscription) + + self._start_subscription(first_subscription) + topic.publish(message_after_first_subscription) + + assert self._receive_message(first_subscription) == message_after_first_subscription + assert first_subscription.receive(timeout=0.1) is None + + self._start_subscription(second_subscription) + topic.publish(message_after_second_subscription) + + assert self._receive_message(first_subscription) == message_after_second_subscription + assert self._receive_message(second_subscription) == message_after_second_subscription + assert first_subscription.receive(timeout=0.1) is None + assert second_subscription.receive(timeout=0.1) is None + finally: + first_subscription.close() + second_subscription.close() + + def test_topic_isolation(self, broadcast_channel: BroadcastChannel) -> None: + """Messages from different topics should remain isolated.""" + topic1 = broadcast_channel.topic(self._get_test_topic_name()) + topic2 = broadcast_channel.topic(self._get_test_topic_name()) + message1 = b"message-for-topic-1" + message2 = b"message-for-topic-2" + + def consume_single_message(topic: Topic) -> bytes: + subscription = topic.subscribe() + try: + self._start_subscription(subscription) + return self._receive_message(subscription) + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=3) as executor: + consumer1_future = executor.submit(consume_single_message, topic1) + consumer2_future = executor.submit(consume_single_message, topic2) + time.sleep(0.1) + topic1.publish(message1) + topic2.publish(message2) + + assert consumer1_future.result(timeout=5.0) == message1 + assert consumer2_future.result(timeout=5.0) == message2 + + def test_concurrent_producers_publish_all_messages(self, broadcast_channel: BroadcastChannel) -> None: + """Concurrent producers should not lose messages for a live subscription.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + producer_count = 4 + messages_per_producer = 4 + expected_total = producer_count * messages_per_producer + consumer_ready = threading.Event() + + def produce_messages(producer_idx: int) -> set[bytes]: + producer = topic.as_producer() + produced: set[bytes] = set() + for message_idx in range(messages_per_producer): + payload = f"producer-{producer_idx}-message-{message_idx}".encode() + produced.add(payload) + producer.publish(payload) + time.sleep(0.001) + return produced + + def consume_messages() -> set[bytes]: + received: set[bytes] = set() + try: + self._start_subscription(subscription) + consumer_ready.set() + while len(received) < expected_total: + message = subscription.receive(timeout=0.2) + if message is not None: + received.add(message) + return received + finally: + subscription.close() + + with ThreadPoolExecutor(max_workers=producer_count + 1) as executor: + consumer_future = executor.submit(consume_messages) + assert consumer_ready.wait(timeout=2.0) + + producer_futures = [executor.submit(produce_messages, idx) for idx in range(producer_count)] + expected_messages: set[bytes] = set() + for future in as_completed(producer_futures, timeout=10.0): + expected_messages.update(future.result()) + + assert consumer_future.result(timeout=10.0) == expected_messages + + def test_receive_raises_subscription_closed_after_close(self, broadcast_channel: BroadcastChannel) -> None: + """Calling receive on a closed subscription should raise SubscriptionClosedError.""" + topic = broadcast_channel.topic(self._get_test_topic_name()) + subscription = topic.subscribe() + + self._start_subscription(subscription) + subscription.close() + + with pytest.raises(SubscriptionClosedError): + subscription.receive(timeout=0.1) diff --git a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py deleted file mode 100644 index c9058626d1..0000000000 --- a/api/tests/test_containers_integration_tests/repositories/test_execution_extra_content_repository.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from sqlalchemy.orm import sessionmaker - -from extensions.ext_database import db -from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository -from tests.test_containers_integration_tests.helpers.execution_extra_content import ( - create_human_input_message_fixture, -) - - -def test_get_by_message_ids_returns_human_input_content(db_session_with_containers): - fixture = create_human_input_message_fixture(db_session_with_containers) - repository = SQLAlchemyExecutionExtraContentRepository( - session_maker=sessionmaker(bind=db.engine, expire_on_commit=False) - ) - - results = repository.get_by_message_ids([fixture.message.id]) - - assert len(results) == 1 - assert len(results[0]) == 1 - content = results[0][0] - assert content.submitted is True - assert content.form_submission_data is not None - assert content.form_submission_data.action_id == fixture.action_id - assert content.form_submission_data.action_text == fixture.action_text - assert content.form_submission_data.rendered_content == fixture.form.rendered_content diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 458862b0ec..641399c7f9 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ from uuid import uuid4 from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index c3ed79656f..cb00752b35 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,7 +2,6 @@ from __future__ import annotations -import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -12,22 +11,20 @@ import pytest from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( - BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, - RecipientType, ) -from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, @@ -218,7 +215,7 @@ class TestDeleteRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -278,7 +275,7 @@ class TestCountRunsWithRelated: app_id=test_scope.app_id, workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=test_scope.user_id, ) @@ -636,12 +633,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_prefers_backstage_token_when_available( + def test_builds_reason_from_form_definition( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Use backstage token when multiple recipient types may exist.""" + """Build the graph pause reason from the stored form definition.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -668,25 +665,6 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() - delivery = HumanInputDelivery( - form_id=form_model.id, - delivery_method_type=DeliveryMethodType.WEBAPP, - channel_payload="{}", - ) - db_session_with_containers.add(delivery) - db_session_with_containers.flush() - - access_token = secrets.token_urlsafe(8) - recipient = HumanInputFormRecipient( - form_id=form_model.id, - delivery_id=delivery.id, - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - db_session_with_containers.add(recipient) - db_session_with_containers.flush() - # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -716,13 +694,12 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) - db_session_with_containers.refresh(recipient) - reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + reason = _build_human_input_required_reason(reason_model, form_model) assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token assert reason.node_title == "Ask Name" assert reason.form_content == "content" assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" + assert reason.resolved_default_values == {"name": "Alice"} diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py new file mode 100644 index 0000000000..3d4ec25150 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -0,0 +1,407 @@ +"""Integration tests for SQLAlchemyExecutionExtraContentRepository using Testcontainers. + +Part of #32454 — replaces the mock-based unit tests with real database interactions. +""" + +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid4 + +import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource, InvokeFrom +from models.execution_extra_content import ExecutionExtraContent, HumanInputContent +from models.human_input import ( + ConsoleRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) +from models.model import App, Conversation, Message +from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows. + + IDs are populated after flushing the base entities to the database. + """ + + tenant_id: str = "" + app_id: str = "" + user_id: str = "" + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(ExecutionExtraContent).where( + ExecutionExtraContent.workflow_run_id.in_( + select(HumanInputForm.workflow_run_id).where(HumanInputForm.tenant_id == scope.tenant_id) + ) + ) + ) + session.execute(delete(HumanInputForm).where(HumanInputForm.tenant_id == scope.tenant_id)) + session.execute(delete(Message).where(Message.app_id == scope.app_id)) + session.execute(delete(Conversation).where(Conversation.app_id == scope.app_id)) + session.execute(delete(App).where(App.id == scope.app_id)) + session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == scope.tenant_id)) + session.execute(delete(Account).where(Account.id == scope.user_id)) + session.execute(delete(Tenant).where(Tenant.id == scope.tenant_id)) + session.commit() + + +def _seed_base_entities(session: Session, scope: _TestScope) -> None: + """Create the base tenant, account, and app needed by tests.""" + tenant = Tenant(name="Test Tenant") + session.add(tenant) + session.flush() + scope.tenant_id = tenant.id + + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + scope.user_id = account.id + + tenant_join = TenantAccountJoin( + tenant_id=scope.tenant_id, + account_id=scope.user_id, + role=TenantAccountRole.OWNER, + current=True, + ) + session.add(tenant_join) + + app = App( + tenant_id=scope.tenant_id, + name="Test App", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=scope.user_id, + updated_by=scope.user_id, + ) + session.add(app) + session.flush() + scope.app_id = app.id + + +def _create_conversation(session: Session, scope: _TestScope) -> Conversation: + conversation = Conversation( + app_id=scope.app_id, + mode="chat", + name="Test Conversation", + summary="", + introduction="", + system_instruction="", + status="normal", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + from_end_user_id=None, + ) + conversation.inputs = {} + session.add(conversation) + session.flush() + return conversation + + +def _create_message( + session: Session, + scope: _TestScope, + conversation_id: str, + workflow_run_id: str, +) -> Message: + message = Message( + app_id=scope.app_id, + conversation_id=conversation_id, + inputs={}, + query="test query", + message={"messages": []}, + answer="test answer", + message_tokens=50, + message_unit_price=Decimal("0.001"), + answer_tokens=80, + answer_unit_price=Decimal("0.001"), + provider_response_latency=0.5, + currency="USD", + from_source=ConversationFromSource.CONSOLE, + from_account_id=scope.user_id, + workflow_run_id=workflow_run_id, + ) + session.add(message) + session.flush() + return message + + +def _create_submitted_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + action_id: str = "approve", + action_title: str = "Approve", + node_title: str = "Approval", +) -> HumanInputForm: + expiration_time = datetime.utcnow() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id=action_id, title=action_title)], + rendered_content="rendered", + expiration_time=expiration_time, + node_title=node_title, + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content=f"Rendered {action_title}", + status=HumanInputFormStatus.SUBMITTED, + expiration_time=expiration_time, + selected_action_id=action_id, + ) + session.add(form) + session.flush() + return form + + +def _create_waiting_form( + session: Session, + scope: _TestScope, + *, + workflow_run_id: str, + default_values: dict | None = None, +) -> HumanInputForm: + expiration_time = datetime.utcnow() + timedelta(days=1) + form_definition = FormDefinition( + form_content="content", + inputs=[], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values=default_values or {"name": "John"}, + node_title="Approval", + display_in_ui=True, + ) + form = HumanInputForm( + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_run_id=workflow_run_id, + node_id="node-id", + form_definition=form_definition.model_dump_json(), + rendered_content="Rendered block", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + session.add(form) + session.flush() + return form + + +def _create_human_input_content( + session: Session, + *, + workflow_run_id: str, + message_id: str, + form_id: str, +) -> HumanInputContent: + content = HumanInputContent.new( + workflow_run_id=workflow_run_id, + message_id=message_id, + form_id=form_id, + ) + session.add(content) + return content + + +def _create_recipient( + session: Session, + *, + form_id: str, + delivery_id: str, + recipient_type: RecipientType = RecipientType.CONSOLE, + access_token: str = "token-1", +) -> HumanInputFormRecipient: + payload = ConsoleRecipientPayload(account_id=None) + recipient = HumanInputFormRecipient( + form_id=form_id, + delivery_id=delivery_id, + recipient_type=recipient_type, + recipient_payload=payload.model_dump_json(), + access_token=access_token, + ) + session.add(recipient) + return recipient + + +def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: + from core.workflow.human_input_compat import DeliveryMethodType + from models.human_input import ConsoleDeliveryPayload + + delivery = HumanInputDelivery( + form_id=form_id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload=ConsoleDeliveryPayload().model_dump_json(), + ) + session.add(delivery) + session.flush() + return delivery + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> SQLAlchemyExecutionExtraContentRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyExecutionExtraContentRepository(sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> Generator[_TestScope]: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + _seed_base_entities(db_session_with_containers, scope) + db_session_with_containers.commit() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetByMessageIds: + """Tests for SQLAlchemyExecutionExtraContentRepository.get_by_message_ids.""" + + def test_groups_contents_by_message( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Submitted forms are correctly mapped and grouped by message ID.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg1 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + msg2 = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_submitted_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + action_id="approve", + action_title="Approve", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg1.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg1.id, msg2.id]) + + assert len(result) == 2 + # msg1 has one submitted content + assert len(result[0]) == 1 + content = result[0][0] + assert content.submitted is True + assert content.workflow_run_id == workflow_run_id + assert content.form_submission_data is not None + assert content.form_submission_data.action_id == "approve" + assert content.form_submission_data.action_text == "Approve" + assert content.form_submission_data.rendered_content == "Rendered Approve" + assert content.form_submission_data.node_id == "node-id" + assert content.form_submission_data.node_title == "Approval" + # msg2 has no content + assert result[1] == [] + + def test_returns_unsubmitted_form_definition( + self, + db_session_with_containers: Session, + repository: SQLAlchemyExecutionExtraContentRepository, + test_scope: _TestScope, + ) -> None: + """Waiting forms return full form_definition with resolved token and defaults.""" + workflow_run_id = str(uuid4()) + conversation = _create_conversation(db_session_with_containers, test_scope) + msg = _create_message(db_session_with_containers, test_scope, conversation.id, workflow_run_id) + + form = _create_waiting_form( + db_session_with_containers, + test_scope, + workflow_run_id=workflow_run_id, + default_values={"name": "John"}, + ) + delivery = _create_delivery(db_session_with_containers, form_id=form.id) + _create_recipient( + db_session_with_containers, + form_id=form.id, + delivery_id=delivery.id, + access_token="token-1", + ) + _create_human_input_content( + db_session_with_containers, + workflow_run_id=workflow_run_id, + message_id=msg.id, + form_id=form.id, + ) + db_session_with_containers.commit() + + result = repository.get_by_message_ids([msg.id]) + + assert len(result) == 1 + assert len(result[0]) == 1 + domain_content = result[0][0] + assert domain_content.submitted is False + assert domain_content.workflow_run_id == workflow_run_id + assert domain_content.form_definition is not None + form_def = domain_content.form_definition + assert form_def.form_id == form.id + assert form_def.node_id == "node-id" + assert form_def.node_title == "Approval" + assert form_def.form_content == "Rendered block" + assert form_def.display_in_ui is True + assert form_def.form_token == "token-1" + assert form_def.resolved_default_values == {"name": "John"} + assert form_def.expiration_time == int(form.expiration_time.timestamp()) + + def test_empty_message_ids_returns_empty_list( + self, + repository: SQLAlchemyExecutionExtraContentRepository, + ) -> None: + """Passing no message IDs returns an empty list without hitting the DB.""" + result = repository.get_by_message_ids([]) + assert result == [] diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index 1568d5d65c..d6f0657380 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -11,8 +11,8 @@ from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 6b35f867d7..02c3d1a80e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -13,6 +13,7 @@ import pytest from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.enums import DataSourceType @@ -74,7 +75,7 @@ class DatasetUpdateDeleteTestDataFactory: name=name, description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index f995ac7bef..42d587b7f7 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -91,7 +92,7 @@ class DocumentStatusTestDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id document.indexing_status = indexing_status diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py b/api/tests/test_containers_integration_tests/services/enterprise/__init__.py similarity index 100% rename from api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py rename to api/tests/test_containers_integration_tests/services/enterprise/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 0000000000..4e8255d8ed --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,200 @@ +"""Integration tests for account deletion synchronization. + +Verifies enterprise account deletion sync functionality including +Redis queuing, error handling, and community vs enterprise behavior. +""" + +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from redis import RedisError + +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + def test_queue_task_success(self): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source="test_source") + + assert result is True + + import json + + raw = redis_client.rpop("enterprise:member:sync:queue") + assert raw is not None + task_data = json.loads(raw) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == "test_source" + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = RedisError("Connection failed") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = TypeError("Cannot serialize") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source="removed") + + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source="removed") + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is False + + +class TestSyncAccountDeletion: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is True + assert mock_queue_task.call_count == 3 + + queued_workspace_ids = {call.kwargs["workspace_id"] for call in mock_queue_task.call_args_list} + assert queued_workspace_ids == set(tenant_ids) + + def test_sync_account_deletion_no_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + fail_tenant = tenant_ids[1] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != fail_tenant + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_id = str(uuid4()) + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + mock_queue_task.assert_called_once() diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__init__.py b/api/tests/test_containers_integration_tests/services/plugin/__init__.py similarity index 100% rename from api/tests/unit_tests/dify_graph/model_runtime/__init__.py rename to api/tests/test_containers_integration_tests/services/plugin/__init__.py diff --git a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index bfa9fe976b..3885137221 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -6,10 +6,13 @@ HIDDEN_VALUE replacement, and error handling for missing records. from __future__ import annotations +import json from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from models.tools import BuiltinToolProvider from services.plugin.plugin_parameter_service import PluginParameterService @@ -39,67 +42,73 @@ class TestGetDynamicSelectOptionsTool: @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_fetches_credentials_with_credential_id(self, mock_tool_mgr, mock_db, mock_encrypter_fn, mock_client_cls): + def test_fetches_credentials_with_credential_id( + self, + mock_tool_mgr, + mock_encrypter_fn, + mock_client_cls, + flask_app_with_containers, + db_session_with_containers, + ): + tenant_id = str(uuid4()) provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl encrypter = MagicMock() encrypter.decrypt.return_value = {"api_key": "decrypted"} mock_encrypter_fn.return_value = (encrypter, None) + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - # Mock the Session/query chain - db_record = MagicMock() - db_record.credentials = {"api_key": "encrypted"} - db_record.credential_type = "api_key" + db_record = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=str(uuid4()), + provider="google", + name="API KEY 1", + encrypted_credentials=json.dumps({"api_key": "encrypted"}), + credential_type="api_key", + ) + db_session_with_containers.add(db_record) + db_session_with_containers.commit() - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.first.return_value = db_record - mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - - result = PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id="cred-1", - provider_type="tool", - ) + result = PluginParameterService.get_dynamic_select_options( + tenant_id=tenant_id, + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=db_record.id, + provider_type="tool", + ) assert result == ["opt1"] @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_raises_when_tool_provider_not_found(self, mock_tool_mgr, mock_db, mock_encrypter_fn): + def test_raises_when_tool_provider_not_found( + self, + mock_tool_mgr, + mock_encrypter_fn, + flask_app_with_containers, + db_session_with_containers, + ): provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl mock_encrypter_fn.return_value = (MagicMock(), None) - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None - - with pytest.raises(ValueError, match="not found"): - PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id=None, - provider_type="tool", - ) + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id=str(uuid4()), + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) class TestGetDynamicSelectOptionsTrigger: diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/__init__.py b/api/tests/test_containers_integration_tests/services/recommend_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 0000000000..2b842629a7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +from models.model import App, RecommendedApp, Site +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +def _create_app(db_session, *, tenant_id: str, is_public: bool = True) -> App: + app = App( + tenant_id=tenant_id, + name=f"app-{uuid4()}", + mode="chat", + enable_site=True, + enable_api=True, + is_public=is_public, + ) + app.id = str(uuid4()) + db_session.add(app) + db_session.commit() + return app + + +def _create_site(db_session, *, app_id: str) -> Site: + site = Site( + app_id=app_id, + title=f"site-{uuid4()}", + default_language="en-US", + customize_token_strategy="not_allow", + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + site.id = str(uuid4()) + db_session.add(site) + db_session.commit() + return site + + +def _create_recommended_app( + db_session, + *, + app_id: str, + category: str = "chat", + language: str = "en-US", + is_listed: bool = True, + position: int = 1, +) -> RecommendedApp: + rec = RecommendedApp( + app_id=app_id, + description={"en-US": "test"}, + copyright="copy", + privacy_policy="pp", + category=category, + language=language, + is_listed=is_listed, + position=position, + ) + rec.id = str(uuid4()) + db_session.add(rec) + db_session.commit() + return rec + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, category="writing") + + app2 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app2.id) + _create_recommended_app(db_session_with_containers, app_id=app2.id, category="assistant") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + assert app2.id in app_ids + assert "assistant" in result["categories"] + assert "writing" in result["categories"] + + def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, language="en-US") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + +class TestFetchRecommendedAppDetailFromDb: + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) + + assert result is None + + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + mock_dsl.export_dsl.return_value = "exported_yaml" + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is not None + assert result["id"] == app1.id + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index b51fbc3a42..00a2f9a59f 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -28,7 +28,7 @@ class TestAgentService: patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, - patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant", autospec=True) as mock_model_manager, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service @@ -841,7 +841,7 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from dify_graph.file import FileTransferMethod, FileType + from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 7ce7357b41..b8e022503f 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -525,3 +525,147 @@ class TestAPIBasedExtensionService: # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + + def test_save_extension_api_key_exactly_four_chars_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 4 characters should be rejected (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="1234", + ) + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_api_key_exactly_five_chars_accepted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """API key with exactly 5 characters should be accepted (boundary).""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key="12345", + ) + + saved = APIBasedExtensionService.save(extension_data) + assert saved.id is not None + + def test_save_extension_requestor_constructor_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Exception raised by requestor constructor is wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor"].side_effect = RuntimeError("bad config") + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_network_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Network exceptions during ping are wrapped in ValueError.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + mock_external_service_dependencies["requestor_instance"].request.side_effect = ConnectionError( + "network failure" + ) + + extension_data = APIBasedExtension( + tenant_id=tenant.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_update_duplicate_name_rejected( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Updating an existing extension to use another extension's name should fail.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant is not None + + ext1 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Alpha", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + ext2 = APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant.id, + name="Extension Beta", + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + # Try to rename ext2 to ext1's name + ext2.name = "Extension Alpha" + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(ext2) + + def test_get_all_returns_empty_for_different_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Extensions from one tenant should not be visible to another.""" + fake = Faker() + _, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + _, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + assert tenant1 is not None + + APIBasedExtensionService.save( + APIBasedExtension( + tenant_id=tenant1.id, + name=fake.company(), + api_endpoint=f"https://{fake.domain_name()}/api", + api_key=fake.password(length=20), + ) + ) + + assert tenant2 is not None + result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + assert result == [] diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 8a362e1f5e..33955d5d84 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -26,7 +26,7 @@ class TestAppDslService: patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..fa57dd4a6f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -23,7 +23,7 @@ class TestAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. @@ -1142,3 +1245,51 @@ class TestAppService: assert paginated_apps is not None assert paginated_apps.total == 1 assert all("50%" in app.name for app in paginated_apps.items) + + def test_get_app_code_by_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_code_by_id raises ValueError when site is missing.""" + from uuid import uuid4 + + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id(str(uuid4())) + + def test_get_app_id_by_code_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_id_by_code raises ValueError when code does not exist.""" + from services.app_service import AppService + + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("nonexistent-code") + + def test_get_app_meta_returns_empty_when_workflow_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when workflow is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + workflow_app = SimpleNamespace(mode="workflow", workflow=None) + + meta = app_service.get_app_meta(workflow_app) + assert meta == {"tool_icons": {}} + + def test_get_app_meta_returns_empty_when_model_config_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test get_app_meta returns empty tool_icons when app_model_config is None.""" + from types import SimpleNamespace + + from services.app_service import AppService + + app_service = AppService() + chat_app = SimpleNamespace(mode="chat", app_model_config=None) + + meta = app_service.get_app_meta(chat_app) + assert meta == {"tool_icons": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 42a2215896..02ab3f8314 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -5,8 +5,8 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import sessionmaker -from dify_graph.variables import StringVariable from extensions.ext_database import db +from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 25de0588fa..0f63d98642 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -6,6 +6,7 @@ import pytest from core.errors.error import QuotaExceededError from models import TenantCreditPool +from models.enums import ProviderQuotaType from services.credit_pool_service import CreditPoolService @@ -20,7 +21,7 @@ class TestCreditPoolService: assert isinstance(pool, TenantCreditPool) assert pool.tenant_id == tenant_id - assert pool.pool_type == "trial" + assert pool.pool_type == ProviderQuotaType.TRIAL assert pool.quota_used == 0 assert pool.quota_limit > 0 @@ -28,14 +29,14 @@ class TestCreditPoolService: tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) - result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) assert result is not None assert result.tenant_id == tenant_id - assert result.pool_type == "trial" + assert result.pool_type == ProviderQuotaType.TRIAL def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): - result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial") + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 55bfb64e18..71c8874f79 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -69,7 +70,7 @@ class DatasetPermissionTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index ac3d9f9604..0de3c64c4f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -11,8 +11,9 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus @@ -62,7 +63,7 @@ class DatasetServiceIntegrationDataFactory: name: str = "Test Dataset", description: str | None = "Test description", provider: str = "vendor", - indexing_technique: str | None = "high_quality", + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, permission: str = DatasetPermissionEnum.ONLY_ME, retrieval_model: dict | None = None, embedding_model_provider: str | None = None, @@ -106,7 +107,7 @@ class DatasetServiceIntegrationDataFactory: created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -156,13 +157,13 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Economy Dataset", description=None, - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "economy" + assert result.indexing_technique == IndexTechniqueType.ECONOMY assert result.embedding_model_provider is None assert result.embedding_model is None @@ -173,20 +174,20 @@ class TestDatasetServiceCreateDataset: embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act - with patch("services.dataset_service.ModelManager") as mock_model_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model result = DatasetService.create_empty_dataset( tenant_id=tenant.id, name="High Quality Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, ) # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( @@ -263,7 +264,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, ): mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model @@ -272,7 +273,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Dataset With Reranking", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, ) @@ -296,7 +297,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, ): mock_model_manager.return_value.get_model_instance.return_value = embedding_model @@ -305,7 +306,7 @@ class TestDatasetServiceCreateDataset: tenant_id=tenant.id, name="Custom Embedding Dataset", description=None, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, @@ -313,7 +314,7 @@ class TestDatasetServiceCreateDataset: # Assert db_session_with_containers.refresh(result) - assert result.indexing_technique == "high_quality" + assert result.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) @@ -588,7 +589,7 @@ class TestDatasetServiceUpdateAndDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure="text_model", ) DatasetServiceIntegrationDataFactory.create_document( @@ -684,14 +685,14 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers, tenant_id=tenant.id, created_by=account.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=str(uuid4()), ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": { "search_method": "full_text_search", "top_k": 10, @@ -706,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] + + +class TestDocumentServicePauseRecoverRetry: + """Tests for pause/recover/retry orchestration using real DB and Redis.""" + + def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc = factory.create_document(db_session_with_containers, dataset, account.id) + doc.indexing_status = indexing_status + db_session_with_containers.commit() + return doc, account + + def test_pause_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is True + assert doc.paused_by == account.id + assert doc.paused_at is not None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is not None + redis_client.delete(cache_key) + + def test_pause_document_invalid_status_error(self, db_session_with_containers): + from services.dataset_service import DocumentService + from services.errors.document import DocumentIndexingError + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(doc) + + def test_recover_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + # Pause first + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + # Recover + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is False + assert doc.paused_by is None + assert doc.paused_at is None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is None + recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) + + def test_retry_document_indexing_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt") + doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt") + doc2.position = 2 + doc1.indexing_status = "error" + doc2.indexing_status = "error" + db_session_with_containers.commit() + + with ( + patch("services.dataset_service.current_user") as mock_user, + patch("services.dataset_service.retry_document_indexing_task") as retry_task, + ): + mock_user.id = account.id + DocumentService.retry_document(dataset.id, [doc1, doc2]) + + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + assert doc1.indexing_status == "waiting" + assert doc2.indexing_status == "waiting" + + # Verify redis keys were set + assert redis_client.get(f"document_{doc1.id}_is_retried") is not None + assert redis_client.get(f"document_{doc2.id}_is_retried") is not None + retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id) + + # Cleanup + redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried") diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index 7983b1cd93..c1d088755c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -79,7 +80,7 @@ class DocumentBatchUpdateIntegrationDataFactory: name=name, created_from=DocumentCreatedFrom.WEB, created_by=created_by or str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = document_id or str(uuid4()) document.enabled = enabled @@ -694,3 +695,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) + + def test_batch_update_invalid_action_raises_value_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Test that an invalid action raises ValueError.""" + factory = DocumentBatchUpdateIntegrationDataFactory + dataset = factory.create_dataset(db_session_with_containers) + doc = factory.create_document(db_session_with_containers, dataset) + user = UserDouble(id=str(uuid4())) + + patched_dependencies["redis_client"].get.return_value = None + + with pytest.raises(ValueError, match="Invalid action"): + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py new file mode 100644 index 0000000000..c486ff5613 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -0,0 +1,60 @@ +"""Testcontainers integration tests for DatasetService.create_empty_rag_pipeline_dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.account import Account, Tenant, TenantAccountJoin +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity + + +class TestDatasetServiceCreateRagPipelineDataset: + def _create_tenant_and_account(self, db_session_with_containers) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"ds_create_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return tenant, account + + def _build_entity(self, name: str = "Test Dataset") -> RagPipelineDatasetCreateEntity: + icon_info = IconInfo(icon="\U0001f4d9", icon_background="#FFF4ED", icon_type="emoji") + return RagPipelineDatasetCreateEntity( + name=name, + description="", + icon_info=icon_info, + permission="only_me", + ) + + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + tenant, _ = self._create_tenant_and_account(db_session_with_containers) + + mock_user = Mock(id=None) + with patch("services.dataset_service.current_user", mock_user): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=self._build_entity(), + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index ed070527c9..3cac964d89 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,7 @@ from unittest.mock import patch from uuid import uuid4 +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -78,7 +79,7 @@ class DatasetDeleteIntegrationDataFactory: tenant_id: str, dataset_id: str, created_by: str, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """Persist a document so dataset.doc_form resolves through the real document path.""" document = Document( @@ -108,7 +109,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), @@ -119,7 +120,7 @@ class TestDatasetServiceDeleteDataset: tenant_id=tenant.id, dataset_id=dataset.id, created_by=owner.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) # Act @@ -207,7 +208,7 @@ class TestDatasetServiceDeleteDataset: db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, chunk_structure=None, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index c4b3a57bb2..87239b2cb3 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -12,6 +12,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom @@ -64,7 +65,7 @@ class SegmentServiceTestDataFactory: name=f"Test Dataset {uuid4()}", description="Test description", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 3021d8984d..2f90d16176 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -15,6 +15,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -102,7 +103,7 @@ class DatasetRetrievalTestDataFactory: name=name, description="desc", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, permission=permission, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index fd81948247..883c3c3feb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -4,7 +4,8 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from dify_graph.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType @@ -53,7 +54,7 @@ class DatasetUpdateTestDataFactory: provider: str = "vendor", name: str = "old_name", description: str = "old_description", - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, retrieval_model: str = "old_model", permission: str = "only_me", embedding_model_provider: str | None = None, @@ -241,7 +242,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -250,7 +251,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": "new_description", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", @@ -261,7 +262,7 @@ class TestDatasetServiceUpdateDataset: assert dataset.name == "new_name" assert dataset.description == "new_description" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.retrieval_model == "new_model" assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" @@ -276,7 +277,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -285,7 +286,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", "description": None, - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", "embedding_model_provider": None, "embedding_model": None, @@ -312,14 +313,14 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, ) update_data = { - "indexing_technique": "economy", + "indexing_technique": IndexTechniqueType.ECONOMY, "retrieval_model": "new_model", } @@ -328,7 +329,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "economy" + assert dataset.indexing_technique == IndexTechniqueType.ECONOMY assert dataset.embedding_model is None assert dataset.embedding_model_provider is None assert dataset.collection_binding_id is None @@ -343,7 +344,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) embedding_model = Mock() @@ -354,7 +355,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-ada-002", "retrieval_model": "new_model", @@ -362,7 +363,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -383,7 +384,7 @@ class TestDatasetServiceUpdateDataset: mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -403,7 +404,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -411,7 +412,7 @@ class TestDatasetServiceUpdateDataset: update_data = { "name": "new_name", - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "retrieval_model": "new_model", } @@ -419,7 +420,7 @@ class TestDatasetServiceUpdateDataset: db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" - assert dataset.indexing_technique == "high_quality" + assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY assert dataset.embedding_model_provider == "openai" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.collection_binding_id == existing_binding_id @@ -435,7 +436,7 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", collection_binding_id=existing_binding_id, @@ -449,7 +450,7 @@ class TestDatasetServiceUpdateDataset: binding.id = str(uuid4()) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "openai", "embedding_model": "text-embedding-3-small", "retrieval_model": "new_model", @@ -457,7 +458,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -531,11 +532,11 @@ class TestDatasetServiceUpdateDataset: tenant_id=tenant.id, created_by=user.id, provider="vendor", - indexing_technique="economy", + indexing_technique=IndexTechniqueType.ECONOMY, ) update_data = { - "indexing_technique": "high_quality", + "indexing_technique": IndexTechniqueType.HIGH_QUALITY, "embedding_model_provider": "invalid_provider", "embedding_model": "invalid_model", "retrieval_model": "new_model", @@ -543,7 +544,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 5f86cb2ae9..fe426ae516 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -7,7 +7,7 @@ from uuid import uuid4 from sqlalchemy import select -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion @@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None + + def test_delete_run_dry_run(self, db_session_with_containers): + """Dry run should return success without actually deleting.""" + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + run_id = run.id + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + + result = deleter._delete_run(run) + + assert result.success is True + assert result.run_id == run_id + # Run should still exist because it's a dry run + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is not None + + def test_delete_run_exception_returns_error(self, db_session_with_containers): + """Exception during deletion should return failure result.""" + from unittest.mock import MagicMock, patch + + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + deleter = ArchivedWorkflowRunDeletion(dry_run=False) + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deleter._delete_run(run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_by_run_id_success(self, db_session_with_containers): + """Successfully delete an archived workflow run by ID.""" + tenant_id = str(uuid4()) + base_time = datetime.now(UTC) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=base_time, + ) + self._create_archive_log(db_session_with_containers, run=run) + run_id = run.id + + deleter = ArchivedWorkflowRunDeletion() + result = deleter.delete_by_run_id(run_id) + + assert result.success is True + db_session_with_containers.expunge_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is None + + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + """_get_workflow_run_repo should return a cached repo on subsequent calls.""" + deleter = ArchivedWorkflowRunDeletion() + + repo1 = deleter._get_workflow_run_repo() + repo2 = deleter._get_workflow_run_repo() + + assert repo1 is repo2 + assert deleter.workflow_run_repo is repo1 diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index c6aa89c733..c0047df810 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -3,6 +3,7 @@ from uuid import uuid4 from sqlalchemy import select +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -42,7 +43,7 @@ def _create_document( name=f"doc-{uuid4()}", created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) document.id = str(uuid4()) document.indexing_status = indexing_status @@ -142,3 +143,11 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c rows = db_session_with_containers.scalars(filtered).all() assert {row.id for row in rows} == {doc1.id, doc2.id} + + +def test_normalize_display_status_alias_mapping(): + """Test that normalize_display_status maps aliases correctly.""" + assert DocumentService.normalize_display_status("ACTIVE") == "available" + assert DocumentService.normalize_display_status("enabled") == "available" + assert DocumentService.normalize_display_status("archived") == "archived" + assert DocumentService.normalize_display_status("unknown") is None diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index bffa520ce6..34532ed7f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document @@ -69,7 +70,7 @@ def make_document( name=name, created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) doc.id = document_id doc.indexing_status = "completed" diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index ae811db768..cafabc939b 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -414,3 +414,144 @@ class TestEndUserServiceGetEndUserById: ) assert result is None + + +class TestEndUserServiceCreateBatch: + """Integration tests for EndUserService.create_end_user_batch.""" + + @pytest.fixture + def factory(self): + return TestEndUserServiceFactory() + + def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + """Create multiple apps under the same tenant.""" + first_app = factory.create_app_and_account(db_session_with_containers) + tenant_id = first_app.tenant_id + apps = [first_app] + for _ in range(count - 1): + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=first_app.created_by, + updated_by=first_app.updated_by, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() + return tenant_id, all_apps + + def test_create_batch_empty_app_ids(self, db_session_with_containers): + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" + ) + assert result == {} + + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 3 + for app_id in app_ids: + assert app_id in result + assert result[app_id].session_id == user_id + assert result[app_id].type == InvokeFrom.SERVICE_API + + def test_create_batch_default_session_id(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="" + ) + + assert len(result) == 2 + for end_user in result.values(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(result) == 2 + + def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) + app_ids = [a.id for a in apps] + user_id = f"user-{uuid4()}" + + # Create batch first time + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Create batch second time — should return existing users + second_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + assert len(second_result) == 2 + for app_id in app_ids: + assert first_result[app_id].id == second_result[app_id].id + + def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) + user_id = f"user-{uuid4()}" + + # Create for first 2 apps + first_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[apps[0].id, apps[1].id], + user_id=user_id, + ) + + # Create for all 3 apps — should reuse first 2, create 3rd + all_result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=[a.id for a in apps], + user_id=user_id, + ) + + assert len(all_result) == 3 + assert all_result[apps[0].id].id == first_result[apps[0].id].id + assert all_result[apps[1].id].id == first_result[apps[1].id].id + assert all_result[apps[2].id].session_id == user_id + + @pytest.mark.parametrize( + "invoke_type", + [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], + ) + def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) + user_id = f"user-{uuid4()}" + + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=[apps[0].id], user_id=user_id + ) + + assert len(result) == 1 + assert result[apps[0].id].type == invoke_type diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py new file mode 100644 index 0000000000..4e0a726cc7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -0,0 +1,96 @@ +""" +Testcontainers integration tests for FileService helpers. + +Covers: +- ZIP tempfile building (sanitization + deduplication + content writes) +- tenant-scoped batch lookup behavior (get_upload_files_by_ids) +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from uuid import uuid4 +from zipfile import ZipFile + +import pytest + +import services.file_service as file_service_module +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.file_service import FileService + + +def _create_upload_file(db_session, *, tenant_id: str, key: str, name: str) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.OPENDAL, + key=key, + name=name, + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session.add(upload_file) + db_session.commit() + return upload_file + + +def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure ZIP entry names are safe and unique while preserving extensions.""" + upload_files: list[Any] = [ + SimpleNamespace(name="a/b.txt", key="k1"), + SimpleNamespace(name="c/b.txt", key="k2"), + SimpleNamespace(name="../b.txt", key="k3"), + ] + + data_by_key: dict[str, list[bytes]] = {"k1": [b"one"], "k2": [b"two"], "k3": [b"three"]} + + def _load(key: str, stream: bool = True) -> list[bytes]: + assert stream is True + return data_by_key[key] + + monkeypatch.setattr(file_service_module.storage, "load", _load) + + with FileService.build_upload_files_zip_tempfile(upload_files=upload_files) as tmp: + with ZipFile(tmp, mode="r") as zf: + assert zf.namelist() == ["b.txt", "b (1).txt", "b (2).txt"] + assert zf.read("b.txt") == b"one" + assert zf.read("b (1).txt") == b"two" + assert zf.read("b (2).txt") == b"three" + + +def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers) -> None: + """Ensure empty input returns an empty mapping without hitting the database.""" + assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + + +def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers) -> None: + """Ensure batch lookup returns a dict keyed by stringified UploadFile ids.""" + tenant_id = str(uuid4()) + file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") + file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") + + result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + + assert set(result.keys()) == {file1.id, file2.id} + assert result[file1.id].id == file1.id + assert result[file2.id].id == file2.id + + +def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers) -> None: + """Ensure files from other tenants are not returned.""" + tenant_a = str(uuid4()) + tenant_b = str(uuid4()) + file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") + _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") + + result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + + assert set(result.keys()) == {file_a.id} diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 70d05792ce..18c5320d0a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,14 +4,14 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType @@ -54,7 +54,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="recipient@example.com")], ), subject="Test {{recipient_email}}", diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 85dc04b162..bdf6d9b951 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -25,7 +25,7 @@ class TestMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.ModelManager.for_tenant") as mock_model_manager, patch("services.message_service.WorkflowService") as mock_workflow_service, patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, patch("services.message_service.LLMGenerator") as mock_llm_generator, diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 8707f2e827..c0c1c25f1e 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from graphon.file.enums import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, @@ -253,7 +254,7 @@ class TestMessagesCleanServiceIntegration: # MessageFile file = MessageFile( message_id=message.id, - type="image", + type=FileType.IMAGE, transfer_method="local_file", url="http://example.com/test.jpg", belongs_to=MessageFileBelongsTo.USER, diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py new file mode 100644 index 0000000000..b55a19eaa9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +from models.dataset import Dataset, DatasetMetadataBinding, Document +from models.enums import DataSourceType, DocumentCreatedFrom +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def _create_dataset(db_session, *, tenant_id: str, built_in_field_enabled: bool = False) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=str(uuid4()), + ) + dataset.id = str(uuid4()) + dataset.built_in_field_enabled = built_in_field_enabled + db_session.add(dataset) + db_session.commit() + return dataset + + +def _create_document(db_session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None) -> Document: + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info="{}", + batch=f"batch-{uuid4()}", + name=f"doc-{uuid4()}", + created_from=DocumentCreatedFrom.WEB, + created_by=str(uuid4()), + ) + document.id = str(uuid4()) + document.doc_metadata = doc_metadata + db_session.add(document) + db_session.commit() + return document + + +class TestMetadataPartialUpdate: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def user_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def mock_current_account(self, user_id, tenant_id): + account = Mock(id=user_id, current_tenant_id=tenant_id) + with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)): + yield account + + def test_partial_update_merges_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata["existing_key"] == "existing_value" + assert updated_doc.doc_metadata["new_key"] == "new_value" + + def test_full_update_replaces_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata == {"new_key": "new_value"} + assert "existing_key" not in updated_doc.doc_metadata + + def test_partial_update_skips_existing_binding( + self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + existing_binding = DatasetMetadataBinding( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=document.id, + metadata_id=meta_id, + created_by=user_id, + ) + db_session_with_containers.add(existing_binding) + db_session_with_containers.commit() + + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="existing_key", value="existing_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where( + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == meta_id, + ) + ).all() + assert len(bindings) == 1 + + def test_rollback_called_on_commit_failure( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="key", value="value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): + with pytest.raises(RuntimeError, match="database connection lost"): + MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e847329c5b..8b1349be9a 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -5,6 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom @@ -139,7 +140,7 @@ class TestMetadataService: name=fake.file_name(), created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 989df42499..ca6e7afeab 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -18,11 +18,10 @@ class TestModelLoadBalancingService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch( - "services.model_load_balancing_service.ModelProviderFactory", autospec=True - ) as mock_model_provider_factory, + "services.model_load_balancing_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns @@ -46,9 +45,6 @@ class TestModelLoadBalancingService: # Mock LBModelManager mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) - # Mock ModelProviderFactory - mock_model_provider_factory_instance = mock_model_provider_factory.return_value - # Mock credential schemas mock_credential_schema = MagicMock() mock_credential_schema.credential_form_schemas = [] @@ -61,7 +57,6 @@ class TestModelLoadBalancingService: yield { "provider_manager": mock_provider_manager, "lb_model_manager": mock_lb_model_manager, - "model_provider_factory": mock_model_provider_factory, "encrypter": mock_encrypter, "provider_config": mock_provider_config, "provider_model_setting": mock_provider_model_setting, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 6afc5aa43c..8955a3b5f2 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -5,7 +5,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -18,8 +18,12 @@ class TestModelProviderService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory, + patch( + "services.model_provider_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch( + "services.model_provider_service.create_plugin_model_provider_factory", autospec=True + ) as mock_model_provider_factory, ): # Setup default mock returns mock_provider_manager.return_value.get_configurations.return_value = MagicMock() @@ -402,8 +406,8 @@ class TestModelProviderService: # Create mock models from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject - from dify_graph.model_runtime.entities.provider_entities import ProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject + from graphon.model_runtime.entities.provider_entities import ProviderEntity # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( @@ -640,7 +644,7 @@ class TestModelProviderService: # Create mock default model response from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..c146a5924b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -0,0 +1,174 @@ +"""Testcontainers integration tests for OAuthServerService.""" + +from __future__ import annotations + +import uuid +from typing import cast +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import BadRequest + +from models.model import OAuthProviderApp +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +class TestOAuthServerServiceGetProviderApp: + """DB-backed tests for get_oauth_provider_app.""" + + def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + app = OAuthProviderApp( + app_icon="icon.png", + client_id=client_id, + client_secret=str(uuid4()), + app_label={"en-US": "Test OAuth App"}, + redirect_uris=["https://example.com/callback"], + scope="read", + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + client_id = f"client-{uuid4()}" + created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) + + result = OAuthServerService.get_oauth_provider_app(client_id) + + assert result is not None + assert result.client_id == client_id + assert result.id == created.id + + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") + + assert result is None + + +class TestOAuthServerServiceTokenOperations: + """Redis-backed tests for token sign/validate operations.""" + + @pytest.fixture + def mock_redis(self): + with patch("services.oauth_server.redis_client") as mock: + yield mock + + def test_sign_authorization_code_stores_and_returns_code(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + assert code == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=code), + "user-1", + ex=600, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_code(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + def test_sign_access_token_issues_tokens_for_valid_code(self, mock_redis): + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + with patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids): + mock_redis.get.return_value = b"user-1" + + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + mock_redis.delete.assert_called_once_with(code_key) + mock_redis.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_sign_access_token_raises_bad_request_for_invalid_refresh_token(self, mock_redis): + mock_redis.get.return_value = None + + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + def test_sign_access_token_issues_new_token_for_valid_refresh(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + mock_redis.get.return_value = b"user-1" + + access_token, returned_refresh = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + assert access_token == str(deterministic_uuid) + assert returned_refresh == "refresh-1" + + def test_sign_access_token_returns_none_for_unknown_grant_type(self, mock_redis): + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + result = OAuthServerService.sign_oauth_access_token(grant_type=grant_type, client_id="client-1") + + assert result is None + + def test_sign_refresh_token_stores_with_expected_expiry(self, mock_redis): + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + with patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid): + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + assert refresh_token == str(deterministic_uuid) + mock_redis.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + def test_validate_access_token_returns_none_when_not_found(self, mock_redis): + mock_redis.get.return_value = None + + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + assert result is None + + def test_validate_access_token_loads_user_when_exists(self, mock_redis): + mock_redis.get.return_value = b"user-88" + expected_user = MagicMock() + + with patch("services.oauth_server.AccountService.load_user", return_value=expected_user) as mock_load: + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + assert result is expected_user + mock_load.assert_called_once_with("user-88") diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index ba4310e22e..7036524918 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -2,17 +2,43 @@ Testcontainers integration tests for workflow run restore functionality. """ +from __future__ import annotations + +from datetime import datetime from uuid import uuid4 from sqlalchemy import select -from models.workflow import WorkflowPause +from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore class TestWorkflowRunRestore: """Tests for the WorkflowRunRestore class.""" + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + def test_restore_table_records_returns_rowcount(self, db_session_with_containers): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 94a4e62560..70aa813142 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -20,7 +20,7 @@ class TestSavedMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.saved_message_service.MessageService") as mock_message_service, ): # Setup default mock returns @@ -396,11 +396,6 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) - # Verify no database operations were performed - - saved_messages = db_session_with_containers.query(SavedMessage).all() - assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -497,124 +492,140 @@ class TestSavedMessageService: # The message should still exist, only the saved_message should be deleted assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): - """ - Test error handling when no user is provided. - - This test verifies: - - Proper error handling for missing user - - ValueError is raised when user is None - - No database operations are performed - """ - # Arrange: Create test data - fake = Faker() + def test_save_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test saving a message for an EndUser.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) - # Act & Assert: Verify proper error handling - with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + mock_external_service_dependencies["message_service"].get_message.return_value = message - assert "User is required" in str(exc_info.value) + SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) - # Verify no database operations were performed for this specific test - # Note: We don't check total count as other tests may have created data - # Instead, we verify that the error was properly raised - pass - - def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): - """ - Test error handling when saving message with no user. - - This test verifies: - - Method returns early when user is None - - No database operations are performed - - No exceptions are raised - """ - # Arrange: Create test data - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - message = self._create_test_message(db_session_with_containers, app, account) - - # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) - - # Assert: Verify the expected outcomes - assert result is None - - # Verify no saved message was created - - saved_message = ( + saved = ( db_session_with_containers.query(SavedMessage) - .where( - SavedMessage.app_id == app.id, - SavedMessage.message_id == message.id, - ) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) .first() ) + assert saved is not None + assert saved.created_by == end_user.id + assert saved.created_by_role == "end_user" - assert saved_message is None - - def test_delete_success_existing_message( + def test_save_duplicate_is_idempotent( self, db_session_with_containers: Session, mock_external_service_dependencies ): - """ - Test successful deletion of an existing saved message. - - This test verifies: - - Proper deletion of existing saved message - - Correct database state after deletion - - No errors during deletion process - """ - # Arrange: Create test data - fake = Faker() + """Test that saving an already-saved message does not create a duplicate.""" app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) message = self._create_test_message(db_session_with_containers, app, account) - # Create a saved message first - saved_message = SavedMessage( - app_id=app.id, - message_id=message.id, - created_by_role="account", - created_by=account.id, - ) + mock_external_service_dependencies["message_service"].get_message.return_value = message - db_session_with_containers.add(saved_message) + # Save once + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + # Save again + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + + count = ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .count() + ) + assert count == 1 + + def test_delete_without_user_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting without a user is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Pre-create a saved message + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + db_session_with_containers.add(saved) db_session_with_containers.commit() - # Verify saved message exists + SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + + # Should still exist + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is not None + ) + + def test_delete_non_existent_does_nothing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that deleting a non-existent saved message is a no-op.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Should not raise — use a valid UUID that doesn't exist in DB + from uuid import uuid4 + + SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + + def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): + """Test deleting a saved message for an EndUser.""" + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, end_user) + + saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + db_session_with_containers.add(saved) + db_session_with_containers.commit() + + SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + + assert ( + db_session_with_containers.query(SavedMessage) + .where(SavedMessage.app_id == app.id, SavedMessage.message_id == message.id) + .first() + is None + ) + + def test_delete_only_affects_own_saved_messages( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that delete only removes the requesting user's saved message.""" + app, account1 = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + end_user = self._create_test_end_user(db_session_with_containers, app) + message = self._create_test_message(db_session_with_containers, app, account1) + + # Both users save the same message + saved_account = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + ) + saved_end_user = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + ) + db_session_with_containers.add_all([saved_account, saved_end_user]) + db_session_with_containers.commit() + + # Delete only account1's saved message + SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + + # Account's saved message should be gone assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == account1.id, ) .first() - is not None + is None ) - - # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) - - # Assert: Verify the expected outcomes - # Check if saved message was deleted from database - deleted_saved_message = ( + # End user's saved message should still exist + assert ( db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, - SavedMessage.created_by_role == "account", - SavedMessage.created_by == account.id, + SavedMessage.created_by == end_user.id, ) .first() + is not None ) - - assert deleted_saved_message is None - - # Verify database state - db_session_with_containers.commit() - # The message should still exist, only the saved_message should be deleted - assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 1a72e3b6c2..f504f35589 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType @@ -102,7 +103,7 @@ class TestTagService: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6b95954480..f2307fbd7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -25,7 +25,7 @@ class TestWebConversationService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 8ab8df2a5a..2a18345c87 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1,20 +1,24 @@ +from __future__ import annotations + import json import uuid from datetime import UTC, datetime, timedelta +from types import SimpleNamespace from unittest.mock import patch import pytest from faker import Faker from sqlalchemy.orm import Session -from dify_graph.entities.workflow_execution import WorkflowExecutionStatus -from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun -from models.enums import CreatorUserRole +from graphon.entities.workflow_execution import WorkflowExecutionStatus +from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService -from services.workflow_app_service import WorkflowAppService +from services.workflow_app_service import LogView, WorkflowAppService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -27,7 +31,7 @@ class TestWorkflowAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service @@ -221,7 +225,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -357,7 +361,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_1.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -399,7 +403,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_2.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -441,7 +445,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run_4.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -521,7 +525,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -627,7 +631,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -732,7 +736,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -860,7 +864,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -902,7 +906,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="web-app", + created_from=WorkflowAppLogCreatedFrom.WEB_APP, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -1037,7 +1041,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1125,7 +1129,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1279,7 +1283,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1379,7 +1383,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1481,7 +1485,7 @@ class TestWorkflowAppService: app_id=app.id, workflow_id=workflow.id, workflow_run_id=workflow_run.id, - created_from="service-api", + created_from=WorkflowAppLogCreatedFrom.SERVICE_API, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -1524,3 +1528,168 @@ class TestWorkflowAppService: # Should not find tenant2's data when searching from tenant1's context assert result_cross_tenant["total"] == 0 + + def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + with pytest.raises(ValueError, match="Account not found: nonexistent@example.com"): + service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account="nonexistent@example.com", + ) + + def test_get_paginate_workflow_app_logs_filters_by_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + workflow, workflow_run, _log = self._create_test_workflow_data(db_session_with_containers, app, account) + + result = service.get_paginate_workflow_app_logs( + session=db_session_with_containers, + app_model=app, + created_by_account=account.email, + ) + + assert result["total"] >= 0 + assert isinstance(result["data"], list) + + def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + service = WorkflowAppService() + + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type="browser", + is_anonymous=False, + session_id="session-1", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + + now = datetime.now(UTC) + archive_defaults = { + "workflow_id": str(uuid.uuid4()), + "run_version": "1.0.0", + "run_status": WorkflowExecutionStatus.SUCCEEDED, + "run_triggered_from": WorkflowRunTriggeredFrom.APP_RUN, + "run_error": None, + "run_elapsed_time": 1.0, + "run_total_tokens": 0, + "run_total_steps": 0, + "run_created_at": now, + "run_finished_at": now, + "run_exceptions_count": 0, + "trigger_metadata": '{"type":"trigger-webhook"}', + "log_created_at": now, + "log_created_from": WorkflowAppLogCreatedFrom.SERVICE_API, + } + archive_account = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=account.id, + created_by_role=CreatorUserRole.ACCOUNT, + **archive_defaults, + ) + archive_end_user = WorkflowArchiveLog( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_run_id=str(uuid.uuid4()), + log_id=str(uuid.uuid4()), + created_by=end_user.id, + created_by_role=CreatorUserRole.END_USER, + **archive_defaults, + ) + db_session_with_containers.add_all([archive_account, archive_end_user]) + db_session_with_containers.commit() + + result = service.get_paginate_workflow_archive_logs( + session=db_session_with_containers, + app_model=app, + page=1, + limit=20, + ) + + assert result["total"] == 2 + assert len(result["data"]) == 2 + account_item = next(d for d in result["data"] if d["created_by_account"] is not None) + end_user_item = next(d for d in result["data"] if d["created_by_end_user"] is not None) + assert account_item["created_by_account"].id == account.id + assert end_user_item["created_by_end_user"].id == end_user.id + + +class TestLogView: + def test_details_and_proxy_attributes(self): + log = SimpleNamespace(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + assert view.details == {"trigger_metadata": {"type": "plugin"}} + assert view.status == "succeeded" + + +class TestHandleTriggerMetadata: + def test_returns_empty_dict_when_metadata_missing(self): + service = WorkflowAppService() + assert service.handle_trigger_metadata("tenant-1", None) == {} + + def test_enriches_plugin_icons(self): + service = WorkflowAppService() + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + with patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + def test_non_plugin_metadata_without_icon_lookup(self): + service = WorkflowAppService() + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + with patch("services.workflow_app_service.PluginService.get_plugin_icon_url") as mock_icon: + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +class TestSafeJsonLoads: + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], + ) + def test_handles_various_inputs(self, value, expected): + assert WorkflowAppService._safe_json_loads(value) == expected + + +class TestSafeParseUuid: + def test_returns_none_for_short_or_invalid_values(self): + service = WorkflowAppService() + assert service._safe_parse_uuid("short") is None + assert service._safe_parse_uuid("x" * 40) is None + + def test_returns_uuid_for_valid_string(self): + service = WorkflowAppService() + raw = str(uuid.uuid4()) + result = service._safe_parse_uuid(raw) + assert result is not None + assert str(result) == raw diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 572cf72fa0..86cf2327c7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -2,8 +2,8 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.segments import StringSegment +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -482,7 +482,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -734,7 +734,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 731770e01a..d02a078281 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -27,7 +27,7 @@ class TestWorkflowRunService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index a5fe052206..ee7b68e6aa 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -1503,10 +1503,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunSucceededEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1548,12 +1548,12 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from dify_graph.enums import BuiltinNodeTypes + from graphon.enums import BuiltinNodeTypes assert result.node_type == BuiltinNodeTypes.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs is not None @@ -1578,10 +1578,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1623,7 +1623,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error is not None @@ -1647,10 +1647,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) @@ -1693,7 +1693,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.EXCEPTION # Should be EXCEPTION, not FAILED assert result.outputs is not None diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index bffdca623a..d3e765055a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -536,3 +536,151 @@ class TestApiToolManageService: # Verify mock interactions mock_external_service_dependencies["encrypter"].assert_called_once() mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + + def test_delete_api_tool_provider_success( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of an API tool provider.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + provider = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert provider is not None + + result = ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, provider_name) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) + .first() + ) + assert deleted is None + + def test_delete_api_tool_provider_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test deletion raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when original provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name="new-name", + original_provider="nonexistent", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_update_api_tool_provider_missing_auth_type( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test update raises ValueError when auth_type is missing from credentials.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + schema = self._create_test_openapi_schema() + provider_name = fake.unique.word() + + ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + icon={"content": "🔧", "background": "#FFF"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy="", + custom_disclaimer="", + labels=[], + ) + + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=provider_name, + original_provider=provider_name, + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=schema, + privacy_policy=None, + custom_disclaimer="", + labels=[], + ) + + def test_list_api_tool_provider_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing tools raises ValueError when provider not found.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="you have not added provider"): + ApiToolManageService.list_api_tool_provider_tools(account.id, tenant.id, "nonexistent") + + def test_test_api_tool_preview_invalid_schema_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test preview raises ValueError for invalid schema type.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id=tenant.id, + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type="bad-schema-type", + schema="schema", + ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index 0f38218c51..2dc50cc720 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -1,12 +1,24 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest from faker import Faker from sqlalchemy.orm import Session -from core.tools.entities.api_entities import ToolProviderApiEntity +from core.tools.__base.tool import Tool +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolDescription, + ToolEntity, + ToolIdentity, + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService @@ -52,7 +64,7 @@ class TestToolTransformService: user_id="test_user_id", credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) elif provider_type == "builtin": @@ -659,7 +671,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -695,7 +707,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -731,7 +743,7 @@ class TestToolTransformService: user_id=fake.uuid4(), credentials_str='{"auth_type": "api_key", "api_key": "test_key"}', schema="{}", - schema_type_str="openapi", + schema_type_str=ApiProviderSchemaType.OPENAPI, tools_str="[]", ) @@ -786,3 +798,192 @@ class TestToolTransformService: assert result is not None assert result == mock_controller mock_from_db.assert_called_once_with(provider) + + +def _mock_tool(*, base_params, runtime_params): + """Helper to build a Mock tool with real entity objects. + + Tool is abstract and requires runtime behaviour (fork_tool_runtime, + get_runtime_parameters), so it stays as a Mock. Everything else uses + real Pydantic instances. + """ + entity = ToolEntity( + identity=ToolIdentity( + author="test_author", + name="test_tool", + label=I18nObject(en_US="Test Tool"), + provider="test_provider", + ), + parameters=base_params or [], + description=ToolDescription( + human=I18nObject(en_US="Test description"), + llm="Test description for LLM", + ), + output_schema={}, + ) + mock_tool = Mock(spec=Tool) + mock_tool.entity = entity + mock_tool.get_runtime_parameters.return_value = runtime_params + mock_tool.fork_tool_runtime.return_value = mock_tool + return mock_tool + + +def _param(name, *, form=ToolParameter.ToolParameterForm.FORM, label=None): + return ToolParameter( + name=name, + label=I18nObject(en_US=label or name), + human_description=I18nObject(en_US=name), + type=ToolParameter.ToolParameterType.STRING, + form=form, + ) + + +class TestConvertToolEntityToApiEntity: + """Tests for ToolTransformService.convert_tool_entity_to_api_entity.""" + + def test_parameter_override(self): + base = [_param("param1", label="Base 1"), _param("param2", label="Base 2")] + runtime = [_param("param1", label="Runtime 1")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 2 + assert next(p for p in result.parameters if p.name == "param1").label.en_US == "Runtime 1" + assert next(p for p in result.parameters if p.name == "param2").label.en_US == "Base 2" + + def test_additional_runtime_parameters(self): + base = [_param("param1", label="Base 1")] + runtime = [_param("param1", label="Runtime 1"), _param("runtime_only", label="Runtime Only")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 2 + names = [p.name for p in result.parameters] + assert "param1" in names + assert "runtime_only" in names + + def test_non_form_runtime_parameters_excluded(self): + base = [_param("param1")] + runtime = [ + _param("param1", label="Runtime 1"), + _param("llm_param", form=ToolParameter.ToolParameterForm.LLM), + ] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert len(result.parameters) == 1 + assert result.parameters[0].name == "param1" + + def test_empty_parameters(self): + tool = _mock_tool(base_params=[], runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_none_parameters(self): + tool = _mock_tool(base_params=None, runtime_params=[]) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert isinstance(result, ToolApiEntity) + assert len(result.parameters) == 0 + + def test_parameter_order_preserved(self): + base = [_param("p1", label="B1"), _param("p2", label="B2"), _param("p3", label="B3")] + runtime = [_param("p2", label="R2"), _param("p4", label="R4")] + tool = _mock_tool(base_params=base, runtime_params=runtime) + + result = ToolTransformService.convert_tool_entity_to_api_entity(tool, "t", None) + + assert [p.name for p in result.parameters] == ["p1", "p2", "p3", "p4"] + assert result.parameters[1].label.en_US == "R2" + + +class TestWorkflowProviderToUserProvider: + """Tests for ToolTransformService.workflow_provider_to_user_provider.""" + + @staticmethod + def _make_controller(provider_id="provider_123", **identity_overrides): + from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + defaults = { + "author": "test_author", + "name": "test_workflow_tool", + "description": I18nObject(en_US="Test description"), + "icon": '{"type": "emoji", "content": "🔧"}', + "icon_dark": None, + "label": I18nObject(en_US="Test Workflow Tool"), + } + defaults.update(identity_overrides) + identity = ToolProviderIdentity(**defaults) + entity = ToolProviderEntity(identity=identity) + return WorkflowToolProviderController(entity=entity, provider_id=provider_id) + + def test_with_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1", "l2"], + workflow_app_id="app_123", + ) + + assert isinstance(result, ToolProviderApiEntity) + assert result.id == "provider_123" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_123" + assert result.labels == ["l1", "l2"] + assert result.is_team_authorization is True + + def test_without_workflow_app_id(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["l1"], + ) + + assert result.workflow_app_id is None + + def test_workflow_app_id_none_explicit(self): + ctrl = self._make_controller() + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=None, + workflow_app_id=None, + ) + + assert result.workflow_app_id is None + assert result.labels == [] + + def test_preserves_other_fields(self): + ctrl = self._make_controller( + "provider_456", + author="another_author", + name="another_workflow_tool", + description=I18nObject(en_US="Another desc", zh_Hans="Another desc"), + icon='{"type": "emoji", "content": "⚙️"}', + icon_dark='{"type": "emoji", "content": "🔧"}', + label=I18nObject(en_US="Another Tool", zh_Hans="Another Tool"), + ) + + result = ToolTransformService.workflow_provider_to_user_provider( + provider_controller=ctrl, + labels=["automation"], + workflow_app_id="app_456", + ) + + assert result.id == "provider_456" + assert result.author == "another_author" + assert result.name == "another_workflow_tool" + assert result.type == ToolProviderType.WORKFLOW + assert result.workflow_app_id == "app_456" + assert result.is_team_authorization is True + assert result.allow_delete is True diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 34906a4e54..21a1975879 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -25,7 +25,7 @@ class TestWorkflowToolManageService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, patch( "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" @@ -1043,3 +1043,112 @@ class TestWorkflowToolManageService: # After the fix, this should always be 0 # For now, we document that the record may exist, demonstrating the bug # assert tool_count == 0 # Expected after fix + + def test_delete_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test successful deletion of a workflow tool.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + tool_name = fake.unique.word() + + WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=tool_name, + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + ) + + tool = ( + db_session_with_containers.query(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name) + .first() + ) + assert tool is not None + + result = WorkflowToolManageService.delete_workflow_tool(account.id, account.current_tenant.id, tool.id) + + assert result == {"result": "success"} + deleted = ( + db_session_with_containers.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool.id).first() + ) + assert deleted is None + + def test_list_tenant_workflow_tools_empty( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test listing workflow tools when none exist returns empty list.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.list_tenant_workflow_tools(account.id, account.current_tenant.id) + + assert result == [] + + def test_get_workflow_tool_by_tool_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_tool_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_get_workflow_tool_by_app_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that get_workflow_tool_by_app_id raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id(account.id, account.current_tenant.id, fake.uuid4()) + + def test_list_single_workflow_tools_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that list_single_workflow_tools raises ValueError when tool not found.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools(account.id, account.current_tenant.id, fake.uuid4()) + + def test_create_workflow_tool_with_labels( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """Test that labels are forwarded to ToolLabelManager when provided.""" + fake = Faker() + app, account, workflow = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies + ) + + result = WorkflowToolManageService.create_workflow_tool( + user_id=account.id, + tenant_id=account.current_tenant.id, + workflow_app_id=app.id, + name=fake.unique.word(), + label=fake.word(), + icon={"type": "emoji", "emoji": "🔧"}, + description=fake.text(max_nb_chars=200), + parameters=self._create_test_workflow_tool_parameters(), + labels=["label-1", "label-2"], + ) + + assert result == {"result": "success"} + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index c3fe6a2950..ce5c2bd162 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import json -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy.orm import Session from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, DatasetEntity, DatasetRetrieveConfigEntity, ExternalDataVariableEntity, @@ -13,10 +18,11 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant -from models.api_based_extension import APIBasedExtension +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow.workflow_converter import WorkflowConverter @@ -548,3 +554,198 @@ class TestWorkflowConverter: # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT), + VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH), + VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT), + ] + + +class TestConvertToHttpRequestNodeVariants: + """Tests for chatbot vs workflow differences in HTTP request node conversion.""" + + @staticmethod + def _setup(app_mode, default_variables): + app_model = App( + tenant_id="tenant_id", + mode=app_mode, + name="test", + icon_type="emoji", + icon="🤖", + icon_background="#FFFFFF", + ) + + ext = APIBasedExtension(tenant_id="tenant_id", name="api-1", api_key="enc", api_endpoint="https://dify.ai") + ext.id = "ext_id" + + converter = WorkflowConverter() + converter._get_api_based_extension = MagicMock(return_value=ext) + + from core.helper import encrypter + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + ext_vars = [ + ExternalDataVariableEntity( + variable="external_variable", type="api", config={"api_based_extension_id": "ext_id"} + ) + ] + nodes, _ = converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=ext_vars, + ) + return nodes + + def test_chatbot_query_uses_sys_query(self, default_variables): + nodes = self._setup(AppMode.CHAT, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "{{#sys.query#}}" + assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY + assert nodes[1]["data"]["type"] == "code" + + def test_workflow_query_is_empty(self, default_variables): + nodes = self._setup(AppMode.WORKFLOW, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "" + + +class TestConvertToKnowledgeRetrievalNodeVariants: + """Tests for chatbot vs workflow differences in knowledge retrieval node.""" + + @staticmethod + def _dataset_config(query_variable=None): + return DatasetEntity( + dataset_ids=["ds1", "ds2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + + @staticmethod + def _model_config(): + return ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + def test_chatbot_uses_sys_query(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.ADVANCED_CHAT, + dataset_config=self._dataset_config(), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["sys", "query"] + + def test_workflow_uses_start_variable(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.WORKFLOW, + dataset_config=self._dataset_config(query_variable="query"), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["start", "query"] + + +class TestConvertToLlmNode: + """Tests for LLM node conversion across model modes and prompt types.""" + + @staticmethod + def _model_config(model, mode): + return ModelConfigEntity( + provider="openai", + model=model, + mode=mode.value, + parameters={}, + stop=[], + ) + + @staticmethod + def _graph(default_variables): + start = WorkflowConverter()._convert_to_start_node(default_variables) + return {"nodes": [start], "edges": []} + + def test_simple_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["type"] == "llm" + assert node["data"]["model"]["mode"] == LLMMode.CHAT.value + assert node["data"]["context"]["enabled"] is False + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"][0]["text"] == expected + + def test_simple_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["model"]["mode"] == LLMMode.COMPLETION.value + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"]["text"] == expected + + def test_advanced_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are helpful named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], list) + assert len(node["data"]["prompt_template"]) == 3 + + def test_advanced_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are helpful named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", assistant="Assistant" + ), + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], dict) + assert "text" in node["data"]["prompt_template"] diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index af9e8d0b2c..4dab895135 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -4,7 +4,7 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 94173c34bf..4b04c1accb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestAddDocumentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 210d9eb39e..6cbbe43137 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -13,6 +13,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -152,7 +153,7 @@ class TestBatchCleanDocumentTask: created_from=DocumentCreatedFrom.WEB, created_by=account.id, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) @@ -392,7 +393,12 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Execute the task with non-existent dataset - batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) + batch_clean_document_task( + document_ids=[document_id], + dataset_id=dataset_id, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + file_ids=[], + ) # Verify that no index processing occurred mock_external_service_dependencies["index_processor"].clean.assert_not_called() @@ -525,7 +531,11 @@ class TestBatchCleanDocumentTask: account = self._create_test_account(db_session_with_containers) # Test different doc_form types - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 202ccb0098..f9ae33b32f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -19,6 +19,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -53,7 +54,10 @@ class TestBatchCreateSegmentToIndexTask: """Mock setup for external service dependencies.""" with ( patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch( + "tasks.batch_create_segment_to_index_task.ModelManager.for_tenant", + autospec=True, + ) as mock_model_manager, patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns @@ -141,7 +145,7 @@ class TestBatchCreateSegmentToIndexTask: name=fake.company(), description=fake.text(), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, @@ -179,7 +183,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ) @@ -221,17 +225,17 @@ class TestBatchCreateSegmentToIndexTask: return upload_file - def _create_test_csv_content(self, content_type="text_model"): + def _create_test_csv_content(self, content_type=IndexStructureType.PARAGRAPH_INDEX): """ Helper method to create test CSV content. Args: - content_type: Type of content to create ("text_model" or "qa_model") + content_type: Type of content to create (IndexStructureType.PARAGRAPH_INDEX or IndexStructureType.QA_INDEX) Returns: str: CSV content as string """ - if content_type == "qa_model": + if content_type == IndexStructureType.QA_INDEX: csv_content = "content,answer\n" csv_content += "This is the first segment content,This is the first answer\n" csv_content += "This is the second segment content,This is the second answer\n" @@ -264,7 +268,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = self._create_test_upload_file(db_session_with_containers, account, tenant) # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] @@ -451,7 +455,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Archived document @@ -467,7 +471,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), # Document with incomplete indexing @@ -483,7 +487,7 @@ class TestBatchCreateSegmentToIndexTask: indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=0, ), ] @@ -655,7 +659,7 @@ class TestBatchCreateSegmentToIndexTask: db_session_with_containers.commit() # Create CSV content - csv_content = self._create_test_csv_content("text_model") + csv_content = self._create_test_csv_content(IndexStructureType.PARAGRAPH_INDEX) # Mock storage to return our CSV content mock_storage = mock_external_service_dependencies["storage"] diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1cd698b870..1dd37fbc92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -18,6 +18,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -153,7 +154,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name="test_dataset", description="Test dataset for cleanup testing", - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph"}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, @@ -192,7 +193,7 @@ class TestCleanDatasetTask: indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, - doc_form="paragraph_index", + doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=100, created_at=datetime.now(), updated_at=datetime.now(), @@ -869,7 +870,7 @@ class TestCleanDatasetTask: tenant_id=tenant.id, name=long_name, description=long_description, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, index_struct='{"type": "paragraph", "max_length": 10000}', collection_binding_id=str(uuid.uuid4()), created_by=account.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index a2a190fd69..926c839c8b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -114,7 +115,7 @@ class TestCleanNotionDocumentTask: name=f"Notion Page {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", # Set doc_form to ensure dataset.doc_form works + doc_form=IndexStructureType.PARAGRAPH_INDEX, # Set doc_form to ensure dataset.doc_form works doc_language="en", indexing_status=IndexingStatus.COMPLETED, ) @@ -261,7 +262,7 @@ class TestCleanNotionDocumentTask: # Test different index types # Note: Only testing text_model to avoid dependency on external services - index_types = ["text_model"] + index_types = [IndexStructureType.PARAGRAPH_INDEX] for index_type in index_types: # Create dataset (doc_form will be set via document creation) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 132f43c320..9f8e37fc9e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -120,7 +121,7 @@ class TestCreateSegmentToIndexTask: description=fake.text(max_nb_chars=100), tenant_id=tenant_id, data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, embedding_model_provider="openai", embedding_model="text-embedding-ada-002", created_by=account_id, @@ -141,7 +142,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -301,7 +302,7 @@ class TestCreateSegmentToIndexTask: enabled=True, archived=False, indexing_status=IndexingStatus.COMPLETED, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() @@ -552,7 +553,11 @@ class TestCreateSegmentToIndexTask: - Processing completes successfully for different forms """ # Arrange: Test different doc_forms - doc_forms = ["qa_model", "text_model", "web_model"] + doc_forms = [ + IndexStructureType.QA_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.PARAGRAPH_INDEX, + ] for doc_form in doc_forms: # Create fresh test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 67f9dc7011..13ea94348a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -141,7 +142,7 @@ class TestDatasetIndexingTaskIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index e80b37ac1b..d457b59d58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService @@ -107,7 +108,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -167,7 +168,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -187,7 +188,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -268,7 +269,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -288,7 +289,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="parent_child_index", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -416,7 +417,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -505,7 +506,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -525,7 +526,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -601,7 +602,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="qa_index", + doc_form=IndexStructureType.QA_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -638,7 +639,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type - mock_index_processor_factory.assert_called_once_with("qa_index") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.QA_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -677,7 +678,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -714,7 +715,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type - mock_index_processor_factory.assert_called_once_with("text_model") + mock_index_processor_factory.assert_called_once_with(IndexStructureType.PARAGRAPH_INDEX) mock_factory = mock_index_processor_factory.return_value mock_processor = mock_factory.init_index_processor.return_value mock_processor.load.assert_called_once() @@ -753,7 +754,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -775,7 +776,7 @@ class TestDealDatasetVectorIndexTask: name=f"Test Document {i}", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -856,7 +857,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -876,7 +877,7 @@ class TestDealDatasetVectorIndexTask: name="Test Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -953,7 +954,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -973,7 +974,7 @@ class TestDealDatasetVectorIndexTask: name="Enabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -992,7 +993,7 @@ class TestDealDatasetVectorIndexTask: name="Disabled Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped @@ -1074,7 +1075,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1094,7 +1095,7 @@ class TestDealDatasetVectorIndexTask: name="Active Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1113,7 +1114,7 @@ class TestDealDatasetVectorIndexTask: name="Archived Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1195,7 +1196,7 @@ class TestDealDatasetVectorIndexTask: name="Document for doc_form", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1215,7 +1216,7 @@ class TestDealDatasetVectorIndexTask: name="Completed Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.COMPLETED, enabled=True, @@ -1234,7 +1235,7 @@ class TestDealDatasetVectorIndexTask: name="Incomplete Document", created_from=DocumentCreatedFrom.WEB, created_by=account.id, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 6fc2a53f9c..8a69707b38 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, Document, DocumentSegment, Tenant from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -108,7 +108,7 @@ class TestDeleteSegmentFromIndexTask: dataset.provider = "vendor" dataset.permission = "only_me" dataset.data_source_type = DataSourceType.UPLOAD_FILE - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id dataset.created_at = fake.date_time_this_year() diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index da42fc7167..5bdf7d1389 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -15,6 +15,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -99,7 +100,7 @@ class TestDisableSegmentFromIndexTask: name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -113,7 +114,7 @@ class TestDisableSegmentFromIndexTask: dataset: Dataset, tenant: Tenant, account: Account, - doc_form: str = "text_model", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, ) -> Document: """ Helper method to create a test document. @@ -476,7 +477,11 @@ class TestDisableSegmentFromIndexTask: - Index processor clean method is called correctly """ # Test different document forms - doc_forms = ["text_model", "qa_model", "table_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Arrange: Create test data for each form diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 4bc9bb4749..3e9a0c8f7f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch from faker import Faker from sqlalchemy.orm import Session +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule @@ -102,7 +103,7 @@ class TestDisableSegmentsFromIndexTask: provider="vendor", permission="only_me", data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, updated_by=account.id, embedding_model="text-embedding-ada-002", @@ -153,7 +154,7 @@ class TestDisableSegmentsFromIndexTask: document.indexing_status = "completed" document.enabled = True document.archived = False - document.doc_form = "text_model" # Use text_model form for testing + document.doc_form = IndexStructureType.PARAGRAPH_INDEX # Use text_model form for testing document.doc_language = "en" db_session_with_containers.add(document) db_session_with_containers.commit() @@ -500,7 +501,11 @@ class TestDisableSegmentsFromIndexTask: segment_ids = [segment.id for segment in segments] # Test different document forms - doc_forms = ["text_model", "qa_model", "hierarchical_model"] + doc_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in doc_forms: # Update document form diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index 6a17a19a54..d4021143ef 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -14,6 +14,7 @@ from uuid import uuid4 import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -56,7 +57,7 @@ class DocumentIndexingSyncTaskTestDataFactory: name=f"dataset-{uuid4()}", description="sync test dataset", data_source_type=DataSourceType.NOTION_IMPORT, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -85,7 +86,7 @@ class DocumentIndexingSyncTaskTestDataFactory: created_by=created_by, indexing_status=indexing_status, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, doc_language="en", ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 9421b07285..cf1a8666f3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -5,6 +5,7 @@ import pytest from faker import Faker from core.entities.document_task import DocumentTask +from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -99,7 +100,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -181,7 +182,7 @@ class TestDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 2fbea1388c..d94abf2b40 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus @@ -63,7 +64,7 @@ class TestDocumentIndexingUpdateTask: name=fake.company(), description=fake.text(max_nb_chars=64), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -80,7 +81,7 @@ class TestDocumentIndexingUpdateTask: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index f1f5a4b105..6a8e186958 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from core.indexing_runner import DocumentIsPausedError +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -109,7 +110,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -130,7 +131,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -244,7 +245,7 @@ class TestDuplicateDocumentIndexingTasks: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -265,7 +266,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=account.id, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) documents.append(document) @@ -524,7 +525,7 @@ class TestDuplicateDocumentIndexingTasks: created_by=dataset.created_by, indexing_status=IndexingStatus.WAITING, enabled=True, - doc_form="text_model", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db_session_with_containers.add(document) extra_documents.append(document) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 54b50016a8..e2f35067e3 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -4,7 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -81,7 +81,7 @@ class TestEnableSegmentsToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), data_source_type=DataSourceType.UPLOAD_FILE, - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 0876a39f82..d341c5ce99 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -9,17 +9,17 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient @@ -79,9 +79,9 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id=account.id), + MemberRecipient(reference_id=account.id), ExternalRecipient(email="external@example.com"), ], ), @@ -96,9 +96,8 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=app_id) params = FormCreateParams( - app_id=app_id, workflow_execution_id=workflow_execution_id, node_id="node-1", form_config=node_data, diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 5bded4d670..9a7507a2f9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,9 +4,9 @@ from unittest.mock import ANY, call, patch import pytest from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index ca76fa0a4b..b9f513a6d0 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -27,9 +27,9 @@ import pytest from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 4ea8d8c1c7..8854ef5e04 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -23,7 +23,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 3f75fd2851..55873b06a8 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine): def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): """ - Helper to set up the mock DB query chain for tenant/account authentication. + Helper to set up the mock DB execute chain for tenant/account authentication. - This configures the mock to return (tenant, account) for the join query used - by validate_app_token and validate_dataset_token decorators. + This configures the mock to return (tenant, account) for the + db.session.execute(select(...).join().join().where()).one_or_none() + query used by validate_app_token decorator. Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return mock_account: Mock account object to return """ - query = mock_db.session.query.return_value - join_chain = query.join.return_value.join.return_value - where_chain = join_chain.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): """ - Helper to set up the mock DB query chain for dataset tenant authentication. + Helper to set up the mock DB execute chain for dataset tenant authentication. - This configures the mock to return (tenant, tenant_account) for the where chain + This configures the mock to return (tenant, tenant_account) for the + db.session.execute(select(...).where().where().where().where()).one_or_none() query used by validate_dataset_token decorator. Args: @@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): mock_tenant: Mock tenant object to return mock_ta: Mock tenant account object to return """ - query = mock_db.session.query.return_value - where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index 60b8ee96fe..1d1e119fd6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" @@ -281,12 +328,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr( site_module, @@ -305,12 +350,10 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() - query = MagicMock() - query.where.return_value.first.return_value = site monkeypatch.setattr( site_module.db, "session", - MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 021e9a0784..2d218dac7e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -20,7 +20,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 5db8e5c332..11b3b3470d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -82,12 +82,8 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: conversation = SimpleNamespace(id="c1", app_id="app-1") - query = MagicMock() - query.where.return_value = query - query.first.return_value = conversation - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = conversation monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) @@ -101,12 +97,8 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 460da06ecc..f588ab261d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -24,7 +24,7 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): ), patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): - mock_session.query.return_value.where.return_value.first.return_value = conversation + mock_session.scalar.return_value = conversation _get_conversation(app_model, "conversation-id") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index f83bc18da3..e64c508b82 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -73,8 +73,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( "/console/api/instruction-generate", @@ -99,8 +98,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) with app.test_request_context( @@ -126,8 +124,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace(graph_dict={"nodes": []}) _install_workflow_service(monkeypatch, workflow=workflow) @@ -155,8 +152,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) app_model = SimpleNamespace(id="app-1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) workflow = SimpleNamespace( graph_dict={ diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py deleted file mode 100644 index e6dfc0d3bd..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_message.py +++ /dev/null @@ -1,320 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.exceptions import InternalServerError, NotFound -from werkzeug.local import LocalProxy - -from controllers.console.app.error import ( - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.app.message import ( - ChatMessageListApi, - ChatMessagesQuery, - FeedbackExportQuery, - MessageAnnotationCountApi, - MessageApi, - MessageFeedbackApi, - MessageFeedbackExportApi, - MessageFeedbackPayload, - MessageSuggestedQuestionApi, -) -from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from models import App, AppMode -from services.errors.conversation import ConversationNotExistsError -from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -import contextlib - - -@contextlib.contextmanager -def setup_test_context( - test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None -): - with ( - patch("extensions.ext_database.db") as mock_db, - patch("controllers.console.app.wraps.db", mock_db), - patch("controllers.console.wraps.db", mock_db), - patch("controllers.console.app.message.db", mock_db), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - # Set up a generic query mock that usually returns mock_app_model when getting app - app_query_mock = MagicMock() - app_query_mock.filter.return_value.first.return_value = mock_app_model - app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model - app_query_mock.where.return_value.first.return_value = mock_app_model - app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model - - data_query_mock = MagicMock() - - def query_side_effect(*args, **kwargs): - if args and hasattr(args[0], "__name__") and args[0].__name__ == "App": - return app_query_mock - return data_query_mock - - mock_db.session.query.side_effect = query_side_effect - mock_db.data_query = data_query_mock - - # Let the caller override the stat db query logic - proxy_mock = LocalProxy(lambda: mock_account) - - query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) - full_path = f"{route_path}?{query_string}" if qs else route_path - - with ( - patch("libs.login.current_user", proxy_mock), - patch("flask_login.current_user", proxy_mock), - patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), - ): - with test_app.test_request_context(full_path, method=method, json=payload): - request.view_args = {"app_id": "app_123"} - - if "suggested-questions" in route_path: - # simplistic extraction for message_id - parts = route_path.split("chat-messages/") - if len(parts) > 1: - request.view_args["message_id"] = parts[1].split("/")[0] - elif "messages/" in route_path and "chat-messages" not in route_path: - parts = route_path.split("messages/") - if len(parts) > 1: - request.view_args["message_id"] = parts[1].split("/")[0] - - api_instance = endpoint_class() - - # Check if it has a dispatch_request or method - if hasattr(api_instance, method.lower()): - yield api_instance, mock_db, request.view_args - - -class TestMessageValidators: - def test_chat_messages_query_validators(self): - # Test empty_to_none - assert ChatMessagesQuery.empty_to_none("") is None - assert ChatMessagesQuery.empty_to_none("val") == "val" - - # Test validate_uuid - assert ChatMessagesQuery.validate_uuid(None) is None - assert ( - ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") - == "123e4567-e89b-12d3-a456-426614174000" - ) - - def test_message_feedback_validators(self): - assert ( - MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") - == "123e4567-e89b-12d3-a456-426614174000" - ) - - def test_feedback_export_validators(self): - assert FeedbackExportQuery.parse_bool(None) is None - assert FeedbackExportQuery.parse_bool(True) is True - assert FeedbackExportQuery.parse_bool("1") is True - assert FeedbackExportQuery.parse_bool("0") is False - assert FeedbackExportQuery.parse_bool("off") is False - - with pytest.raises(ValueError): - FeedbackExportQuery.parse_bool("invalid") - - -class TestMessageEndpoints: - def test_chat_message_list_not_found(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - ChatMessageListApi, - "/apps/app_123/chat-messages", - "GET", - mock_account, - mock_app_model, - qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, - ) as (api, mock_db, v_args): - mock_db.session.scalar.return_value = None - - with pytest.raises(NotFound): - api.get(**v_args) - - def test_chat_message_list_success(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - ChatMessageListApi, - "/apps/app_123/chat-messages", - "GET", - mock_account, - mock_app_model, - qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1}, - ) as (api, mock_db, v_args): - mock_conv = MagicMock() - mock_conv.id = "123e4567-e89b-12d3-a456-426614174000" - mock_msg = MagicMock() - mock_msg.id = "msg_123" - mock_msg.feedbacks = [] - mock_msg.annotation = None - mock_msg.annotation_hit_history = None - mock_msg.agent_thoughts = [] - mock_msg.message_files = [] - mock_msg.extra_contents = [] - mock_msg.message = {} - mock_msg.message_metadata_dict = {} - - # scalar() is called twice: first for conversation lookup, second for has_more check - mock_db.session.scalar.side_effect = [mock_conv, False] - scalars_result = MagicMock() - scalars_result.all.return_value = [mock_msg] - mock_db.session.scalars.return_value = scalars_result - - resp = api.get(**v_args) - assert resp["limit"] == 1 - assert resp["has_more"] is False - assert len(resp["data"]) == 1 - - def test_message_feedback_not_found(self, app, mock_account, mock_app_model): - with setup_test_context( - app, - MessageFeedbackApi, - "/apps/app_123/feedbacks", - "POST", - mock_account, - mock_app_model, - payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, - ) as (api, mock_db, v_args): - mock_db.session.scalar.return_value = None - - with pytest.raises(NotFound): - api.post(**v_args) - - def test_message_feedback_success(self, app, mock_account, mock_app_model): - payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"} - with setup_test_context( - app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload - ) as (api, mock_db, v_args): - mock_msg = MagicMock() - mock_msg.admin_feedback = None - mock_db.session.scalar.return_value = mock_msg - - resp = api.post(**v_args) - assert resp == {"result": "success"} - - def test_message_annotation_count(self, app, mock_account, mock_app_model): - with setup_test_context( - app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - mock_db.session.scalar.return_value = 5 - - resp = api.get(**v_args) - assert resp == {"count": 5} - - @patch("controllers.console.app.message.MessageService") - def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model): - mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"] - - with setup_test_context( - app, - MessageSuggestedQuestionApi, - "/apps/app_123/chat-messages/msg_123/suggested-questions", - "GET", - mock_account, - mock_app_model, - ) as (api, mock_db, v_args): - resp = api.get(**v_args) - assert resp == {"data": ["q1", "q2"]} - - @pytest.mark.parametrize( - ("exc", "expected_exc"), - [ - (MessageNotExistsError, NotFound), - (ConversationNotExistsError, NotFound), - (ProviderTokenNotInitError, ProviderNotInitializeError), - (QuotaExceededError, ProviderQuotaExceededError), - (ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError), - (SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError), - (Exception, InternalServerError), - ], - ) - @patch("controllers.console.app.message.MessageService") - def test_message_suggested_questions_errors( - self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model - ): - mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc() - - with setup_test_context( - app, - MessageSuggestedQuestionApi, - "/apps/app_123/chat-messages/msg_123/suggested-questions", - "GET", - mock_account, - mock_app_model, - ) as (api, mock_db, v_args): - with pytest.raises(expected_exc): - api.get(**v_args) - - @patch("services.feedback_service.FeedbackService.export_feedbacks") - def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model): - mock_export.return_value = {"exported": True} - - with setup_test_context( - app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - resp = api.get(**v_args) - assert resp == {"exported": True} - - def test_message_api_get_success(self, app, mock_account, mock_app_model): - with setup_test_context( - app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model - ) as (api, mock_db, v_args): - mock_msg = MagicMock() - mock_msg.id = "msg_123" - mock_msg.feedbacks = [] - mock_msg.annotation = None - mock_msg.annotation_hit_history = None - mock_msg.agent_thoughts = [] - mock_msg.message_files = [] - mock_msg.extra_contents = [] - mock_msg.message = {} - mock_msg.message_metadata_dict = {} - - mock_db.session.scalar.return_value = mock_msg - - resp = api.get(**v_args) - assert resp["id"] == "msg_123" diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 61d92bb5c7..a0e2edb8cf 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -92,10 +92,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc ) session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.first.return_value = original_config - session.query.return_value = query + session.get.return_value = original_config monkeypatch.setattr(model_config_module.db, "session", session) monkeypatch.setattr( diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py deleted file mode 100644 index beba23385d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_statistic.py +++ /dev/null @@ -1,275 +0,0 @@ -from decimal import Decimal -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.statistic import ( - AverageResponseTimeStatistic, - AverageSessionInteractionStatistic, - DailyConversationStatistic, - DailyMessageStatistic, - DailyTerminalsStatistic, - DailyTokenCostStatistic, - TokensPerSecondStatistic, - UserSatisfactionRateStatistic, -) -from models import App, AppMode - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.CHAT - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context( - test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None) -): - with ( - patch("controllers.console.app.statistic.db") as mock_db_stat, - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch( - "controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123") - ), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_conn = MagicMock() - mock_conn.execute.return_value = mock_rs - - mock_begin = MagicMock() - mock_begin.__enter__.return_value = mock_conn - mock_db_stat.engine.begin.return_value = mock_begin - - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method="GET"): - request.view_args = {"app_id": "app_123"} - api_instance = endpoint_class() - response = api_instance.get(app_id="app_123") - return response - - -class TestStatisticEndpoints: - def test_daily_message_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["message_count"] == 10 - - def test_daily_conversation_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.conversation_count = 5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyConversationStatistic, - "/apps/app_123/statistics/daily-conversations", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["conversation_count"] == 5 - - def test_daily_terminals_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.terminal_count = 2 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTerminalsStatistic, - "/apps/app_123/statistics/daily-end-users", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["terminal_count"] == 2 - - def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.token_count = 100 - mock_row.total_price = Decimal("0.02") - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - DailyTokenCostStatistic, - "/apps/app_123/statistics/token-costs", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["token_count"] == 100 - assert response.json["data"][0]["total_price"] == "0.02" - - def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.interactions = Decimal("3.523") - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageSessionInteractionStatistic, - "/apps/app_123/statistics/average-session-interactions", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["interactions"] == 3.52 - - def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.message_count = 100 - mock_row.feedback_count = 10 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - UserSatisfactionRateStatistic, - "/apps/app_123/statistics/user-satisfaction-rate", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["rate"] == 100.0 - - def test_average_response_time_statistic(self, app, mock_account, mock_app_model): - mock_app_model.mode = AppMode.COMPLETION - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.latency = 1.234 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - AverageResponseTimeStatistic, - "/apps/app_123/statistics/average-response-time", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["latency"] == 1234.0 - - def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model): - mock_row = MagicMock() - mock_row.date = "2023-01-01" - mock_row.tokens_per_second = 15.5 - mock_row.interactions = Decimal(0) - - with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): - response = setup_test_context( - app, - TokensPerSecondStatistic, - "/apps/app_123/statistics/tokens-per-second", - mock_account, - mock_app_model, - [mock_row], - ) - assert response.status_code == 200 - assert response.json["data"][0]["tps"] == 15.5 - - @patch("controllers.console.app.statistic.parse_time_range") - def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model): - mock_parse.side_effect = ValueError("Invalid time") - - from werkzeug.exceptions import BadRequest - - with pytest.raises(BadRequest): - setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=invalid&end=invalid", - mock_account, - mock_app_model, - [], - ) - - @patch("controllers.console.app.statistic.parse_time_range") - def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model): - import datetime - - start = datetime.datetime.now() - end = datetime.datetime.now() - mock_parse.return_value = (start, end) - - response = setup_test_context( - app, - DailyMessageStatistic, - "/apps/app_123/statistics/daily-messages?start=something&end=something", - mock_account, - mock_app_model, - [], - ) - assert response.status_code == 200 - mock_parse.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 0e22db9f9b..341efc05ca 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -9,8 +9,8 @@ from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File def _unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py deleted file mode 100644 index 9b5d47c208..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py +++ /dev/null @@ -1,313 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask, request -from werkzeug.local import LocalProxy - -from controllers.console.app.error import DraftWorkflowNotExist -from controllers.console.app.workflow_draft_variable import ( - ConversationVariableCollectionApi, - EnvironmentVariableCollectionApi, - NodeVariableCollectionApi, - SystemVariableCollectionApi, - VariableApi, - VariableResetApi, - WorkflowVariableCollectionApi, -) -from controllers.web.error import InvalidArgumentError, NotFoundError -from models import App, AppMode -from models.enums import DraftVariableType - - -@pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" - return flask_app - - -@pytest.fixture -def mock_account(): - from models.account import Account, AccountStatus - - account = MagicMock(spec=Account) - account.id = "user_123" - account.timezone = "UTC" - account.status = AccountStatus.ACTIVE - account.is_admin_or_owner = True - account.current_tenant.current_role = "owner" - account.has_edit_permission = True - return account - - -@pytest.fixture -def mock_app_model(): - app_model = MagicMock(spec=App) - app_model.id = "app_123" - app_model.mode = AppMode.WORKFLOW - app_model.tenant_id = "tenant_123" - return app_model - - -@pytest.fixture(autouse=True) -def mock_csrf(): - with patch("libs.login.check_csrf_token") as mock: - yield mock - - -def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None): - with ( - patch("controllers.console.app.wraps.db") as mock_db_wraps, - patch("controllers.console.wraps.db", mock_db_wraps), - patch("controllers.console.app.workflow_draft_variable.db"), - patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - ): - mock_query = MagicMock() - mock_query.filter.return_value.first.return_value = mock_app_model - mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model - mock_query.where.return_value.first.return_value = mock_app_model - mock_query.where.return_value.where.return_value.first.return_value = mock_app_model - mock_db_wraps.session.query.return_value = mock_query - - proxy_mock = LocalProxy(lambda: mock_account) - - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - with test_app.test_request_context(route_path, method=method, json=payload): - request.view_args = {"app_id": "app_123"} - # extract node_id or variable_id from path manually since view_args overrides - if "nodes/" in route_path: - request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0] - if "variables/" in route_path: - # simplistic extraction - parts = route_path.split("variables/") - if len(parts) > 1 and parts[1] and parts[1] != "reset": - request.view_args["variable_id"] = parts[1].split("/")[0] - - api_instance = endpoint_class() - # we just call dispatch_request to avoid manual argument passing - if hasattr(api_instance, method.lower()): - func = getattr(api_instance, method.lower()) - return func(**request.view_args) - - -class TestWorkflowDraftVariableEndpoints: - @staticmethod - def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock: - class DummyValueType: - def exposed_type(self): - return DraftVariableType.NODE - - mock_var = MagicMock() - mock_var.app_id = "app_123" - mock_var.id = "var_123" - mock_var.name = "test_var" - mock_var.description = "" - mock_var.get_variable_type.return_value = variable_type - mock_var.get_selector.return_value = [] - mock_var.value_type = DummyValueType() - mock_var.edited = False - mock_var.visible = True - mock_var.file_id = None - mock_var.variable_file = None - mock_var.is_truncated.return_value = False - mock_var.get_value.return_value.model_copy.return_value.value = "test_value" - return mock_var - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_workflow_variable_collection_get_success( - self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model - ): - mock_wf_srv.return_value.is_workflow_exist.return_value = True - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList( - variables=[], total=0 - ) - - resp = setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables?page=1&limit=20", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": [], "total": 0} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.is_workflow_exist.return_value = False - - with pytest.raises(DraftWorkflowNotExist): - setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables", - "GET", - mock_account, - mock_app_model, - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): - resp = setup_test_context( - app, - WorkflowVariableCollectionApi, - "/apps/app_123/workflows/draft/variables", - "DELETE", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[]) - resp = setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/node_123/variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model): - with pytest.raises(InvalidArgumentError): - setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/sys/variables", - "GET", - mock_account, - mock_app_model, - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): - resp = setup_test_context( - app, - NodeVariableCollectionApi, - "/apps/app_123/workflows/draft/nodes/node_123/variables", - "DELETE", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model - ) - assert resp["id"] == "var_123" - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = None - - with pytest.raises(NotFoundError): - setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model - ) - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, - VariableApi, - "/apps/app_123/workflows/draft/variables/var_123", - "PATCH", - mock_account, - mock_app_model, - payload={"name": "new_name"}, - ) - assert resp["id"] == "var_123" - mock_draft_srv.return_value.update_variable.assert_called_once() - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model): - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - - resp = setup_test_context( - app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model - ) - assert resp.status_code == 204 - mock_draft_srv.return_value.delete_variable.assert_called_once() - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() - mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() - mock_draft_srv.return_value.reset_variable.return_value = None # means no content - - resp = setup_test_context( - app, - VariableResetApi, - "/apps/app_123/workflows/draft/variables/var_123/reset", - "PUT", - mock_account, - mock_app_model, - ) - assert resp.status_code == 204 - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[]) - - resp = setup_test_context( - app, - ConversationVariableCollectionApi, - "/apps/app_123/workflows/draft/conversation-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") - def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model): - from services.workflow_draft_variable_service import WorkflowDraftVariableList - - mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[]) - - resp = setup_test_context( - app, - SystemVariableCollectionApi, - "/apps/app_123/workflows/draft/system-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} - - @patch("controllers.console.app.workflow_draft_variable.WorkflowService") - def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model): - mock_wf = MagicMock() - mock_wf.environment_variables = [] - mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf - - resp = setup_test_context( - app, - EnvironmentVariableCollectionApi, - "/apps/app_123/workflows/draft/environment-variables", - "GET", - mock_account, - mock_app_model, - ) - assert resp == {"items": []} diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 83601dc1b9..c4a8148446 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -10,10 +10,10 @@ from flask import Flask from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun @@ -67,7 +67,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte actions=[UserAction(id="approve", title="Approve")], node_id="node-1", node_title="Ask Name", - form_token="backstage-token", ) pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) @@ -78,6 +77,11 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte "create_api_workflow_run_repository", lambda *_, **__: repo, ) + monkeypatch.setattr( + workflow_run_module, + "_load_form_tokens_by_form_id", + lambda _form_ids: {"form-1": "backstage-token"}, + ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index 7664e492da..b5f751f5a5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -11,10 +11,8 @@ from models.model import AppMode def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model def handler(app_model): @@ -25,10 +23,8 @@ def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") - query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) - monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(scalar=lambda *_args, **_kwargs: app_model)) @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) def handler(app_model): diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f34702a257..559b5fea09 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,9 +13,9 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -310,13 +310,12 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file.enums import FileTransferMethod, FileType + from graphon.file.models import File # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", @@ -368,13 +367,12 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file.enums import FileTransferMethod, FileType + from graphon.file.models import File # Create a File object with REMOTE_URL transfer method test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py deleted file mode 100644 index bc4c7e0993..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ /dev/null @@ -1,209 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask - -from controllers.console.auth.data_source_bearer_auth import ( - ApiKeyAuthDataSource, - ApiKeyAuthDataSourceBinding, - ApiKeyAuthDataSourceBindingDelete, -) -from controllers.console.auth.error import ApiKeyAuthFailedError - - -class TestApiKeyAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_binding = MagicMock() - mock_binding.id = "bind_123" - mock_binding.category = "api_key" - mock_binding.provider = "custom_provider" - mock_binding.disabled = False - mock_binding.created_at.timestamp.return_value = 1620000000 - mock_binding.updated_at.timestamp.return_value = 1620000001 - - mock_get_list.return_value = [mock_binding] - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 1 - assert response["sources"][0]["provider"] == "custom_provider" - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") - def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_get_list.return_value = None - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSource() - response = api_instance.get() - - assert "sources" in response - assert len(response["sources"]) == 0 - - -class TestApiKeyAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - response = api_instance.post() - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_validate.assert_called_once() - mock_create.assert_called_once() - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") - def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - mock_create.side_effect = ValueError("Invalid structure") - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context( - "/console/api/api-key-auth/data-source/binding", - method="POST", - json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, - ): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBinding() - with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): - api_instance.post() - - -class TestApiKeyAuthDataSourceBindingDelete: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - app.config["WTF_CSRF_ENABLED"] = False - return app - - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth") - def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app): - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with ( - patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), - patch( - "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", - return_value=(mock_account, "tenant_123"), - ), - ): - with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): - proxy_mock = MagicMock() - proxy_mock._get_current_object.return_value = mock_account - with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): - api_instance = ApiKeyAuthDataSourceBindingDelete() - response = api_instance.delete("binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 204 - mock_delete.assert_called_once_with("tenant_123", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py deleted file mode 100644 index f369565946..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py +++ /dev/null @@ -1,192 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.local import LocalProxy - -from controllers.console.auth.data_source_oauth import ( - OAuthDataSource, - OAuthDataSourceBinding, - OAuthDataSourceCallback, - OAuthDataSourceSync, -) - - -class TestOAuthDataSource: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("flask_login.current_user") - @patch("libs.login.current_user") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) - def test_get_oauth_url_successful( - self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app - ): - mock_oauth_provider = MagicMock() - mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" - mock_get_providers.return_value = {"notion": mock_oauth_provider} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - mock_libs_user.return_value = mock_account - mock_flask_user.return_value = mock_account - - # also patch current_account_with_tenant - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("notion") - - assert response[0]["data"] == "http://oauth.provider/auth" - assert response[1] == 200 - mock_oauth_provider.get_authorization_url.assert_called_once() - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("flask_login.current_user") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSource() - response = api_instance.get("unknown_provider") - - assert response[0]["error"] == "Invalid provider" - assert response[1] == 400 - - -class TestOAuthDataSourceCallback: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_successful(self, mock_get_providers, app): - provider_mock = MagicMock() - mock_get_providers.return_value = {"notion": provider_mock} - - with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("notion") - - assert response.status_code == 302 - assert "code=mock_code" in response.location - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_missing_code(self, mock_get_providers, app): - provider_mock = MagicMock() - mock_get_providers.return_value = {"notion": provider_mock} - - with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("notion") - - assert response.status_code == 302 - assert "error=Access denied" in response.location - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_oauth_callback_invalid_provider(self, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"): - api_instance = OAuthDataSourceCallback() - response = api_instance.get("invalid") - - assert response[0]["error"] == "Invalid provider" - assert response[1] == 400 - - -class TestOAuthDataSourceBinding: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_get_binding_successful(self, mock_get_providers, app): - mock_provider = MagicMock() - mock_provider.get_access_token.return_value = None - mock_get_providers.return_value = {"notion": mock_provider} - - with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"): - api_instance = OAuthDataSourceBinding() - response = api_instance.get("notion") - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_provider.get_access_token.assert_called_once_with("auth_code_123") - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - def test_get_binding_missing_code(self, mock_get_providers, app): - mock_get_providers.return_value = {"notion": MagicMock()} - - with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"): - api_instance = OAuthDataSourceBinding() - response = api_instance.get("notion") - - assert response[0]["error"] == "Invalid code" - assert response[1] == 400 - - -class TestOAuthDataSourceSync: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") - @patch("libs.login.check_csrf_token") - @patch("controllers.console.wraps.db") - def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app): - mock_provider = MagicMock() - mock_provider.sync_data_source.return_value = None - mock_get_providers.return_value = {"notion": mock_provider} - - from models.account import Account, AccountStatus - - mock_account = MagicMock(spec=Account) - mock_account.id = "user_123" - mock_account.status = AccountStatus.ACTIVE - mock_account.is_admin_or_owner = True - mock_account.current_tenant.current_role = "owner" - - with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): - with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): - proxy_mock = LocalProxy(lambda: mock_account) - with patch("libs.login.current_user", proxy_mock): - api_instance = OAuthDataSourceSync() - # The route pattern uses , so we just pass a string for unit testing - response = api_instance.get("notion", "binding_123") - - assert response[0]["result"] == "success" - assert response[1] == 200 - mock_provider.sync_data_source.assert_called_once_with("binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py deleted file mode 100644 index fc5663e72d..0000000000 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py +++ /dev/null @@ -1,417 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from werkzeug.exceptions import BadRequest, NotFound - -from controllers.console.auth.oauth_server import ( - OAuthServerAppApi, - OAuthServerUserAccountApi, - OAuthServerUserAuthorizeApi, - OAuthServerUserTokenApi, -) - - -class TestOAuthServerAppApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.redirect_uris = ["http://localhost/callback"] - oauth_app.app_icon = "icon_url" - oauth_app.app_label = "Test App" - oauth_app.scope = "read,write" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - response = api_instance.post() - - assert response["app_icon"] == "icon_url" - assert response["app_label"] == "Test App" - assert response["scope"] == "read,write" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_client_id(self, mock_get_app, mock_db, app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = None - - with app.test_request_context( - "/oauth/provider", - method="POST", - json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, - ): - api_instance = OAuthServerAppApi() - with pytest.raises(NotFound, match="client_id is invalid"): - api_instance.post() - - -class TestOAuthServerUserAuthorizeApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - oauth_app = MagicMock() - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.current_account_with_tenant") - @patch("controllers.console.wraps.current_account_with_tenant") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code") - @patch("libs.login.check_csrf_token") - def test_successful_authorize( - self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app - ): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.id = "user_123" - from models.account import AccountStatus - - mock_account.status = AccountStatus.ACTIVE - - mock_current.return_value = (mock_account, MagicMock()) - mock_wrap_current.return_value = (mock_account, MagicMock()) - - mock_sign.return_value = "auth_code_123" - - with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): - with patch("libs.login.current_user", mock_account): - api_instance = OAuthServerUserAuthorizeApi() - response = api_instance.post() - - assert response["code"] == "auth_code_123" - mock_sign.assert_called_once_with("test_client_id", "user_123") - - -class TestOAuthServerUserTokenApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - oauth_app.client_secret = "test_secret" - oauth_app.redirect_uris = ["http://localhost/callback"] - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("access_123", "refresh_123") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "access_123" - assert response["refresh_token"] == "refresh_123" - assert response["token_type"] == "Bearer" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "client_secret": "test_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="code is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "invalid_secret", - "redirect_uri": "http://localhost/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="client_secret is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "authorization_code", - "code": "auth_code", - "client_secret": "test_secret", - "redirect_uri": "http://invalid/callback", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="redirect_uri is invalid"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") - def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_sign.return_value = ("new_access", "new_refresh") - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, - ): - api_instance = OAuthServerUserTokenApi() - response = api_instance.post() - - assert response["access_token"] == "new_access" - assert response["refresh_token"] == "new_refresh" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "refresh_token", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="refresh_token is required"): - api_instance.post() - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/token", - method="POST", - json={ - "client_id": "test_client_id", - "grant_type": "invalid_grant", - }, - ): - api_instance = OAuthServerUserTokenApi() - with pytest.raises(BadRequest, match="invalid grant_type"): - api_instance.post() - - -class TestOAuthServerUserAccountApi: - @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @pytest.fixture - def mock_oauth_provider_app(self): - from models.model import OAuthProviderApp - - oauth_app = MagicMock(spec=OAuthProviderApp) - oauth_app.client_id = "test_client_id" - return oauth_app - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - mock_account = MagicMock() - mock_account.name = "Test User" - mock_account.email = "test@example.com" - mock_account.avatar = "avatar_url" - mock_account.interface_language = "en-US" - mock_account.timezone = "UTC" - mock_validate.return_value = mock_account - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer valid_access_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response["name"] == "Test User" - assert response["email"] == "test@example.com" - assert response["avatar"] == "avatar_url" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Authorization header is required" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "InvalidFormat"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Basic something"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "token_type is invalid" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer "}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "Invalid Authorization header format" - - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") - @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") - def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): - mock_db.session.query.return_value.first.return_value = MagicMock() - mock_get_app.return_value = mock_oauth_provider_app - mock_validate.return_value = None - - with app.test_request_context( - "/oauth/provider/account", - method="POST", - json={"client_id": "test_client_id"}, - headers={"Authorization": "Bearer invalid_token"}, - ): - api_instance = OAuthServerUserAccountApi() - response = api_instance.post() - - assert response.status_code == 401 - assert response.json["error"] == "access_token or client_id is invalid" diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9014edc39e..5136922e88 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -17,7 +17,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index b4c0903f63..63950736c5 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -14,8 +14,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor RagPipelineVariableResetApi, ) from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 3060062adf..d841f67f9b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,6 +11,7 @@ from controllers.console.datasets.data_source import ( DataSourceNotionDocumentSyncApi, DataSourceNotionListApi, ) +from core.rag.index_processor.constant.index_type import IndexStructureType def unwrap(func): @@ -343,7 +344,7 @@ class TestDataSourceNotionApi: } ], "process_rule": {"rules": {}}, - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 0ee76e504b..8555900f4e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -28,6 +28,7 @@ from controllers.console.datasets.datasets import ( from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import ApiToken, UploadFile @@ -416,7 +417,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -520,7 +521,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -579,7 +580,7 @@ class TestDatasetApiGet: "get_dataset_partial_member_list", return_value=partial_members, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -1146,7 +1147,7 @@ class TestDatasetIndexingEstimateApi: }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", - "doc_form": "text_model", + "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } @@ -1475,8 +1476,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), ): response, status = method(api, "dataset-1") @@ -1525,13 +1526,6 @@ class TestDatasetIndexingStatusApi: document.error = None document.stopped_at = None - # First count = completed segments, second = total segments - query_mock = MagicMock() - query_mock.where.side_effect = [ - MagicMock(count=lambda: 2), - MagicMock(count=lambda: 5), - ] - with ( app.test_request_context("/"), patch( @@ -1543,8 +1537,8 @@ class TestDatasetIndexingStatusApi: return_value=MagicMock(all=lambda: [document]), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets.db.session.scalar", + side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") @@ -1590,8 +1584,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken.generate_api_key", @@ -1624,8 +1618,8 @@ class TestDatasetApiKeyApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: @@ -1652,8 +1646,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", @@ -1680,8 +1674,8 @@ class TestDatasetApiDeleteApi: return_value=(MagicMock(), "tenant-1"), ), patch( - "controllers.console.datasets.datasets.db.session.query", - return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f23dd5b44a..ce2278de4f 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -30,6 +30,7 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import DataSourceType, IndexingStatus @@ -66,7 +67,7 @@ def document(): indexing_status=IndexingStatus.INDEXING, data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, archived=False, is_paused=False, dataset_process_rule=None, @@ -139,8 +140,8 @@ class TestDatasetDocumentListApi: return_value=pagination, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=2, ), patch( "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", @@ -699,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=log, ), ): response, status = method(api, "ds-1", "doc-1") @@ -765,8 +764,8 @@ class TestDocumentGenerateSummaryApi: summary_index_setting={"enable": True}, ) - doc1 = MagicMock(id="doc-1", doc_form="qa_model") - doc2 = MagicMock(id="doc-2", doc_form="text") + doc1 = MagicMock(id="doc-1", doc_form=IndexStructureType.QA_INDEX) + doc2 = MagicMock(id="doc-2", doc_form=IndexStructureType.PARAGRAPH_INDEX) payload = {"document_list": ["doc-1", "doc-2"]} @@ -822,19 +821,16 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) - query_mock = MagicMock() - query_mock.where.return_value.first.return_value = None - with ( app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -849,7 +845,7 @@ class TestDocumentIndexingEstimateApi: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -862,10 +858,8 @@ class TestDocumentIndexingEstimateApi: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", @@ -973,7 +967,7 @@ class TestDocumentBatchIndexingEstimateApi: "mode": "single", "only_main_content": True, }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1001,7 +995,7 @@ class TestDocumentBatchIndexingEstimateApi: "notion_page_id": "p1", "type": "page", }, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with ( @@ -1024,7 +1018,7 @@ class TestDocumentBatchIndexingEstimateApi: indexing_status=IndexingStatus.INDEXING, data_source_type="unknown", data_source_info_dict={}, - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): @@ -1238,12 +1232,8 @@ class TestDocumentPermissionCases: return_value=None, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=lambda *a: MagicMock( - order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) - ) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=process_rule, ), ): result = method(api) @@ -1353,7 +1343,7 @@ class TestDocumentIndexingEdgeCases: data_source_type=DataSourceType.UPLOAD_FILE, data_source_info_dict={"upload_file_id": "file-1"}, tenant_id="tenant-1", - doc_form="text", + doc_form=IndexStructureType.PARAGRAPH_INDEX, dataset_process_rule=None, ) @@ -1363,8 +1353,8 @@ class TestDocumentIndexingEdgeCases: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index e67e4daad9..306a772fd1 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -24,6 +24,7 @@ from controllers.console.datasets.error import ( InvalidActionError, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -366,7 +367,7 @@ class TestDatasetDocumentSegmentAddApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() segment.id = "seg-1" @@ -505,7 +506,7 @@ class TestDatasetDocumentSegmentUpdateApi: dataset.indexing_technique = "economy" document = MagicMock() - document.doc_form = "text" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX segment = MagicMock() @@ -525,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -620,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -705,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -737,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(ValueError): @@ -769,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.redis_client.setnx", @@ -830,8 +831,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -879,8 +880,8 @@ class TestChildChunkAddApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=segment, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -923,11 +924,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -969,11 +967,8 @@ class TestChildChunkUpdateApi: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - side_effect=[ - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), - MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), - ], + "controllers.console.datasets.datasets_segments.db.session.scalar", + side_effect=[segment, child_chunk], ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", @@ -1179,8 +1174,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), ): with pytest.raises(NotFound): @@ -1214,8 +1209,8 @@ class TestSegmentOperationCases: return_value=document, ), patch( - "controllers.console.datasets.datasets_segments.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_segments.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e7ae37ae45..e4acd91b76 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -20,7 +20,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py index 90f00711c1..e358435de4 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -26,12 +26,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = None - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=None, ) with pytest.raises(PipelineNotFoundError): @@ -51,12 +48,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -76,12 +70,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -100,18 +91,15 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - def where_side_effect(*args, **kwargs): - assert args[0].right.value == "123" - return Mock(first=lambda: pipeline) - - mock_query = Mock() - mock_query.where.side_effect = where_side_effect - - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + mock_scalar = mocker.patch( + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id=123) assert result is pipeline + # Verify the pipeline_id was cast to string in the where clause + stmt = mock_scalar.call_args[0][0] + where_clauses = stmt.whereclause.clauses + assert where_clauses[0].right.value == "123" diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 0afbc5a8f7..b4b57022e2 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -19,7 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 6b5c304884..145cc9cdd7 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -21,7 +21,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 5a03daecbc..03eadcdb4e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -25,7 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py index c18dd044a7..2dff9c4037 100644 --- a/api/tests/unit_tests/controllers/console/test_apikey.py +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -8,6 +8,7 @@ from controllers.console.apikey import ( BaseApiKeyResource, _get_resource, ) +from models.enums import ApiTokenType @pytest.fixture @@ -45,14 +46,14 @@ def bypass_permissions(): class DummyApiKeyListResource(BaseApiKeyListResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" token_prefix = "app-" class DummyApiKeyResource(BaseApiKeyResource): - resource_type = "app" + resource_type = ApiTokenType.APP resource_model = MagicMock() resource_id_field = "app_id" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index f2e57eb65f..b2f949c6e2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -13,8 +13,8 @@ from flask import Flask from flask.views import MethodView from werkzeug.exceptions import Forbidden -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index af0c2c5594..168479af1e 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -13,7 +13,7 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index 43b8e1ac2e..f0d32f81fb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -14,8 +14,8 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 94c3019d5e..44feacf2ad 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -4,7 +4,7 @@ from __future__ import annotations import builtins import importlib -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from types import ModuleType, SimpleNamespace from unittest.mock import MagicMock, patch @@ -18,7 +18,6 @@ if not hasattr(builtins, "MethodView"): _CONTROLLER_MODULE: ModuleType | None = None _WRAPS_MODULE: ModuleType | None = None -_CONTROLLER_PATCHERS: list[patch] = [] @contextmanager @@ -37,6 +36,14 @@ def app() -> Flask: @pytest.fixture def controller_module(monkeypatch: pytest.MonkeyPatch): + """ + Import the controller with auth decorators neutralized only during import. + + The imported view classes retain those no-op decorators after import, so we + can restore the original globals immediately and avoid leaking auth patches + into unrelated tests such as libs.login unit coverage. + """ + module_name = "controllers.console.workspace.tool_providers" global _CONTROLLER_MODULE if _CONTROLLER_MODULE is None: @@ -51,13 +58,12 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): ("controllers.console.wraps.is_admin_or_owner_required", _noop), ("controllers.console.wraps.enterprise_license_required", _noop), ] - for target, value in patch_targets: - patcher = patch(target, value) - patcher.start() - _CONTROLLER_PATCHERS.append(patcher) monkeypatch.setenv("DIFY_SETUP_READY", "true") - with _mock_db(): - _CONTROLLER_MODULE = importlib.import_module(module_name) + with ExitStack() as stack: + for target, value in patch_targets: + stack.enter_context(patch(target, value)) + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) module = _CONTROLLER_MODULE monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) diff --git a/api/tests/unit_tests/controllers/files/test_tool_files.py b/api/tests/unit_tests/controllers/files/test_tool_files.py index e5df7a1eea..edb91c3f26 100644 --- a/api/tests/unit_tests/controllers/files/test_tool_files.py +++ b/api/tests/unit_tests/controllers/files/test_tool_files.py @@ -18,10 +18,10 @@ def fake_request(args: dict): class DummyToolFile: - def __init__(self, mimetype="text/plain", size=10, name="tool.txt"): - self.mimetype = mimetype + def __init__(self, mime_type="text/plain", size=10, filename="tool.txt"): + self.mime_type = mime_type self.size = size - self.name = name + self.filename = filename @pytest.fixture(autouse=True) @@ -87,8 +87,8 @@ class TestToolFileApi: stream = iter([b"data"]) tool_file = DummyToolFile( - mimetype="application/pdf", - name="doc.pdf", + mime_type="application/pdf", + filename="doc.pdf", ) mock_tool_file_manager.return_value.get_file_generator_by_tool_file_id.return_value = ( diff --git a/api/tests/unit_tests/controllers/inner_api/app/__init__.py b/api/tests/unit_tests/controllers/inner_api/app/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py new file mode 100644 index 0000000000..5862239142 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -0,0 +1,245 @@ +"""Unit tests for inner_api app DSL import/export endpoints. + +Tests Pydantic model validation, endpoint handler logic, and the +_get_active_account helper. Auth/setup decorators are tested separately +in test_auth_wraps.py; handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.app.dsl import ( + EnterpriseAppDSLExport, + EnterpriseAppDSLImport, + InnerAppDSLImportPayload, + _get_active_account, +) +from services.app_dsl_service import ImportStatus + + +class TestInnerAppDSLImportPayload: + """Test InnerAppDSLImportPayload Pydantic model validation.""" + + def test_valid_payload_all_fields(self): + data = { + "yaml_content": "version: 0.6.0\nkind: app\n", + "creator_email": "user@example.com", + "name": "My App", + "description": "A test app", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.yaml_content == data["yaml_content"] + assert payload.creator_email == "user@example.com" + assert payload.name == "My App" + assert payload.description == "A test app" + + def test_valid_payload_optional_fields_omitted(self): + data = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.name is None + assert payload.description is None + + def test_missing_yaml_content_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"creator_email": "a@b.com"}) + assert "yaml_content" in str(exc_info.value) + + def test_missing_creator_email_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"yaml_content": "test"}) + assert "creator_email" in str(exc_info.value) + + +class TestGetActiveAccount: + """Test the _get_active_account helper function.""" + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_active_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "active" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + + result = _get_active_account("user@example.com") + + assert result is mock_account + mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com") + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_inactive_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "banned" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account + + result = _get_active_account("banned@example.com") + + assert result is None + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_nonexistent_email(self, mock_db): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + result = _get_active_account("missing@example.com") + + assert result is None + + +class TestEnterpriseAppDSLImport: + """Test EnterpriseAppDSLImport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLImport() + + @pytest.fixture + def _mock_import_deps(self): + """Patch db, Session, and AppDslService for import handler tests.""" + with ( + patch("controllers.inner_api.app.dsl.db"), + patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, + ): + mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_session.return_value.__exit__ = MagicMock(return_value=False) + self._mock_dsl = MagicMock() + mock_dsl_cls.return_value = self._mock_dsl + yield + + def _make_import_result(self, status: ImportStatus, **kwargs) -> "Import": + from services.app_dsl_service import Import + + result = Import( + id="import-id", + status=status, + app_id=kwargs.get("app_id", "app-123"), + app_mode=kwargs.get("app_mode", "workflow"), + ) + return result + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_success_returns_200(self, mock_get_account, api_instance, app: Flask): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.COMPLETED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 200 + assert body["status"] == "completed" + mock_account.set_tenant_id.assert_called_once_with("ws-123") + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_pending_returns_202(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.PENDING) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 202 + assert body["status"] == "pending" + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_failed_returns_400(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.FAILED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 400 + assert body["status"] == "failed" + + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = None + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "missing@e.com"} + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 404 + assert "missing@e.com" in body["message"] + + +class TestEnterpriseAppDSLExport: + """Test EnterpriseAppDSLExport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLExport() + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + assert body["data"] == "version: 0.6.0\nkind: app\n" + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=False) + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "yaml-data" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=true"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=True) + + @patch("controllers.inner_api.app.dsl.db") + def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="nonexistent") + + body, status_code = result + assert status_code == 404 + assert "app not found" in body["message"] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f8e9cf9b80..1507bf7a5f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -65,7 +65,7 @@ class TestAppParameterApi: mock_tenant.status = "normal" # Mock DB queries for app and tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -112,7 +112,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -153,7 +153,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -192,7 +192,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -255,7 +255,7 @@ class TestAppMetaApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -323,7 +323,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -380,7 +380,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -426,7 +426,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -478,7 +478,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1923ab7fa7..e81e612803 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -29,7 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 4e4482f704..3364c07e62 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -34,7 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 1bdcd0f1a3..d83c22f2cf 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -79,10 +79,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -100,8 +103,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + # Mock MessageFile not found via scalar() + mock_db.session.scalar.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -115,8 +118,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile found but Message not owned by app - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock MessageFile found but Message not owned by app via scalar() + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found None, # Message query - not found (access denied) ] @@ -133,12 +136,13 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile and Message found but UploadFile not found - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found mock_message, # Message query - found - None, # UploadFile query - not found ] + # Mock get() for UploadFile - not found + mock_db.session.get.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -161,10 +165,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -262,10 +269,13 @@ class TestFilePreviewApi: mock_storage.load.return_value = mock_generator with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -301,10 +311,13 @@ class TestFilePreviewApi: mock_storage.load.side_effect = Exception("Storage error") with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries for validation - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -327,8 +340,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database query to raise unexpected exception - mock_db.session.query.side_effect = Exception("Unexpected database error") + # Mock database scalar to raise unexpected exception + mock_db.session.scalar.side_effect = Exception("Unexpected database error") # Execute and assert exception with pytest.raises(FileAccessDeniedError) as exc_info: diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 4eada73b82..6543c27037 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -35,7 +35,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -315,7 +315,7 @@ class TestWorkflowStopMechanism: def test_graph_engine_manager_has_send_stop_command(self): """Test GraphEngineManager has send_stop_command method.""" - from dify_graph.graph_engine.manager import GraphEngineManager + from graphon.graph_engine.manager import GraphEngineManager assert hasattr(GraphEngineManager, "send_stop_command") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 9e95f45a0a..eda270258d 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,7 +1,7 @@ from types import SimpleNamespace from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from dify_graph.enums import WorkflowExecutionStatus +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 4337a0c8c0..eddba5a517 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -12,6 +12,7 @@ from unittest.mock import Mock import pytest from flask import Flask +from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -118,11 +119,8 @@ class AuthenticationMocker: @staticmethod def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None): - """Configure mock_db to return app and tenant in sequence.""" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - ] + """Configure mock_db to return app and tenant via session.get().""" + mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: mock_ta = Mock() @@ -135,11 +133,9 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_query = mock_db.session.query.return_value - target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value - target_mock.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + mock_db.session.get.return_value = mock_account @pytest.fixture @@ -175,7 +171,7 @@ def mock_document(): document.name = "test_document.txt" document.indexing_status = "completed" document.enabled = True - document.doc_form = "text_model" + document.doc_form = IndexStructureType.PARAGRAPH_INDEX return document diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 8fe41cd19f..910d781cd0 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -942,11 +942,11 @@ class TestDatasetListApiGet: """Test suite for DatasetListApi.get() endpoint. ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``ProviderManager``, and ``marshal``. + ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. """ @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_list_datasets_success( @@ -1044,12 +1044,12 @@ class TestDatasetApiGet: """Test suite for DatasetApi.get() endpoint. ``get`` has no billing decorators but calls ``DatasetService``, - ``ProviderManager``, ``marshal``, and ``current_user``. + ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. """ @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_success( diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5c48ef1804..7f5d6b0839 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -31,6 +31,7 @@ from controllers.service_api.dataset.segment import ( SegmentCreatePayload, SegmentListQuery, ) +from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import ChildChunk, Dataset, Document, DocumentSegment from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -787,8 +788,8 @@ class TestSegmentApiGet: """Test successful segment list retrieval.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc_svc.get_document.return_value = Mock(doc_form="text_model") + mock_db.session.scalar.return_value = mock_dataset + mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -812,7 +813,7 @@ class TestSegmentApiGet: """Test 404 when dataset not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -832,7 +833,7 @@ class TestSegmentApiGet: """Test 404 when document not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None # Act & Assert @@ -898,12 +899,12 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.segment_create_args_validate.return_value = None @@ -949,7 +950,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -991,7 +992,7 @@ class TestSegmentApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "indexing" # Not completed @@ -1042,7 +1043,7 @@ class TestDatasetSegmentApiDelete: """Test successful segment deletion.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock() @@ -1086,12 +1087,12 @@ class TestDatasetSegmentApiDelete: """Test 404 when segment not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" mock_doc.enabled = True - mock_doc.doc_form = "text_model" + mock_doc.doc_form = IndexStructureType.PARAGRAPH_INDEX mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = None # Segment not found @@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when dataset not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when document not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = mock_segment @@ -1279,7 +1280,7 @@ class TestDatasetSegmentApiUpdate: """Test 404 when dataset not found for update.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1320,7 +1321,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1369,9 +1370,9 @@ class TestDatasetSegmentApiGetSingle: ): """Test successful single segment retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None - mock_doc = Mock(doc_form="text_model") + mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc mock_seg_svc.get_segment_by_id.return_value = mock_segment mock_marshal.return_value = {"id": mock_segment.id} @@ -1390,7 +1391,7 @@ class TestDatasetSegmentApiGetSingle: assert status == 200 assert "data" in response - assert response["doc_form"] == "text_model" + assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX @patch("controllers.service_api.dataset.segment.current_account_with_tenant") @patch("controllers.service_api.dataset.segment.db") @@ -1404,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1435,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1470,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1514,7 +1515,7 @@ class TestChildChunkApiGet: ): """Test successful child chunk list retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() @@ -1553,7 +1554,7 @@ class TestChildChunkApiGet: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1582,7 +1583,7 @@ class TestChildChunkApiGet: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None with app.test_request_context( @@ -1614,7 +1615,7 @@ class TestChildChunkApiGet: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1675,7 +1676,7 @@ class TestChildChunkApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() mock_child = Mock() @@ -1716,7 +1717,7 @@ class TestChildChunkApiPost: """Test 404 when dataset not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1754,7 +1755,7 @@ class TestChildChunkApiPost: """Test 404 when segment not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1807,7 +1808,7 @@ class TestDatasetChildChunkApiDelete: ): """Test successful child chunk deletion.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc_svc.get_document.return_value = mock_doc @@ -1857,7 +1858,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1898,7 +1899,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when segment does not belong to the document.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1938,7 +1939,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk does not belong to the segment.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index e6e841be19..12d5e7345d 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from core.rag.index_processor.constant.index_type import IndexStructureType from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -52,7 +53,7 @@ class TestDocumentTextCreatePayload: def test_payload_with_defaults(self): """Test payload default values.""" payload = DocumentTextCreatePayload(name="Doc", text="Content") - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" assert payload.process_rule is None assert payload.indexing_technique is None @@ -62,14 +63,14 @@ class TestDocumentTextCreatePayload: payload = DocumentTextCreatePayload( name="Full Document", text="Complete document content here", - doc_form="qa_model", + doc_form=IndexStructureType.QA_INDEX, doc_language="Chinese", indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) assert payload.name == "Full Document" - assert payload.doc_form == "qa_model" + assert payload.doc_form == IndexStructureType.QA_INDEX assert payload.doc_language == "Chinese" assert payload.indexing_technique == "high_quality" assert payload.embedding_model == "text-embedding-ada-002" @@ -147,8 +148,8 @@ class TestDocumentTextUpdate: def test_payload_with_doc_form_update(self): """Test payload with doc_form update.""" - payload = DocumentTextUpdate(doc_form="qa_model") - assert payload.doc_form == "qa_model" + payload = DocumentTextUpdate(doc_form=IndexStructureType.QA_INDEX) + assert payload.doc_form == IndexStructureType.QA_INDEX def test_payload_with_language_update(self): """Test payload with doc_language update.""" @@ -158,7 +159,7 @@ class TestDocumentTextUpdate: def test_payload_default_values(self): """Test payload default values.""" payload = DocumentTextUpdate() - assert payload.doc_form == "text_model" + assert payload.doc_form == IndexStructureType.PARAGRAPH_INDEX assert payload.doc_language == "English" @@ -272,14 +273,24 @@ class TestDocumentDocForm: def test_text_model_form(self): """Test text_model form.""" - doc_form = "text_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.PARAGRAPH_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms def test_qa_model_form(self): """Test qa_model form.""" - doc_form = "qa_model" - valid_forms = ["text_model", "qa_model", "hierarchical_model", "parent_child_model"] + doc_form = IndexStructureType.QA_INDEX + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + "parent_child_model", + ] assert doc_form in valid_forms @@ -504,7 +515,7 @@ class TestDocumentApiGet: doc.name = "test_document.txt" doc.indexing_status = "completed" doc.enabled = True - doc.doc_form = "text_model" + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.doc_type = "book" doc.doc_metadata_details = {"source": "upload"} @@ -706,7 +717,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = False @@ -735,7 +746,7 @@ class TestDocumentApiDelete: document_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None @@ -756,7 +767,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = True @@ -777,7 +788,7 @@ class TestDocumentApiDelete: # Arrange dataset_id = str(uuid.uuid4()) document_id = str(uuid.uuid4()) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -798,7 +809,7 @@ class TestDocumentListApi: def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset): """Test successful document list retrieval.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_pagination = Mock() mock_pagination.items = [Mock(), Mock()] @@ -827,7 +838,7 @@ class TestDocumentListApi: def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset): """Test 404 when dataset not found.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -849,8 +860,6 @@ class TestDocumentIndexingStatusApi: """Test successful indexing status retrieval.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc = Mock() mock_doc.id = str(uuid.uuid4()) mock_doc.is_paused = False @@ -866,8 +875,8 @@ class TestDocumentIndexingStatusApi: mock_doc_svc.get_batch_documents.return_value = [mock_doc] - # Mock segment count queries - mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5 + # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count + mock_db.session.scalar.side_effect = [mock_dataset, 5, 5] mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"} # Act @@ -887,7 +896,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when dataset not found.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -904,7 +913,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when no documents found for batch.""" # Arrange batch_id = "batch_empty" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_batch_documents.return_value = [] # Act & Assert @@ -975,7 +984,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset.indexing_technique = "economy" mock_current_user.id = str(uuid.uuid4()) @@ -1024,7 +1033,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1053,7 +1062,7 @@ class TestDocumentAddByTextApi: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset # Act & Assert with app.test_request_context( @@ -1139,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost: _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = "economy" mock_dataset.latest_process_rule = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() @@ -1182,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None doc_id = str(uuid.uuid4()) with app.test_request_context( @@ -1221,7 +1230,7 @@ class TestDocumentAddByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1252,7 +1261,7 @@ class TestDocumentAddByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1287,7 +1296,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = "economy" mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset with app.test_request_context( f"/datasets/{mock_dataset.id}/document/create_by_file", @@ -1317,7 +1326,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = None mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1355,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1391,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1439,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost: mock_dataset.chunk_structure = None mock_dataset.latest_process_rule = Mock() mock_dataset.created_by_account = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py index b58caf3be1..c0b40d070a 100644 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ b/api/tests/unit_tests/controllers/service_api/test_site.py @@ -88,7 +88,7 @@ class TestAppSiteApi: mock_app_model.tenant = mock_tenant # Mock wraps.db for authentication - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -98,7 +98,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site.db for site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -109,7 +109,7 @@ class TestAppSiteApi: assert response["title"] == "Test Site" assert response["icon"] == "icon-url" assert response["description"] == "Site description" - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() @patch("controllers.service_api.wraps.user_logged_in") @patch("controllers.service_api.app.site.db") @@ -140,7 +140,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -150,7 +150,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query to return None - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -187,7 +187,7 @@ class TestAppSiteApi: mock_tenant = Mock() mock_tenant.status = TenantStatus.NORMAL - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -197,7 +197,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Set tenant status to archived AFTER authentication mock_app_model.tenant.status = TenantStatus.ARCHIVE @@ -230,7 +230,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -258,7 +258,7 @@ class TestAppSiteApi: mock_site.icon_type = "image" mock_site.created_at = "2024-01-01T00:00:00" mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -267,4 +267,4 @@ class TestAppSiteApi: # Assert # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 9c2d075f41..a2008e024b 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -144,14 +144,10 @@ class TestValidateAppToken: mock_ta = Mock() mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - mock_account, - ] + # Use side_effect to return app first, then tenant via session.get() + mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query + # Mock the tenant owner query (execute(select(...)).one_or_none()) setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) @validate_app_token @@ -175,7 +171,7 @@ class TestValidateAppToken: mock_api_token.app_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None @validate_app_token def protected_view(**kwargs): @@ -198,7 +194,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "abnormal" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -222,7 +218,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "normal" mock_app.enable_api = False - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -474,11 +470,11 @@ class TestValidateDatasetToken: mock_account.id = mock_ta.account_id mock_account.current_tenant = mock_tenant - # Mock the tenant account join query + # Mock the tenant account join query (execute(select(...)).one_or_none()) setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) - # Mock the account query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + # Mock the account lookup via session.get() + mock_db.session.get.return_value = mock_account @validate_dataset_token def protected_view(tenant_id): @@ -501,7 +497,7 @@ class TestValidateDatasetToken: mock_api_token.tenant_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None @validate_dataset_token def protected_view(dataset_id=None, **kwargs): diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index 01f34345aa..a6ca441801 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -21,7 +21,7 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index e88bcf2ae6..4f8d848637 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -18,7 +18,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError +from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index f6d1edbaf0..cde8820e00 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -6,7 +6,7 @@ import pytest from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): @@ -387,7 +387,7 @@ class TestRun: runner.update_prompt_message_tool.assert_called_once() def test_historic_with_assistant_and_tool_calls(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage assistant = AssistantPromptMessage(content="thinking") assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] @@ -400,7 +400,7 @@ class TestRun: assert isinstance(result, list) def test_historic_final_flush_branch(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage assistant = AssistantPromptMessage(content="final") runner.history_prompt_messages = [assistant] @@ -458,7 +458,7 @@ class TestFillInputsEdgeCases: class TestOrganizeHistoricPromptMessagesExtended: def test_user_message_flushes_scratchpad(self, runner, mocker): - from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + from graphon.model_runtime.entities.message_entities import UserPromptMessage user_message = UserPromptMessage(content="Hi") @@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended: assert result == ["final"] def test_tool_message_without_scratchpad_raises(self, runner): - from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + from graphon.model_runtime.entities.message_entities import ToolPromptMessage runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index f9d69d1196..ea8cc8aa86 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, @@ -93,7 +93,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", @@ -118,7 +118,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index ab822bb57d..2f5873d865 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -3,7 +3,7 @@ import json import pytest from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 299c9b31d2..17ab5babcb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -8,8 +8,8 @@ from core.agent.errors import AgentMaxIterationError from core.agent.fc_agent_runner import FunctionCallAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, diff --git a/api/tests/unit_tests/core/app/app_config/__init__.py b/api/tests/unit_tests/core/app/app_config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py new file mode 100644 index 0000000000..1c5b6ed944 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_parameters_mapping.py @@ -0,0 +1,227 @@ +from unittest.mock import MagicMock + +import pytest + +# Module under test +from core.app.app_config.common import parameters_mapping + + +class TestGetParametersFromFeatureDict: + """Test suite for get_parameters_from_feature_dict""" + + @pytest.fixture + def mock_config(self, monkeypatch): + """Mock dify_config values""" + mock = MagicMock() + mock.UPLOAD_IMAGE_FILE_SIZE_LIMIT = 1 + mock.UPLOAD_VIDEO_FILE_SIZE_LIMIT = 2 + mock.UPLOAD_AUDIO_FILE_SIZE_LIMIT = 3 + mock.UPLOAD_FILE_SIZE_LIMIT = 4 + mock.WORKFLOW_FILE_UPLOAD_LIMIT = 5 + + monkeypatch.setattr(parameters_mapping, "dify_config", mock) + return mock + + @pytest.fixture + def mock_default_file_limits(self, monkeypatch): + """Mock DEFAULT_FILE_NUMBER_LIMITS constant""" + monkeypatch.setattr(parameters_mapping, "DEFAULT_FILE_NUMBER_LIMITS", 99) + return 99 + + @pytest.fixture + def minimal_inputs(self): + return {}, [] + + @pytest.mark.parametrize( + ("feature_key", "expected_default"), + [ + ("suggested_questions", []), + ("suggested_questions_after_answer", {"enabled": False}), + ("speech_to_text", {"enabled": False}), + ("text_to_speech", {"enabled": False}), + ("retriever_resource", {"enabled": False}), + ("annotation_reply", {"enabled": False}), + ("more_like_this", {"enabled": False}), + ( + "sensitive_word_avoidance", + {"enabled": False, "type": "", "configs": []}, + ), + ], + ) + def test_defaults_when_key_missing( + self, + feature_key, + expected_default, + mock_config, + mock_default_file_limits, + ): + # Arrange + features = {} + user_input = [] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + assert result[feature_key] == expected_default + + def test_opening_statement_present(self, mock_config, mock_default_file_limits): + # Arrange + features = {"opening_statement": "Hello"} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] == "Hello" + + def test_opening_statement_missing_returns_none(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + assert result["opening_statement"] is None + + def test_all_features_provided(self, mock_config, mock_default_file_limits): + # Arrange + features = { + "opening_statement": "Hi", + "suggested_questions": ["Q1"], + "suggested_questions_after_answer": {"enabled": True}, + "speech_to_text": {"enabled": True}, + "text_to_speech": {"enabled": True}, + "retriever_resource": {"enabled": True}, + "annotation_reply": {"enabled": True}, + "more_like_this": {"enabled": True}, + "sensitive_word_avoidance": { + "enabled": True, + "type": "strict", + "configs": ["a"], + }, + "file_upload": { + "image": { + "enabled": True, + "number_limits": 10, + "detail": "low", + "transfer_methods": ["local_file"], + } + }, + } + user_input = [{"name": "field1"}] + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + # Assert + for key in features: + assert result[key] == features[key] + assert result["user_input_form"] == user_input + + def test_file_upload_default_structure(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + file_upload = result["file_upload"] + assert file_upload["image"]["enabled"] is False + assert file_upload["image"]["number_limits"] == 99 + assert file_upload["image"]["detail"] == "high" + assert "remote_url" in file_upload["image"]["transfer_methods"] + assert "local_file" in file_upload["image"]["transfer_methods"] + + def test_system_parameters_from_config(self, mock_config, mock_default_file_limits): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + # Assert + system_params = result["system_parameters"] + assert system_params["image_file_size_limit"] == 1 + assert system_params["video_file_size_limit"] == 2 + assert system_params["audio_file_size_limit"] == 3 + assert system_params["file_size_limit"] == 4 + assert system_params["workflow_file_upload_limit"] == 5 + + @pytest.mark.parametrize( + ("features_dict", "user_input_form"), + [ + (None, []), + ([], []), + ("invalid", []), + ], + ) + def test_invalid_features_dict_type_raises(self, features_dict, user_input_form): + # Act & Assert + with pytest.raises(AttributeError): + parameters_mapping.get_parameters_from_feature_dict( + features_dict=features_dict, + user_input_form=user_input_form, + ) + + @pytest.mark.parametrize( + "user_input_form", + [None, "invalid", 123], + ) + def test_user_input_form_invalid_type(self, mock_config, mock_default_file_limits, user_input_form): + # Arrange + features = {} + + # Act + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input_form, + ) + + # Assert + assert result["user_input_form"] == user_input_form + + def test_empty_user_input_form(self, mock_config, mock_default_file_limits): + features = {} + user_input = [] + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=user_input, + ) + + assert result["user_input_form"] == [] + + def test_feature_values_none(self, mock_config, mock_default_file_limits): + features = { + "suggested_questions": None, + "speech_to_text": None, + } + + result = parameters_mapping.get_parameters_from_feature_dict( + features_dict=features, + user_input_form=[], + ) + + assert result["suggested_questions"] is None + assert result["speech_to_text"] is None diff --git a/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py new file mode 100644 index 0000000000..013ed0cbc4 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/common/test_sensitive_word_avoidance_manager.py @@ -0,0 +1,202 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.common.sensitive_word_avoidance.manager import ( + SensitiveWordAvoidanceConfigManager, +) + + +class TestSensitiveWordAvoidanceConfigManagerConvert: + """Tests for convert classmethod""" + + @pytest.mark.parametrize( + "config", + [ + {}, + {"sensitive_word_avoidance": None}, + {"sensitive_word_avoidance": {}}, + {"sensitive_word_avoidance": {"enabled": False}}, + ], + ) + def test_convert_returns_none_when_disabled_or_missing(self, config): + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result is None + + def test_convert_returns_entity_when_enabled(self, mocker): + # Arrange + mock_entity = MagicMock() + mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"key": "value"}, + } + } + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + assert result == mock_entity + + def test_convert_enabled_without_type_or_config(self, mocker): + # Arrange + mock_entity = MagicMock() + patched = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.SensitiveWordAvoidanceEntity", + return_value=mock_entity, + ) + + config = {"sensitive_word_avoidance": {"enabled": True}} + + # Act + result = SensitiveWordAvoidanceConfigManager.convert(config) + + # Assert + patched.assert_called_once_with(type=None, config={}) + assert result == mock_entity + + +class TestSensitiveWordAvoidanceConfigManagerValidateAndSetDefaults: + """Tests for validate_and_set_defaults classmethod""" + + @pytest.fixture + def base_config(self): + return {} + + def test_validate_sets_default_when_missing(self, base_config): + # Act + config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=base_config.copy() + ) + + # Assert + assert config["sensitive_word_avoidance"]["enabled"] is False + assert fields == ["sensitive_word_avoidance"] + + def test_validate_raises_when_not_dict(self): + config = {"sensitive_word_avoidance": "invalid"} + + with pytest.raises(ValueError, match="must be of dict type"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + @pytest.mark.parametrize( + "config", + [ + {"sensitive_word_avoidance": {"enabled": False}}, + {"sensitive_word_avoidance": {"enabled": None}}, + {"sensitive_word_avoidance": {}}, + ], + ) + def test_validate_disables_when_enabled_false_or_missing(self, config): + # Act + result_config, _ = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + assert result_config["sensitive_word_avoidance"]["enabled"] is False + + def test_validate_raises_when_enabled_true_without_type(self): + config = {"sensitive_word_avoidance": {"enabled": True}} + + with pytest.raises(ValueError, match="type is required"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_type_not_string(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": 123, + } + } + + with pytest.raises(ValueError, match="must be a string"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_raises_when_config_not_dict(self): + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": "invalid", + } + } + + with pytest.raises(ValueError, match="must be a dict"): + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + def test_validate_calls_moderation_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + result_config, fields = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config + ) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={"k": "v"}) + assert result_config["sensitive_word_avoidance"]["enabled"] is True + assert fields == ["sensitive_word_avoidance"] + + def test_validate_sets_empty_dict_when_config_none(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": None, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id="tenant1", config=config) + + # Assert + mock_validate.assert_called_once_with(name="mock_type", tenant_id="tenant1", config={}) + + def test_validate_only_structure_validate_skips_factory(self, mocker): + # Arrange + mock_validate = mocker.patch( + "core.app.app_config.common.sensitive_word_avoidance.manager.ModerationFactory.validate_config" + ) + + config = { + "sensitive_word_avoidance": { + "enabled": True, + "type": "mock_type", + "config": {"k": "v"}, + } + } + + # Act + SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id="tenant1", config=config, only_structure_validate=True + ) + + # Assert + mock_validate.assert_not_called() diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py new file mode 100644 index 0000000000..992b580376 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_agent_manager.py @@ -0,0 +1,236 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager + + +class TestAgentConfigManagerConvert: + @pytest.fixture + def base_config(self): + return { + "agent_mode": { + "enabled": True, + "strategy": "cot", + "tools": [], + }, + "model": { + "provider": "openai", + "name": "gpt-4", + "mode": "completion", + }, + } + + def test_convert_returns_none_when_agent_mode_missing(self): + config = {"model": {"provider": "openai", "name": "gpt-4"}} + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize("agent_mode_value", [None, {}, {"enabled": False}]) + def test_convert_returns_none_when_agent_mode_disabled(self, agent_mode_value, base_config): + config = base_config.copy() + config["agent_mode"] = agent_mode_value + + result = AgentConfigManager.convert(config) + + assert result is None + + @pytest.mark.parametrize( + ("strategy_input", "expected_enum"), + [ + ("function_call", "FUNCTION_CALLING"), + ("cot", "CHAIN_OF_THOUGHT"), + ("react", "CHAIN_OF_THOUGHT"), + ], + ) + def test_convert_strategy_mapping(self, strategy_input, expected_enum, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": strategy_input, + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.strategy.name == expected_enum + + def test_convert_unknown_strategy_openai_defaults_to_function_calling(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "openai" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "FUNCTION_CALLING" + + def test_convert_unknown_strategy_non_openai_defaults_to_chain_of_thought(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "unknown_strategy", + "tools": [], + } + config["model"]["provider"] = "anthropic" + + result = AgentConfigManager.convert(config) + + assert result.strategy.name == "CHAIN_OF_THOUGHT" + + def test_convert_skips_disabled_tools(self, mocker, base_config): + # Patch AgentEntity to bypass pydantic validation + mock_agent_entity = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentEntity", + return_value=MagicMock(), + ) + + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value={ + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "tool_parameters": {}, + "credential_id": None, + }, + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + { + "provider_type": "type1", + "provider_id": "id1", + "tool_name": "tool1", + "enabled": False, + }, + { + "provider_type": "type2", + "provider_id": "id2", + "tool_name": "tool2", + "enabled": True, + "extra_key": "x", + }, + ], + } + + AgentConfigManager.convert(config) + + mock_validate.assert_called_once() + mock_agent_entity.assert_called_once() + + def test_convert_tool_requires_minimum_keys(self, mocker, base_config): + mock_validate = mocker.patch( + "core.app.app_config.easy_ui_based_app.agent.manager.AgentToolEntity.model_validate", + return_value=MagicMock(), + ) + + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "cot", + "tools": [ + {"a": 1, "b": 2}, # insufficient keys + ], + } + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.tools == [] + mock_validate.assert_not_called() + + def test_convert_completion_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "completion" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_chat_mode_prompt_defaults(self, base_config): + config = base_config.copy() + config["agent_mode"]["prompt"] = {} + config["model"]["mode"] = "chat" + + result = AgentConfigManager.convert(config) + + assert result is not None + assert result.prompt.first_prompt is not None + assert result.prompt.next_iteration is not None + + def test_convert_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_react_router_strategy_returns_none(self, base_config): + config = base_config.copy() + config["agent_mode"] = { + "enabled": True, + "strategy": "react_router", + "tools": [], + } + + result = AgentConfigManager.convert(config) + + assert result is None + + def test_convert_max_iteration_default(self, base_config): + config = base_config.copy() + config["agent_mode"].pop("max_iteration", None) + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 10 + + def test_convert_custom_max_iteration(self, base_config): + config = base_config.copy() + config["agent_mode"]["max_iteration"] = 25 + + result = AgentConfigManager.convert(config) + + assert result.max_iteration == 25 + + def test_convert_missing_model_raises_key_error(self, base_config): + config = base_config.copy() + del config["model"] + + with pytest.raises(KeyError): + AgentConfigManager.convert(config) + + @pytest.mark.parametrize( + ("invalid_config", "should_raise"), + [ + (None, True), + (123, True), + ("", False), + ([], False), + ], + ) + def test_convert_invalid_input_type_behavior(self, invalid_config, should_raise): + if should_raise: + with pytest.raises(TypeError): + AgentConfigManager.convert(invalid_config) # type: ignore + else: + result = AgentConfigManager.convert(invalid_config) # type: ignore + assert result is None diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py new file mode 100644 index 0000000000..a688e2a5c5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_dataset_manager.py @@ -0,0 +1,319 @@ +import uuid +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager +from core.entities.agent_entities import PlanningStrategy +from models.model import AppMode + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def valid_uuid(): + return str(uuid.uuid4()) + + +@pytest.fixture +def base_config(valid_uuid): + return { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + + +@pytest.fixture +def mock_dataset_service(mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + +# ============================== +# convert tests +# ============================== + + +class TestDatasetConfigManagerConvert: + def test_convert_returns_none_when_no_datasets(self): + config = {"dataset_configs": {"datasets": {"datasets": []}}} + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_single_retrieval(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result is not None + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable == "query" + + def test_convert_single_with_metadata_configs(self, valid_uuid, mocker): + mock_retrieve_config = MagicMock() + mock_entity = MagicMock() + mock_entity.dataset_ids = [valid_uuid] + mock_entity.retrieve_config = mock_retrieve_config + + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.ModelConfig", + return_value={"mock": "model"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.MetadataFilteringCondition", + return_value={"mock": "condition"}, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetRetrieveConfigEntity", + return_value=mock_retrieve_config, + ) + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetEntity", + return_value=mock_entity, + ) + + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "single", + "metadata_filtering_mode": "manual", + "metadata_model_config": {"any": "value"}, + "metadata_filtering_conditions": {"any": "value"}, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config is mock_retrieve_config + + def test_convert_multiple_defaults(self, valid_uuid): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + } + } + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 4 + assert result.retrieve_config.score_threshold is None + assert result.retrieve_config.reranking_enabled is True + + def test_convert_agent_mode_disabled_tool(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": False}}], + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_dataset_configs_none(self): + config = {"dataset_configs": None} + with pytest.raises(TypeError): + DatasetConfigManager.convert(config) + + def test_convert_agent_mode_old_style_old_format(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + assert result.retrieve_config.query_variable is None + + def test_convert_multiple_with_score_threshold(self, valid_uuid): + config = { + "dataset_query_variable": "query", + "dataset_configs": { + "retrieval_model": "multiple", + "top_k": 5, + "score_threshold": 0.8, + "score_threshold_enabled": True, + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + }, + }, + } + + result = DatasetConfigManager.convert(config) + assert result.retrieve_config.top_k == 5 + assert result.retrieve_config.score_threshold == 0.8 + + @pytest.mark.parametrize( + "dataset_entry", + [ + {}, + {"invalid": {}}, + {"dataset": {"id": None, "enabled": True}}, + {"dataset": {"id": "", "enabled": False}}, + ], + ) + def test_convert_ignores_invalid_dataset_entries(self, dataset_entry): + config = { + "dataset_configs": { + "retrieval_model": "multiple", + "datasets": {"strategy": "router", "datasets": [dataset_entry]}, + } + } + result = DatasetConfigManager.convert(config) + assert result is None + + def test_convert_agent_mode_old_style(self, valid_uuid): + config = { + "agent_mode": { + "enabled": True, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + result = DatasetConfigManager.convert(config) + assert result.dataset_ids == [valid_uuid] + + +# ============================== +# validate_and_set_defaults tests +# ============================== + + +class TestValidateAndSetDefaults: + def test_validate_sets_defaults(self): + config = {} + updated, fields = DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + assert "dataset_configs" in updated + assert updated["dataset_configs"]["retrieval_model"] == "single" + assert isinstance(fields, list) + + def test_validate_raises_when_dataset_configs_not_dict(self): + config = {"dataset_configs": "invalid"} + with pytest.raises(AttributeError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.CHAT, config) + + def test_validate_requires_query_variable_in_completion_mode(self, valid_uuid): + config = { + "dataset_configs": { + "datasets": { + "strategy": "router", + "datasets": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + } + with pytest.raises(ValueError): + DatasetConfigManager.validate_and_set_defaults("tenant1", AppMode.COMPLETION, config) + + +# ============================== +# extract_dataset_config_for_legacy_compatibility tests +# ============================== + + +class TestExtractDatasetConfig: + def test_extract_sets_defaults(self): + config = {} + result = DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + assert "agent_mode" in result + assert result["agent_mode"]["enabled"] is False + assert result["agent_mode"]["tools"] == [] + + def test_extract_invalid_agent_mode_type(self): + config = {"agent_mode": "invalid"} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_enabled_type(self): + config = {"agent_mode": {"enabled": "yes"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_tools_type(self): + config = {"agent_mode": {"enabled": True, "tools": "invalid"}} + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_invalid_uuid(self, mocker): + invalid_uuid = "not-a-uuid" + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": invalid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + def test_extract_dataset_not_exists(self, valid_uuid, mocker): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER, + "tools": [{"dataset": {"id": valid_uuid, "enabled": True}}], + } + } + with pytest.raises(ValueError): + DatasetConfigManager.extract_dataset_config_for_legacy_compatibility("tenant1", AppMode.CHAT, config) + + +# ============================== +# is_dataset_exists tests +# ============================== + + +class TestIsDatasetExists: + def test_dataset_exists_true(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "tenant1" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + + assert DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_not_found(self, mocker, valid_uuid): + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=None, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) + + def test_dataset_exists_false_when_tenant_mismatch(self, mocker, valid_uuid): + mock_dataset = MagicMock() + mock_dataset.tenant_id = "other" + mocker.patch( + "core.app.app_config.easy_ui_based_app.dataset.manager.DatasetService.get_dataset", + return_value=mock_dataset, + ) + assert not DatasetConfigManager.is_dataset_exists("tenant1", valid_uuid) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py new file mode 100644 index 0000000000..186b4a501d --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter +from core.entities.model_entities import ModelStatus +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + + +class TestModelConfigConverter: + @pytest.fixture(autouse=True) + def patch_response_entity(self, mocker): + """ + Patch ModelConfigWithCredentialsEntity to bypass Pydantic validation + and return a simple namespace object instead. + """ + + def _factory(**kwargs): + return SimpleNamespace(**kwargs) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.ModelConfigWithCredentialsEntity", + side_effect=_factory, + ) + + @pytest.fixture + def mock_app_config(self): + app_config = MagicMock() + app_config.tenant_id = "tenant_1" + + model_config = MagicMock() + model_config.provider = "openai" + model_config.model = "gpt-4" + model_config.parameters = {"temperature": 0.5} + model_config.mode = None + + app_config.model = model_config + return app_config + + @pytest.fixture + def mock_provider_bundle(self): + bundle = MagicMock() + + # configuration + configuration = MagicMock() + configuration.provider.provider = "openai" + configuration.get_current_credentials.return_value = {"api_key": "key"} + + provider_model = MagicMock() + provider_model.status = ModelStatus.ACTIVE + configuration.get_provider_model.return_value = provider_model + + bundle.configuration = configuration + + # model type instance + model_type_instance = MagicMock() + model_schema = MagicMock() + model_schema.model_properties = {} + model_type_instance.get_model_schema.return_value = model_schema + bundle.model_type_instance = model_type_instance + + return bundle + + @pytest.fixture + def patch_provider_manager(self, mocker, mock_provider_bundle): + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + return mock_manager + + # ============================= + # Positive Scenarios + # ============================= + + def test_convert_success_default_mode(self, mock_app_config, patch_provider_manager): + result = ModelConfigConverter.convert(mock_app_config) + + assert result.provider == "openai" + assert result.model == "gpt-4" + assert result.mode == LLMMode.CHAT + assert result.parameters == {"temperature": 0.5} + assert result.stop == [] + + def test_convert_success_with_stop_parameter(self, mock_app_config, patch_provider_manager): + mock_app_config.model.parameters = {"temperature": 0.7, "stop": ["\n"]} + + result = ModelConfigConverter.convert(mock_app_config) + + assert result.parameters == {"temperature": 0.7} + assert result.stop == ["\n"] + + def test_convert_mode_from_schema_valid(self, mock_app_config, mock_provider_bundle, mocker): + mock_app_config.model.mode = None + + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: LLMMode.COMPLETION.value + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.COMPLETION + + def test_convert_mode_from_schema_invalid_fallback(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value.model_properties = { + ModelPropertyKey.MODE: "invalid" + } + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + result = ModelConfigConverter.convert(mock_app_config) + assert result.mode == LLMMode.CHAT + + # ============================= + # Credential Errors + # ============================= + + def test_convert_credentials_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_current_credentials.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ProviderTokenNotInitError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Provider Model Errors + # ============================= + + def test_convert_provider_model_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.configuration.get_provider_model.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (ModelStatus.NO_CONFIGURE, ProviderTokenNotInitError), + (ModelStatus.NO_PERMISSION, ModelCurrentlyNotSupportError), + (ModelStatus.QUOTA_EXCEEDED, QuotaExceededError), + ], + ) + def test_convert_provider_model_status_errors( + self, mock_app_config, mock_provider_bundle, mocker, status, expected_exception + ): + mock_provider = MagicMock() + mock_provider.status = status + mock_provider_bundle.configuration.get_provider_model.return_value = mock_provider + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(expected_exception): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Schema Errors + # ============================= + + def test_convert_model_schema_none_raises(self, mock_app_config, mock_provider_bundle, mocker): + mock_provider_bundle.model_type_instance.get_model_schema.return_value = None + + mock_manager = MagicMock() + mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", + return_value=mock_manager, + ) + + with pytest.raises(ValueError): + ModelConfigConverter.convert(mock_app_config) + + # ============================= + # Edge Cases + # ============================= + + @pytest.mark.parametrize( + "parameters", + [ + {}, + {"stop": []}, + {"stop": ["END"], "max_tokens": 100}, + ], + ) + def test_convert_parameter_edge_cases(self, mock_app_config, patch_provider_manager, parameters): + mock_app_config.model.parameters = parameters.copy() + + result = ModelConfigConverter.convert(mock_app_config) + + if "stop" in parameters: + assert result.stop == parameters.get("stop") + expected_params = parameters.copy() + expected_params.pop("stop", None) + assert result.parameters == expected_params + else: + assert result.stop == [] + assert result.parameters == parameters diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py new file mode 100644 index 0000000000..68bca485bb --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +import pytest + +# Target +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def valid_completion_params(): + return {"temperature": 0.7, "stop": ["\n"]} + + +@pytest.fixture +def valid_model_list(): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {"mode": "chat"} + return [model] + + +@pytest.fixture +def provider_entities(): + provider = MagicMock() + provider.provider = "openai/gpt" + return [provider] + + +@pytest.fixture +def valid_config(): + return { + "model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.5, "stop": ["END"]}} + } + + +# ----------------------------- +# Test Class +# ----------------------------- + + +class TestModelConfigManager: + @staticmethod + def _patch_model_assembly(mocker, *, provider_entities, model_list): + assembly = MagicMock() + assembly.model_provider_factory.get_providers.return_value = provider_entities + assembly.provider_manager.get_configurations.return_value.get_models.return_value = model_list + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) + return assembly + + # ========================================================== + # convert + # ========================================================== + + def test_convert_success(self, valid_config): + result = ModelConfigManager.convert(valid_config) + + assert result.provider == "openai/gpt" + assert result.model == "gpt-4" + assert result.parameters == {"temperature": 0.5} + assert result.stop == ["END"] + + def test_convert_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.convert({}) + + def test_convert_without_stop(self): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {"temperature": 0.9}}} + result = ModelConfigManager.convert(config) + assert result.stop == [] + assert result.parameters == {"temperature": 0.9} + + # ========================================================== + # validate_model_completion_params + # ========================================================== + + @pytest.mark.parametrize( + "invalid_cp", + [None, "string", 123, []], + ) + def test_validate_model_completion_params_invalid_type(self, invalid_cp): + with pytest.raises(ValueError, match="must be of object type"): + ModelConfigManager.validate_model_completion_params(invalid_cp) + + def test_validate_model_completion_params_default_stop(self): + cp = {"temperature": 0.2} + result = ModelConfigManager.validate_model_completion_params(cp) + assert result["stop"] == [] + + def test_validate_model_completion_params_invalid_stop_type(self): + cp = {"stop": "invalid"} + with pytest.raises(ValueError, match="must be of list type"): + ModelConfigManager.validate_model_completion_params(cp) + + def test_validate_model_completion_params_stop_length_exceeded(self): + cp = {"stop": [1, 2, 3, 4, 5]} + with pytest.raises(ValueError, match="less than 4"): + ModelConfigManager.validate_model_completion_params(cp) + + # ========================================================== + # validate_and_set_defaults + # ========================================================== + + def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) + + assert updated_config["model"]["mode"] == "chat" + assert keys == ["model"] + + def test_validate_and_set_defaults_missing_model(self): + with pytest.raises(ValueError, match="model is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", {}) + + def test_validate_and_set_defaults_model_not_dict(self): + with pytest.raises(ValueError, match="object type"): + ModelConfigManager.validate_and_set_defaults("tenant1", {"model": "invalid"}) + + def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): + config = {"model": {"name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): + config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.provider is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="model.name is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + with pytest.raises(ValueError, match="must be in the specified model list"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_default_mode_when_missing(self, mocker, provider_entities): + model = MagicMock() + model.model = "gpt-4" + model.model_properties = {} + + config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[model]) + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + assert updated_config["model"]["mode"] == "completion" + + def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): + config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) + + with pytest.raises(ValueError, match="completion_params is required"): + ModelConfigManager.validate_and_set_defaults("tenant1", config) + + def test_validate_and_set_defaults_provider_without_slash_converted(self, mocker, valid_model_list): + """ + Covers branch where provider does not contain '/' and + ModelProviderID conversion is triggered (line 64). + """ + config = { + "model": { + "provider": "openai", # no slash -> triggers conversion + "name": "gpt-4", + "completion_params": {}, + } + } + + # Mock ModelProviderID to return formatted provider + mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") + mock_provider_id.return_value = "openai/gpt" + provider_entity = MagicMock() + provider_entity.provider = "openai/gpt" + self._patch_model_assembly(mocker, provider_entities=[provider_entity], model_list=valid_model_list) + + updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) + + # Ensure conversion happened + mock_provider_id.assert_called_once_with("openai") + assert updated_config["model"]["provider"] == "openai/gpt" diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py new file mode 100644 index 0000000000..fd49072cd5 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( + PromptTemplateConfigManager, +) + +# ----------------------------- +# Helpers +# ----------------------------- + + +class DummyEnumValue: + def __init__(self, value): + self.value = value + + +class DummyPromptType: + def __init__(self): + self.SIMPLE = "simple" + self.ADVANCED = "advanced" + + def value_of(self, value): + return value + + def __iter__(self): + return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + + +# ----------------------------- +# Convert Tests +# ----------------------------- + + +class TestPromptTemplateConfigManagerConvert: + def test_convert_missing_prompt_type_raises(self): + with pytest.raises(ValueError, match="prompt_type is required"): + PromptTemplateConfigManager.convert({}) + + def test_convert_simple_prompt(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mock_prompt_entity_cls.return_value = "simple_entity" + + config = {"prompt_type": "simple", "pre_prompt": "hello"} + + result = PromptTemplateConfigManager.convert(config) + + assert result == "simple_entity" + mock_prompt_entity_cls.assert_called_once_with(prompt_type="simple", simple_prompt_template="hello") + + def test_convert_advanced_chat_valid(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptMessageRole.value_of", + return_value="role_enum", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatMessageEntity", + return_value="chat_msg", + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedChatPromptTemplateEntity", + return_value="chat_template", + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{"text": "hi", "role": "user"}]}, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + @pytest.mark.parametrize( + "message", + [ + {"text": 123, "role": "user"}, + {"text": "hi", "role": 123}, + ], + ) + def test_convert_advanced_invalid_message_fields(self, mocker, message): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [message]}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.convert(config) + + def test_convert_advanced_completion_with_roles(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mock_prompt_entity_cls.return_value = "advanced_entity" + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.AdvancedCompletionPromptTemplateEntity", + return_value="completion_template", + ) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "complete"}, + "conversation_histories_role": { + "user_prefix": "U", + "assistant_prefix": "A", + }, + }, + } + + result = PromptTemplateConfigManager.convert(config) + + assert result == "advanced_entity" + + +# ----------------------------- +# validate_and_set_defaults +# ----------------------------- + + +class TestValidateAndSetDefaults: + def setup_method(self): + self.valid_model = {"mode": "chat"} + + def _patch_prompt_type(self, mocker): + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = DummyPromptType() + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + return mock_prompt_entity_cls + + def test_default_prompt_type_set(self, mocker): + self._patch_prompt_type(mocker) + + config = {"model": self.valid_model} + + result, keys = PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + assert result["prompt_type"] == "simple" + assert isinstance(keys, list) + + def test_invalid_prompt_type_raises(self, mocker): + class InvalidEnum(DummyPromptType): + def __iter__(self): + return iter([DummyEnumValue("valid")]) + + mock_prompt_entity_cls = MagicMock() + mock_prompt_entity_cls.PromptType = InvalidEnum() + + mocker.patch( + "core.app.app_config.easy_ui_based_app.prompt_template.manager.PromptTemplateEntity", + mock_prompt_entity_cls, + ) + + config = {"prompt_type": "invalid", "model": self.valid_model} + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_invalid_chat_prompt_config_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "chat_prompt_config": "invalid", + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_simple_mode_invalid_pre_prompt_type(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "simple", + "pre_prompt": 123, + "model": self.valid_model, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_requires_one_config(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {}, + "completion_prompt_config": {}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_invalid_model_mode(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": []}, + "model": {"mode": "invalid"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_advanced_chat_prompt_length_exceeds(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "chat_prompt_config": {"prompt": [{}] * 11}, + "model": {"mode": "chat"}, + } + + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_and_set_defaults("chat_app", config) + + def test_completion_prefix_defaults_set_when_empty(self, mocker): + self._patch_prompt_type(mocker) + + config = { + "prompt_type": "advanced", + "completion_prompt_config": { + "prompt": {"text": "hi"}, + "conversation_histories_role": { + "user_prefix": "", + "assistant_prefix": "", + }, + }, + "model": {"mode": "completion"}, + } + + updated, _ = PromptTemplateConfigManager.validate_and_set_defaults("chat", config) + + roles = updated["completion_prompt_config"]["conversation_histories_role"] + assert roles["user_prefix"] == "Human" + assert roles["assistant_prefix"] == "Assistant" + + +# ----------------------------- +# validate_post_prompt +# ----------------------------- + + +class TestValidatePostPrompt: + @pytest.mark.parametrize("value", [None, ""]) + def test_post_prompt_defaults(self, value): + config = {"post_prompt": value} + result = PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) + assert result["post_prompt"] == "" + + def test_post_prompt_invalid_type(self): + config = {"post_prompt": 123} + with pytest.raises(ValueError): + PromptTemplateConfigManager.validate_post_prompt_and_set_defaults(config) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py new file mode 100644 index 0000000000..d9fe7004ff --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -0,0 +1,286 @@ +import pytest + +from core.app.app_config.easy_ui_based_app.variables.manager import ( + BasicVariablesConfigManager, +) +from graphon.variables.input_entities import VariableEntityType + + +class TestBasicVariablesConfigManagerConvert: + def test_convert_empty_config(self): + config = {} + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + def test_convert_external_data_tools_enabled_and_disabled(self, mocker): + config = { + "external_data_tools": [ + {"enabled": False}, + { + "enabled": True, + "variable": "ext_var", + "type": "tool_type", + "config": {"k": "v"}, + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert len(external) == 1 + assert external[0].variable == "ext_var" + assert external[0].type == "tool_type" + + def test_convert_user_input_form_variable_types(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "variable": "name", + "label": "Name", + "description": "desc", + "required": True, + "max_length": 50, + } + }, + { + VariableEntityType.SELECT: { + "variable": "choice", + "label": "Choice", + "options": ["a", "b"], + } + }, + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + "config": {"x": 1}, + } + }, + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert len(variables) == 2 + assert len(external) == 1 + + def test_convert_external_data_tool_without_config_skipped(self): + config = { + "user_input_form": [ + { + VariableEntityType.EXTERNAL_DATA_TOOL: { + "variable": "ext", + "type": "tool", + } + } + ] + } + + variables, external = BasicVariablesConfigManager.convert(config) + + assert variables == [] + assert external == [] + + +class TestValidateVariablesAndSetDefaults: + def test_validate_sets_empty_user_input_form_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"] == [] + assert "user_input_form" in keys + + def test_validate_user_input_form_not_list_raises(self): + config = {"user_input_form": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_invalid_key_raises(self): + config = {"user_input_form": [{"invalid": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_label_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_label_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"variable": "name", "label": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_missing_variable_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name"}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_variable_not_string_raises(self): + config = {"user_input_form": [{VariableEntityType.TEXT_INPUT: {"label": "Name", "variable": 123}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + @pytest.mark.parametrize( + "variable_name", + ["1invalid", "invalid space", "", None], + ) + def test_validate_variable_invalid_pattern_raises(self, variable_name): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": variable_name, + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_required_default_and_type(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + } + } + ] + } + + updated, _ = BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + assert updated["user_input_form"][0][VariableEntityType.TEXT_INPUT]["required"] is False + + def test_validate_required_not_bool_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.TEXT_INPUT: { + "label": "Name", + "variable": "valid_name", + "required": "yes", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_default_not_in_options_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": ["a", "b"], + "default": "c", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + def test_validate_select_options_not_list_raises(self): + config = { + "user_input_form": [ + { + VariableEntityType.SELECT: { + "label": "Choice", + "variable": "choice", + "options": "not_list", + } + } + ] + } + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_variables_and_set_defaults(config) + + +class TestValidateExternalDataToolsAndSetDefaults: + def test_validate_sets_empty_external_data_tools_if_missing(self): + config = {} + + updated, keys = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + assert updated["external_data_tools"] == [] + assert "external_data_tools" in keys + + def test_validate_external_data_tools_not_list_raises(self): + config = {"external_data_tools": "invalid"} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_disabled_tool_skipped(self, mocker): + config = {"external_data_tools": [{"enabled": False}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + updated, _ = BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + spy.assert_not_called() + assert updated["external_data_tools"][0]["enabled"] is False + + def test_validate_enabled_tool_missing_type_raises(self): + config = {"external_data_tools": [{"enabled": True, "config": {}}]} + + with pytest.raises(ValueError): + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant", config) + + def test_validate_enabled_tool_calls_factory(self, mocker): + config = {"external_data_tools": [{"enabled": True, "type": "tool", "config": {"a": 1}}]} + + spy = mocker.patch( + "core.app.app_config.easy_ui_based_app.variables.manager.ExternalDataToolFactory.validate_config" + ) + + BasicVariablesConfigManager.validate_external_data_tools_and_set_defaults("tenant_id", config) + + spy.assert_called_once_with(name="tool", tenant_id="tenant_id", config={"a": 1}) + + +class TestValidateAndSetDefaultsIntegration: + def test_validate_and_set_defaults_calls_both(self, mocker): + config = {} + + spy_var = mocker.patch.object( + BasicVariablesConfigManager, + "validate_variables_and_set_defaults", + return_value=(config, ["user_input_form"]), + ) + spy_ext = mocker.patch.object( + BasicVariablesConfigManager, + "validate_external_data_tools_and_set_defaults", + return_value=(config, ["external_data_tools"]), + ) + + updated, keys = BasicVariablesConfigManager.validate_and_set_defaults("tenant", config) + + spy_var.assert_called_once() + spy_ext.assert_called_once() + assert "user_input_form" in keys + assert "external_data_tools" in keys + assert updated == config diff --git a/api/tests/unit_tests/core/app/app_config/features/__init__.py b/api/tests/unit_tests/core/app/app_config/features/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index de99833aac..11fc15c94d 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,6 +1,6 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from dify_graph.file.models import FileTransferMethod, FileUploadConfig, ImageConfig -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent def test_convert_with_vision(): diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py new file mode 100644 index 0000000000..dd00c3defc --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -0,0 +1,115 @@ +import pytest + +from core.app.app_config.entities import TextToSpeechEntity +from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager +from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager +from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager +from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager +from core.app.app_config.features.suggested_questions_after_answer.manager import ( + SuggestedQuestionsAfterAnswerConfigManager, +) +from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager + + +class TestAdditionalFeatureManagers: + def test_opening_statement_validate_defaults(self): + config, keys = OpeningStatementConfigManager.validate_and_set_defaults({}) + assert config["opening_statement"] == "" + assert config["suggested_questions"] == [] + assert set(keys) == {"opening_statement", "suggested_questions"} + + def test_opening_statement_validate_types(self): + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults({"opening_statement": 123}) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": "bad"} + ) + with pytest.raises(ValueError): + OpeningStatementConfigManager.validate_and_set_defaults( + {"opening_statement": "hi", "suggested_questions": [1]} + ) + + def test_opening_statement_convert(self): + opening, questions = OpeningStatementConfigManager.convert( + {"opening_statement": "hello", "suggested_questions": ["q1"]} + ) + assert opening == "hello" + assert questions == ["q1"] + + def test_retrieval_resource_validate(self): + config, keys = RetrievalResourceConfigManager.validate_and_set_defaults({}) + assert config["retriever_resource"]["enabled"] is False + assert keys == ["retriever_resource"] + + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": "bad"}) + with pytest.raises(ValueError): + RetrievalResourceConfigManager.validate_and_set_defaults({"retriever_resource": {"enabled": "yes"}}) + + def test_retrieval_resource_convert(self): + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": True}}) is True + assert RetrievalResourceConfigManager.convert({"retriever_resource": {"enabled": False}}) is False + + def test_speech_to_text_validate_and_convert(self): + config, keys = SpeechToTextConfigManager.validate_and_set_defaults({}) + assert config["speech_to_text"]["enabled"] is False + assert keys == ["speech_to_text"] + + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": "bad"}) + with pytest.raises(ValueError): + SpeechToTextConfigManager.validate_and_set_defaults({"speech_to_text": {"enabled": "yes"}}) + + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": True}}) is True + assert SpeechToTextConfigManager.convert({"speech_to_text": {"enabled": False}}) is False + + def test_suggested_questions_after_answer_validate_and_convert(self): + config, keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults({}) + assert config["suggested_questions_after_answer"]["enabled"] is False + assert keys == ["suggested_questions_after_answer"] + + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": "bad"} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": "yes"}} + ) + + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) + is True + ) + assert ( + SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": False}}) + is False + ) + + def test_text_to_speech_validate_and_convert(self): + config, keys = TextToSpeechConfigManager.validate_and_set_defaults({}) + assert config["text_to_speech"]["enabled"] is False + assert keys == ["text_to_speech"] + + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": "bad"}) + with pytest.raises(ValueError): + TextToSpeechConfigManager.validate_and_set_defaults({"text_to_speech": {"enabled": "yes"}}) + + result = TextToSpeechConfigManager.convert( + {"text_to_speech": {"enabled": True, "voice": "v", "language": "en"}} + ) + assert isinstance(result, TextToSpeechEntity) + assert result.voice == "v" + assert result.language == "en" + + def test_more_like_this_convert_and_validate(self): + config, keys = MoreLikeThisConfigManager.validate_and_set_defaults({}) + assert config["more_like_this"]["enabled"] is False + assert keys == ["more_like_this"] + + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": True}}) is True + assert MoreLikeThisConfigManager.convert({"more_like_this": {"enabled": False}}) is False + with pytest.raises(ValueError): + MoreLikeThisConfigManager.validate_and_set_defaults({"more_like_this": "bad"}) diff --git a/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py new file mode 100644 index 0000000000..e99852cf76 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_base_app_config_manager.py @@ -0,0 +1,180 @@ +from collections import UserDict +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.base_app_config_manager import BaseAppConfigManager + + +class TestBaseAppConfigManager: + @pytest.fixture + def mock_config_dict(self): + return {"key": "value", "another": 123} + + @pytest.fixture + def mock_app_additional_features(self, mocker): + mock_instance = MagicMock() + mocker.patch( + "core.app.app_config.base_app_config_manager.AppAdditionalFeatures", + return_value=mock_instance, + ) + return mock_instance + + @pytest.fixture + def mock_managers(self, mocker): + retrieval = mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + return_value="retrieval_result", + ) + file_upload = mocker.patch( + "core.app.app_config.base_app_config_manager.FileUploadConfigManager.convert", + return_value="file_upload_result", + ) + opening_statement = mocker.patch( + "core.app.app_config.base_app_config_manager.OpeningStatementConfigManager.convert", + return_value=("opening_result", "suggested_result"), + ) + suggested_after = mocker.patch( + "core.app.app_config.base_app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.convert", + return_value="suggested_after_result", + ) + more_like_this = mocker.patch( + "core.app.app_config.base_app_config_manager.MoreLikeThisConfigManager.convert", + return_value="more_like_this_result", + ) + speech_to_text = mocker.patch( + "core.app.app_config.base_app_config_manager.SpeechToTextConfigManager.convert", + return_value="speech_to_text_result", + ) + text_to_speech = mocker.patch( + "core.app.app_config.base_app_config_manager.TextToSpeechConfigManager.convert", + return_value="text_to_speech_result", + ) + + return { + "retrieval": retrieval, + "file_upload": file_upload, + "opening_statement": opening_statement, + "suggested_after": suggested_after, + "more_like_this": more_like_this, + "speech_to_text": speech_to_text, + "text_to_speech": text_to_speech, + } + + @pytest.mark.parametrize( + ("app_mode", "expected_is_vision"), + [ + ("CHAT", True), + ("COMPLETION", True), + ("AGENT_CHAT", True), + ("OTHER", False), + ], + ) + def test_convert_features_all_modes( + self, + mocker, + mock_config_dict, + mock_app_additional_features, + mock_managers, + app_mode, + expected_is_vision, + ): + # Arrange + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(mock_config_dict, app_mode) + + # Assert + assert result == mock_app_additional_features + mock_managers["retrieval"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["file_upload"].assert_called_once() + _, kwargs = mock_managers["file_upload"].call_args + assert kwargs["config"] == dict(mock_config_dict.items()) + assert kwargs["is_vision"] is expected_is_vision + + mock_managers["opening_statement"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["suggested_after"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["more_like_this"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["speech_to_text"].assert_called_once_with(config=dict(mock_config_dict.items())) + mock_managers["text_to_speech"].assert_called_once_with(config=dict(mock_config_dict.items())) + + def test_convert_features_empty_config(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + empty_config = {} + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(empty_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called + + @pytest.mark.parametrize( + "invalid_config", + [ + None, + "string", + 123, + 12.34, + [], + ], + ) + def test_convert_features_invalid_config_raises(self, invalid_config): + # Act & Assert + with pytest.raises((TypeError, AttributeError)): + BaseAppConfigManager.convert_features(invalid_config, "CHAT") + + def test_convert_features_manager_exception_propagates(self, mocker, mock_config_dict): + # Arrange + mocker.patch( + "core.app.app_config.base_app_config_manager.RetrievalResourceConfigManager.convert", + side_effect=RuntimeError("manager failure"), + ) + + # Act & Assert + with pytest.raises(RuntimeError): + BaseAppConfigManager.convert_features(mock_config_dict, "CHAT") + + def test_convert_features_mapping_subclass(self, mocker, mock_app_additional_features, mock_managers): + # Arrange + class CustomMapping(UserDict): + pass + + custom_config = CustomMapping({"a": 1}) + + mock_app_mode = MagicMock() + mock_app_mode.CHAT = "CHAT" + mock_app_mode.COMPLETION = "COMPLETION" + mock_app_mode.AGENT_CHAT = "AGENT_CHAT" + + mocker.patch( + "core.app.app_config.base_app_config_manager.AppMode", + mock_app_mode, + ) + + # Act + result = BaseAppConfigManager.convert_features(custom_config, "CHAT") + + # Assert + assert result == mock_app_additional_features + for manager in mock_managers.values(): + assert manager.called diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py new file mode 100644 index 0000000000..f2bc3076da --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -0,0 +1,43 @@ +import pytest + +from core.app.app_config.entities import ( + DatasetRetrieveConfigEntity, + PromptTemplateEntity, +) +from graphon.variables.input_entities import VariableEntity, VariableEntityType + + +class TestAppConfigEntities: + def test_variable_entity_coerces_none_description_and_options(self): + entity = VariableEntity( + variable="query", + label="Query", + description=None, + type=VariableEntityType.TEXT_INPUT, + options=None, + ) + + assert entity.description == "" + assert entity.options == [] + + def test_variable_entity_rejects_invalid_json_schema(self): + with pytest.raises(ValueError): + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + json_schema={"type": "string", "minLength": "bad"}, + ) + + def test_prompt_template_value_of(self): + assert PromptTemplateEntity.PromptType.value_of("simple") == PromptTemplateEntity.PromptType.SIMPLE + with pytest.raises(ValueError): + PromptTemplateEntity.PromptType.value_of("missing") + + def test_dataset_retrieve_strategy_value_of(self): + assert ( + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("single") + == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE + ) + with pytest.raises(ValueError): + DatasetRetrieveConfigEntity.RetrieveStrategy.value_of("missing") diff --git a/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py new file mode 100644 index 0000000000..fa128aca87 --- /dev/null +++ b/api/tests/unit_tests/core/app/app_config/workflow_ui_based_app/test_workflow_ui_based_app_manager.py @@ -0,0 +1,222 @@ +import pytest + +from core.app.app_config.workflow_ui_based_app.variables.manager import ( + WorkflowVariablesConfigManager, +) + +# ============================= +# Fixtures +# ============================= + + +@pytest.fixture +def mock_workflow(mocker): + workflow = mocker.MagicMock() + workflow.graph_dict = {"nodes": []} + return workflow + + +@pytest.fixture +def mock_variable_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.VariableEntity") + + +@pytest.fixture +def mock_rag_entity(mocker): + return mocker.patch("core.app.app_config.workflow_ui_based_app.variables.manager.RagPipelineVariableEntity") + + +# ============================= +# Test Convert (user_input_form) +# ============================= + + +class TestWorkflowVariablesConfigManagerConvert: + def test_convert_success_multiple_variables(self, mock_workflow, mock_variable_entity): + # Arrange + input_variables = [{"name": "var1"}, {"name": "var2"}] + mock_workflow.user_input_form.return_value = input_variables + mock_variable_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [{"validated": v} for v in input_variables] + assert mock_variable_entity.model_validate.call_count == 2 + + def test_convert_empty_list(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [] + + # Act + result = WorkflowVariablesConfigManager.convert(mock_workflow) + + # Assert + assert result == [] + mock_variable_entity.model_validate.assert_not_called() + + def test_convert_none_returned_raises(self, mock_workflow): + # Arrange + mock_workflow.user_input_form.return_value = None + + # Act & Assert + with pytest.raises(TypeError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + def test_convert_validation_error_propagates(self, mock_workflow, mock_variable_entity): + # Arrange + mock_workflow.user_input_form.return_value = [{"invalid": "data"}] + mock_variable_entity.model_validate.side_effect = ValueError("validation error") + + # Act & Assert + with pytest.raises(ValueError): + WorkflowVariablesConfigManager.convert(mock_workflow) + + +# ============================= +# Test convert_rag_pipeline_variable +# ============================= + + +class TestWorkflowVariablesConfigManagerConvertRag: + def test_no_rag_pipeline_variables(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = [] + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_rag_pipeline_none(self, mock_workflow): + # Arrange + mock_workflow.rag_pipeline_variables = None + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [] + + def test_no_matching_node_keeps_all(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert result == [{"validated": mock_workflow.rag_pipeline_variables[0]}] + + def test_string_pattern_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "{{#parent.var1#}}"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + def test_list_value_removes_variable(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + {"variable": "var2", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": ["x", "var1"]}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + assert result[0]["validated"]["variable"] == "var2" + + @pytest.mark.parametrize( + ("belong_to_node_id", "expected_count"), + [ + ("node1", 1), + ("shared", 1), + ("other_node", 0), + ], + ) + def test_belong_to_node_filtering(self, mock_workflow, mock_rag_entity, belong_to_node_id, expected_count): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": belong_to_node_id}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == expected_count + + def test_invalid_pattern_does_not_remove(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + + mock_workflow.graph_dict = { + "nodes": [ + { + "id": "node1", + "data": {"datasource_parameters": {"param1": {"value": "invalid_pattern"}}}, + } + ] + } + + mock_rag_entity.model_validate.side_effect = lambda x: {"validated": x} + + # Act + result = WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") + + # Assert + assert len(result) == 1 + + def test_validation_error_propagates(self, mock_workflow, mock_rag_entity): + # Arrange + mock_workflow.rag_pipeline_variables = [ + {"variable": "var1", "belong_to_node_id": "node1"}, + ] + mock_workflow.graph_dict = {"nodes": []} + mock_rag_entity.model_validate.side_effect = RuntimeError("validation failed") + + # Act & Assert + with pytest.raises(RuntimeError): + WorkflowVariablesConfigManager.convert_rag_pipeline_variable(mock_workflow, "node1") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 441d2fcd17..8b0ff7b6c1 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -1053,7 +1053,7 @@ class TestAdvancedChatAppGeneratorInternals: _ = kwargs def run(self): - from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + from graphon.model_runtime.errors.invoke import InvokeAuthorizationError raise InvokeAuthorizationError("bad key") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 15aceef2c7..ef7df5e1da 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -7,10 +7,23 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from dify_graph.variables import SegmentType from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -49,7 +62,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variable (only var1 exists in DB) @@ -200,7 +213,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Mock conversation and message @@ -349,7 +362,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variables (both exist in DB) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5792a2f1e2..079df0b4e6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -8,6 +8,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationError +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + @pytest.fixture def build_runner(): @@ -30,7 +43,7 @@ def build_runner(): mock_workflow.tenant_id = str(uuid4()) mock_workflow.app_id = app_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] mock_app_config = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index 5b199e0c52..f2df35d7d0 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -10,7 +10,7 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 83a6e0f231..56919d7f65 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -17,8 +17,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import EndUser @@ -137,7 +137,6 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No actions=[], node_id="node-1", node_title="Approval", - form_token="token-1", resolved_default_values={}, ) event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 0a244b3fea..c78844d173 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -42,11 +42,12 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import MessageStatus from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -166,7 +167,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -311,7 +312,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -522,7 +523,7 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -556,7 +557,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 53f26d1592..80f7f94b1a 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -6,7 +6,7 @@ from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 5603115b30..4567b35480 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -3,8 +3,8 @@ import pytest from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 3cdffbb4cd..8f3c41701b 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 67b3777c40..f56ca8de99 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -9,8 +9,8 @@ from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.file.enums import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index b0789bbc1e..d6f7a05cdc 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -3,13 +3,15 @@ from types import SimpleNamespace import pytest from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport -from dify_graph.runtime import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, build_system_variables(workflow_execution_id=workflow_run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 72430a3347..3ab63aed25 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from dify_graph.variables.segments import ArrayFileSegment, FileSegment +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: @@ -12,7 +12,6 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Create a test File object""" return File( id=file_id, - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", @@ -223,7 +222,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: assert len(result) == 1 file_dict = result[0] assert file_dict["id"] == "property_test" - assert file_dict["tenant_id"] == "test_tenant" + assert "tenant_id" not in file_dict assert file_dict["type"] == "document" assert file_dict["transfer_method"] == "local_file" assert file_dict["filename"] == "property_test.txt" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 4ed7d73cd0..e8946281ac 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -4,13 +4,13 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 5879e8fb9b..492e11ee0f 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 374af5ddc4..7ee375d884 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -24,9 +24,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -54,7 +54,7 @@ class TestWorkflowResponseConverter: mock_user.name = "Test User" mock_user.email = "test@example.com" - system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + system_variables = build_system_variables(workflow_id="wf-id", workflow_execution_id="initial-run-id") return WorkflowResponseConverter( application_generate_entity=mock_entity, user=mock_user, @@ -451,9 +451,9 @@ class TestWorkflowResponseConverterServiceApiTruncation: account.id = "test_user_id" return account - def create_test_system_variables(self) -> SystemVariable: + def create_test_system_variables(self): """Create test system variables.""" - return SystemVariable() + return build_system_variables() def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter: """Create WorkflowResponseConverter with specified invoke_from.""" diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 51f33bac35..aa2085177e 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -6,7 +6,7 @@ import pytest import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 2714757353..f2e35f9900 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -9,7 +9,7 @@ import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 94ed8166b9..cfe797aa76 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -10,7 +10,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 72f7552bd1..9db83f5531 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -13,7 +13,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index eec95b7f39..fb19d6d761 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -26,7 +26,7 @@ import pytest import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.graph_events import GraphRunFailedEvent +from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: @@ -284,7 +284,12 @@ def test_run_normal_path_builds_graph(mocker): return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), ) mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) - mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + class FakeVariablePool: + def add(self, selector, value): + return None + + mocker.patch.object(module, "VariablePool", return_value=FakeVariablePool()) workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index a3ced02394..b0f8b423e1 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,9 +1,7 @@ -from unittest.mock import MagicMock - import pytest from core.app.apps.base_app_generator import BaseAppGenerator -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -403,11 +401,11 @@ class TestBaseAppGeneratorExtras: monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mapping", - lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object", ) monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mappings", - lambda mappings, tenant_id, config: ["file-1", "file-2"], + lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"], ) user_inputs = { @@ -479,7 +477,7 @@ class TestBaseAppGeneratorExtras: def test_get_draft_var_saver_factory_debugger(self): from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.enums import BuiltinNodeTypes + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() @@ -489,7 +487,6 @@ class TestBaseAppGeneratorExtras: factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) saver = factory( - session=MagicMock(), app_id="app-id", node_id="node-id", node_type=BuiltinNodeTypes.START, diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py index c6dc20ffc6..842d14bbd2 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -59,3 +59,18 @@ class TestBaseAppQueueManager: bad = SimpleNamespace(_sa_instance_state=True) with pytest.raises(TypeError): manager._check_for_sqlalchemy_models(bad) + + def test_stop_listen_defers_graph_runtime_state_cleanup_until_listener_exits(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = None + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + runtime_state = SimpleNamespace(name="runtime-state") + manager.graph_runtime_state = runtime_state + + manager.stop_listen() + + assert manager.graph_runtime_state is runtime_state + assert list(manager.listen()) == [] + assert manager.graph_runtime_state is None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index aabeb54553..17de39ca99 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -14,15 +14,15 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 2f73a8cda8..3673b7f68e 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -3,33 +3,33 @@ import time from types import ModuleType, SimpleNamespace from typing import Any -import dify_graph.nodes.human_input.entities # noqa: F401 +import graphon.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_system_variables +from graphon.entities.base_node_data import BaseNodeData, RetryConfig +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import SchedulingPause +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult, PauseRequestedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult, PauseRequestedEvent +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.base.node import Node +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -162,11 +162,11 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) - variable_pool.system_variables.workflow_execution_id = run_id + variable_pool.add(("sys", "workflow_run_id"), run_id) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 3f1dd14569..58c7bfa4bc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from types import SimpleNamespace import pytest @@ -11,25 +11,35 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( +from core.workflow.system_variables import default_system_variables +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, NodeRunIterationSucceededEvent, NodeRunLoopFailedEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, + NodeRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.variables import StringVariable class TestWorkflowBasedAppRunner: @@ -44,7 +54,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -78,12 +88,12 @@ class TestWorkflowBasedAppRunner: workflow = SimpleNamespace(environment_variables=[], graph_dict={}) with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): - runner._prepare_single_node_execution(workflow, None, None) + runner._prepare_single_node_execution(workflow, None, None, user_id="00000000-0000-0000-0000-000000000001") def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -126,11 +136,102 @@ class TestWorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id="00000000-0000-0000-0000-000000000001", ) assert graph is not None assert variable_pool is graph_runtime_state.variable_pool + def test_get_graph_and_variable_pool_preloads_constructor_variables_before_graph_init(self, monkeypatch): + variable_loader = SimpleNamespace( + load_variables=lambda selectors: ( + [ + StringVariable( + name="conversation_id", + value="conv-1", + selector=["sys", "conversation_id"], + ) + ] + if selectors + else [] + ) + ) + runner = WorkflowBasedAppRunner( + queue_manager=SimpleNamespace(), + variable_loader=variable_loader, + app_id="app", + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + + workflow = SimpleNamespace( + tenant_id="tenant", + id="workflow", + graph_dict={ + "nodes": [ + {"id": "loop-node", "data": {"type": "loop", "version": "1", "title": "Loop"}}, + { + "id": "llm-child", + "data": { + "type": "llm", + "version": "1", + "loop_id": "loop-node", + "memory": object(), + }, + }, + ], + "edges": [], + }, + ) + + class _LoopNodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + def _validate_node_config(value): + return {"id": value["id"], "data": SimpleNamespace(**value["data"])} + + def _graph_init(**kwargs): + variable_pool = graph_runtime_state.variable_pool + assert variable_pool.get(["sys", "conversation_id"]) is not None + return SimpleNamespace() + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.NodeConfigDictAdapter.validate_python", + _validate_node_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + _graph_init, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.resolve_workflow_node_class", + lambda **_kwargs: _LoopNodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert graph is not None + assert variable_pool.get(["sys", "conversation_id"]).value == "conv-1" + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): published: list[object] = [] @@ -140,7 +241,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -183,7 +284,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -195,7 +296,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.START, node_title="Start", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), ), ) runner._handle_event( @@ -232,7 +333,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Iter", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={"ok": True}, metadata={}, @@ -246,7 +347,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Loop", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={}, metadata={}, @@ -259,3 +360,87 @@ class TestWorkflowBasedAppRunner: assert any(isinstance(event, QueueAgentLogEvent) for event in published) assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) + + @pytest.mark.parametrize( + ("event_factory", "queue_event_cls"), + [ + ( + lambda result, start_at, finished_at: NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + node_run_result=result, + ), + QueueNodeSucceededEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeFailedEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeExceptionEvent, + ), + ( + lambda result, start_at, _finished_at: NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=start_at, + error="boom", + retry_index=1, + node_run_result=result, + ), + QueueNodeRetryEvent, + ), + ], + ) + def test_handle_start_node_result_events_project_outputs(self, event_factory, queue_event_cls): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + started_at = datetime.now(UTC) + finished_at = datetime.now(UTC) + result = NodeRunResult( + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + "conversation.session_id": "session-1", + }, + ) + + runner._handle_event(workflow_entry, event_factory(result, started_at, finished_at)) + + queue_event = published[-1] + assert isinstance(queue_event, queue_event_cls) + assert queue_event.outputs == {"question": "hello"} diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 1388279221..38a947986f 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -4,8 +4,8 @@ import pytest from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph_events.graph import GraphRunPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118e..620a153204 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -9,15 +9,15 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow def _make_graph_state(): variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -100,6 +100,7 @@ def test_run_uses_single_node_execution_branch( workflow=workflow, single_iteration_run=single_iteration_run, single_loop_run=single_loop_run, + user_id="user", ) init_graph.assert_not_called() @@ -158,6 +159,7 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id="00000000-0000-0000-0000-000000000001", ) assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 65c6bd6654..ef0edf4096 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -10,13 +10,14 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.pause_reason import HumanInputRequired +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph_events.graph import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account +from models.human_input import RecipientType class _RecordingWorkflowAppRunner(WorkflowAppRunner): @@ -74,7 +75,6 @@ def test_graph_run_paused_event_emits_queue_pause_event(): actions=[], node_id="node-human", node_title="Human Step", - form_token="tok", ) event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) workflow_entry = SimpleNamespace( @@ -98,7 +98,7 @@ def _build_converter(): invoke_from=InvokeFrom.SERVICE_API, app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="user", app_id="app-id", workflow_id="workflow-id", @@ -128,7 +128,21 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon class _FakeSession: def execute(self, _stmt): - return [("form-1", expiration_time)] + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def scalars(self, _stmt): + return [ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.BACKSTAGE, + access_token="backstage-token", + ), + ] def __enter__(self): return self @@ -146,10 +160,8 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), ], actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, node_id="node-id", node_title="Human Step", - form_token="token", ) queue_event = QueueWorkflowPausedEvent( reasons=[reason], @@ -170,7 +182,6 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert pause_resp.data.paused_nodes == ["node-id"] assert pause_resp.data.outputs == {} assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True assert isinstance(responses[0], HumanInputRequiredResponse) hi_resp = responses[0] @@ -180,4 +191,5 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert hi_resp.data.inputs[0].output_variable_name == "field" assert hi_resp.data.actions[0].id == "approve" assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "backstage-token" assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 62e94a7580..7dd7ffd727 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -9,7 +9,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 5b23e71035..a0a999cbc5 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -7,11 +7,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode +from tests.workflow_test_utils import build_test_variable_pool def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: @@ -37,11 +38,7 @@ def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) + variable_pool = build_test_variable_pool(variables=build_system_variables(workflow_execution_id=run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index f35710d207..601c3989b9 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -44,11 +44,12 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -164,7 +165,7 @@ class TestWorkflowGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -205,7 +206,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -257,7 +258,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -451,7 +452,7 @@ class TestWorkflowGenerateTaskPipeline: ) assert pipeline._created_by_role == CreatorUserRole.END_USER - assert pipeline._workflow_system_variables.user_id == "session-id" + assert system_variables_to_mapping(pipeline._workflow_system_variables)["user_id"] == "session-id" def test_process_returns_stream_and_blocking_variants(self): pipeline = _make_pipeline() @@ -699,7 +700,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -727,7 +728,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -743,7 +744,7 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( @@ -815,7 +816,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) assert len(added) == count_before - def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + def test_save_output_for_event_writes_draft_variables(self): pipeline = _make_pipeline() saver_calls: list[tuple[object, object]] = [] captured_factory_args: dict[str, object] = {} @@ -828,29 +829,7 @@ class TestWorkflowGenerateTaskPipeline: captured_factory_args.update(kwargs) return _Saver() - class _Begin: - def __enter__(self): - return None - - def __exit__(self, exc_type, exc, tb): - return False - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return _Begin() - pipeline._draft_var_saver_factory = _factory - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) event = QueueNodeSucceededEvent( node_execution_id="exec-id", diff --git a/api/tests/unit_tests/core/app/entities/test_queue_entities.py b/api/tests/unit_tests/core/app/entities/test_queue_entities.py new file mode 100644 index 0000000000..7c21b00966 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_queue_entities.py @@ -0,0 +1,12 @@ +from core.app.entities.queue_entities import QueueStopEvent + + +class TestQueueEntities: + def test_get_stop_reason_for_known_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + assert event.get_stop_reason() == "Stopped by user." + + def test_get_stop_reason_for_unknown_stop_by(self): + event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + event.stopped_by = "unknown" + assert event.get_stop_reason() == "Stopped by unknown reason." diff --git a/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py new file mode 100644 index 0000000000..1e0ef6d6d6 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_rag_pipeline_invoke_entities.py @@ -0,0 +1,17 @@ +from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity + + +class TestRagPipelineInvokeEntity: + def test_defaults_and_fields(self): + entity = RagPipelineInvokeEntity( + pipeline_id="pipe-1", + application_generate_entity={"foo": "bar"}, + user_id="user-1", + tenant_id="tenant-1", + workflow_id="workflow-1", + streaming=True, + ) + + assert entity.workflow_execution_id is None + assert entity.workflow_thread_pool_id is None + assert entity.streaming is True diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py new file mode 100644 index 0000000000..7c79780641 --- /dev/null +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -0,0 +1,78 @@ +from core.app.entities.task_entities import ( + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + StreamEvent, +) +from graphon.enums import WorkflowNodeExecutionStatus + + +class TestTaskEntities: + def test_node_start_to_ignore_detail_dict(self): + data = NodeStartStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + created_at=1, + ) + response = NodeStartStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_STARTED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["extras"] == {} + + def test_node_finish_to_ignore_detail_dict(self): + data = NodeFinishStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ) + response = NodeFinishStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_FINISHED.value + assert payload["data"]["inputs"] is None + assert payload["data"]["outputs"] is None + assert payload["data"]["files"] == [] + + def test_node_retry_to_ignore_detail_dict(self): + data = NodeRetryStreamResponse.Data( + id="exec-1", + node_id="node-1", + node_type="answer", + title="Answer", + index=1, + predecessor_node_id=None, + inputs={"foo": "bar"}, + process_data={"step": 1}, + outputs={"answer": "ok"}, + status=WorkflowNodeExecutionStatus.RETRY, + elapsed_time=0.1, + created_at=1, + finished_at=2, + retry_index=2, + ) + response = NodeRetryStreamResponse(task_id="task-1", workflow_run_id="run-1", data=data) + + payload = response.to_ignore_detail_dict() + + assert payload["event"] == StreamEvent.NODE_RETRY.value + assert payload["data"]["retry_index"] == 2 + assert payload["data"]["outputs"] is None diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c72..538b130cac 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" diff --git a/api/tests/unit_tests/core/app/features/test_annotation_reply.py b/api/tests/unit_tests/core/app/features/test_annotation_reply.py new file mode 100644 index 0000000000..e721a77079 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_annotation_reply.py @@ -0,0 +1,163 @@ +import logging +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature + + +class TestAnnotationReplyFeature: + def test_query_returns_none_when_setting_missing(self): + feature = AnnotationReplyFeature() + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = None + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_none_when_binding_missing(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace(collection_binding_detail=None) + + with patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db: + mock_db.session.scalar.return_value = annotation_setting + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + + def test_query_returns_annotation_and_records_history_for_api(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.8}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result == annotation + mock_annotation_service.add_annotation_history.assert_called_once() + _, _, _, _, _, _, _, from_source, score = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "api" + assert score == 0.8 + + def test_query_returns_annotation_and_records_history_for_console(self): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=0.5, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + dataset_binding = SimpleNamespace(id="binding-1") + annotation = SimpleNamespace( + id="ann-1", + question_text="question", + content="content", + account_id="acct-1", + account=None, + ) + document = SimpleNamespace(metadata={"annotation_id": "ann-1", "score": 0.6}) + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [document] + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + patch( + "core.app.features.annotation_reply.annotation_reply.AppAnnotationService" + ) as mock_annotation_service, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = dataset_binding + mock_vector.return_value = vector_instance + mock_annotation_service.get_annotation_by_id.return_value = annotation + + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.EXPLORE, + ) + + assert result == annotation + _, _, _, _, _, _, _, from_source, _ = mock_annotation_service.add_annotation_history.call_args[0] + assert from_source == "console" + + def test_query_logs_and_returns_none_on_exception(self, caplog): + feature = AnnotationReplyFeature() + annotation_setting = SimpleNamespace( + score_threshold=None, + collection_binding_detail=SimpleNamespace(provider_name="prov", model_name="model"), + ) + + with ( + patch("core.app.features.annotation_reply.annotation_reply.db") as mock_db, + patch( + "core.app.features.annotation_reply.annotation_reply.DatasetCollectionBindingService" + ) as mock_binding_service, + patch("core.app.features.annotation_reply.annotation_reply.Vector") as mock_vector, + ): + mock_db.session.scalar.return_value = annotation_setting + mock_binding_service.get_dataset_collection_binding.return_value = SimpleNamespace(id="binding-1") + mock_vector.return_value.search_by_vector.side_effect = RuntimeError("boom") + + with caplog.at_level(logging.WARNING): + result = feature.query( + app_record=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + message=SimpleNamespace(id="msg-1"), + query="hi", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert result is None + assert "Query annotation failed" in caplog.text diff --git a/api/tests/unit_tests/core/app/features/test_hosting_moderation.py b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py new file mode 100644 index 0000000000..01194c16f5 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/test_hosting_moderation.py @@ -0,0 +1,30 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature + + +class TestHostingModerationFeature: + def test_check_aggregates_text_and_calls_moderation(self): + application_generate_entity = Mock() + application_generate_entity.model_conf = {"model": "mock"} + application_generate_entity.app_config = SimpleNamespace(tenant_id="tenant-1") + + prompt_messages = [ + SimpleNamespace(content="hello"), + SimpleNamespace(content=123), + SimpleNamespace(content="world"), + ] + + with patch("core.app.features.hosting_moderation.hosting_moderation.moderation.check_moderation") as mock_check: + mock_check.return_value = True + + feature = HostingModerationFeature() + result = feature.check(application_generate_entity, prompt_messages) + + assert result is True + mock_check.assert_called_once_with( + tenant_id="tenant-1", + model_config=application_generate_entity.model_conf, + text="hello\nworld\n", + ) diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index bdc889d941..28745a2091 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -3,16 +3,15 @@ from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringVariable -from dify_graph.variables.segments import Segment +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.protocols.command_channel import CommandChannel +from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment class MockReadOnlyVariablePool: @@ -36,31 +35,38 @@ def _build_graph_runtime_state( conversation_id: str | None = None, ) -> ReadOnlyGraphRuntimeState: graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + if conversation_id is not None: + variable_pool._variables[("sys", SystemVariableKey.CONVERSATION_ID.value)] = StringSegment( + value=conversation_id + ) graph_runtime_state.variable_pool = variable_pool - graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() return graph_runtime_state -def _build_node_run_succeeded_event( - *, - node_type: NodeType, - outputs: dict[str, object] | None = None, - process_data: dict[str, object] | None = None, -) -> NodeRunSucceededEvent: +def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="node-exec-id", node_id="assigner", - node_type=node_type, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.utcnow(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs or {}, - process_data=process_data or {}, + outputs={}, + process_data={}, ), ) -def test_persists_conversation_variables_from_assigner_output(): +def _build_variable_updated_event(variable: StringVariable) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id="node-exec-id", + node_id="assigner", + node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, + variable=variable, + ) + + +def test_persists_conversation_variables_from_variable_update_event(): conversation_id = "conv-123" variable = StringVariable( id="var-1", @@ -68,55 +74,26 @@ def test_persists_conversation_variables_from_assigner_output(): value="updated", selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(variable) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) - updater.flush.assert_called_once() -def test_skips_when_outputs_missing(): +def test_skips_non_variable_update_events(): conversation_id = "conv-456" - variable = StringVariable( - id="var-2", - name="name", - value="updated", - selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event() layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() - - -def test_skips_non_assigner_nodes(): - updater = Mock() - layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) - layer.on_event(event) - - updater.update.assert_not_called() - updater.flush.assert_not_called() def test_skips_non_conversation_variables(): @@ -127,18 +104,11 @@ def test_skips_non_conversation_variables(): value="updated", selector=["environment", "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] - ) - - variable_pool = MockReadOnlyVariablePool() - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(non_conversation_variable) layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035f0ee05c..92a7788f6e 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,17 +13,18 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import ( +from core.workflow.system_variables import SystemVariableKey +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from dify_graph.variables.segments import Segment +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from graphon.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -51,17 +52,6 @@ class TestDataFactory: return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) -class MockSystemVariableReadOnlyView: - """Minimal read-only system variable view for testing.""" - - def __init__(self, workflow_execution_id: str | None = None) -> None: - self._workflow_execution_id = workflow_execution_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - class MockReadOnlyVariablePool: """Mock implementation of ReadOnlyVariablePool for testing.""" @@ -76,13 +66,14 @@ class MockReadOnlyVariablePool: return None mock_segment = Mock(spec=Segment) mock_segment.value = value + mock_segment.text = value if isinstance(value, str) else None return mock_segment def get_all_by_node(self, node_id: str) -> dict[str, object]: return {key: value for (nid, key), value in self._variables.items() if nid == node_id} def get_by_prefix(self, prefix: str) -> dict[str, object]: - return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} class MockReadOnlyGraphRuntimeState: @@ -105,12 +96,10 @@ class MockReadOnlyGraphRuntimeState: self._ready_queue_size = ready_queue_size self._exceptions_count = exceptions_count self._outputs = outputs or {} - self._variable_pool = MockReadOnlyVariablePool(variables) - self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) - - @property - def system_variable(self) -> MockSystemVariableReadOnlyView: - return self._system_variable + resolved_variables = dict(variables or {}) + if workflow_execution_id is not None: + resolved_variables[("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value)] = workflow_execution_id + self._variable_pool = MockReadOnlyVariablePool(resolved_variables) @property def variable_pool(self) -> ReadOnlyVariablePool: @@ -161,7 +150,9 @@ class MockReadOnlyGraphRuntimeState: "exceptions_count": self._exceptions_count, "outputs": self._outputs, "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, - "workflow_execution_id": self._system_variable.workflow_execution_id, + "workflow_execution_id": self._variable_pool._variables.get( + ("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value) + ), } ) diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py new file mode 100644 index 0000000000..56705f1a7e --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -0,0 +1,19 @@ +from core.app.layers.suspend_layer import SuspendLayer +from graphon.graph_events.graph import GraphRunPausedEvent + + +class TestSuspendLayer: + def test_on_event_accepts_paused_event(self): + layer = SuspendLayer() + assert layer.is_paused() is False + layer.on_graph_start() + assert layer.is_paused() is False + layer.on_event(GraphRunPausedEvent()) + assert layer.is_paused() is True + + def test_on_event_ignores_other_events(self): + layer = SuspendLayer() + layer.on_graph_start() + initial_state = layer.is_paused() + layer.on_event(object()) + assert layer.is_paused() is initial_state diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py new file mode 100644 index 0000000000..1ac9a4d8c0 --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -0,0 +1,98 @@ +from unittest.mock import Mock, patch + +from core.app.layers.timeslice_layer import TimeSliceLayer +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import SchedulerCommand + + +class TestTimeSliceLayer: + def test_init_starts_scheduler_when_not_running(self): + scheduler = Mock() + scheduler.running = False + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + _ = TimeSliceLayer(cfs_plan_scheduler=Mock(plan=Mock())) + + scheduler.start.assert_called_once() + + def test_on_graph_start_adds_job_for_time_slice(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.uuid.uuid4") as mock_uuid, + ): + mock_uuid.return_value.hex = "job-1" + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.on_graph_start() + + assert layer.schedule_id == "job-1" + scheduler.add_job.assert_called_once() + + def test_on_graph_end_removes_job(self): + scheduler = Mock() + scheduler.running = True + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=3, + ) + cfs_plan_scheduler = Mock(plan=plan) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.schedule_id = "job-1" + layer.on_graph_end(None) + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_removes_when_stopped(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.stopped = True + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + + def test_checker_job_handles_resource_limit_without_command_channel(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with ( + patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler), + patch("core.app.layers.timeslice_layer.logger") as mock_logger, + ): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + mock_logger.exception.assert_called_once() + + def test_checker_job_sends_pause_command(self): + scheduler = Mock() + scheduler.running = True + cfs_plan_scheduler = Mock(plan=Mock()) + cfs_plan_scheduler.can_schedule.return_value = SchedulerCommand.RESOURCE_LIMIT_REACHED + + with patch("core.app.layers.timeslice_layer.TimeSliceLayer.scheduler", scheduler): + layer = TimeSliceLayer(cfs_plan_scheduler=cfs_plan_scheduler) + layer.command_channel = Mock() + layer._checker_job("job-1") + + scheduler.remove_job.assert_called_once_with("job-1") + layer.command_channel.send_command.assert_called_once() + sent_command = layer.command_channel.send_command.call_args[0][0] + assert isinstance(sent_command, GraphEngineCommand) + assert sent_command.command_type == CommandType.PAUSE diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py new file mode 100644 index 0000000000..ecc431936c --- /dev/null +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -0,0 +1,108 @@ +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.layers.trigger_post_layer import TriggerPostLayer +from core.workflow.system_variables import build_system_variables +from graphon.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool +from models.enums import WorkflowTriggerStatus + + +class TestTriggerPostLayer: + def test_on_event_updates_trigger_log(self): + trigger_log = SimpleNamespace( + status=None, + workflow_run_id=None, + outputs=None, + elapsed_time=None, + total_tokens=None, + finished_at=None, + ) + runtime_state = SimpleNamespace( + outputs={"answer": "ok"}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=12, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.datetime") as mock_datetime, + ): + mock_datetime.now.return_value = datetime(2026, 2, 20, tzinfo=UTC) + + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = trigger_log + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC) - timedelta(seconds=10), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunSucceededEvent()) + + assert trigger_log.status == WorkflowTriggerStatus.SUCCEEDED + assert trigger_log.workflow_run_id == "run-1" + assert trigger_log.outputs is not None + assert trigger_log.elapsed_time is not None + assert trigger_log.total_tokens == 12 + assert trigger_log.finished_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + + def test_on_event_handles_missing_trigger_log(self): + runtime_state = SimpleNamespace( + outputs={}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=0, + ) + + with ( + patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory, + patch("core.app.layers.trigger_post_layer.SQLAlchemyWorkflowTriggerLogRepository") as mock_repo_cls, + patch("core.app.layers.trigger_post_layer.logger") as mock_logger, + ): + session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = session + + repo = Mock() + repo.get_by_id.return_value = None + mock_repo_cls.return_value = repo + + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="missing", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(GraphRunFailedEvent(error="boom")) + + mock_logger.exception.assert_called_once() + session.commit.assert_not_called() + + def test_on_event_ignores_non_status_events(self): + runtime_state = SimpleNamespace( + outputs={}, + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + total_tokens=0, + ) + + with patch("core.app.layers.trigger_post_layer.session_factory") as mock_session_factory: + layer = TriggerPostLayer( + cfs_plan_scheduler_entity=Mock(), + start_time=datetime(2026, 2, 20, tzinfo=UTC), + trigger_log_id="log-1", + ) + layer.initialize(runtime_state, Mock()) + + layer.on_event(Mock()) + + mock_session_factory.create_session.assert_not_called() diff --git a/api/tests/unit_tests/core/app/task_pipeline/__init__.py b/api/tests/unit_tests/core/app/task_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py new file mode 100644 index 0000000000..c246f7b783 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.entities.queue_entities import QueueErrorEvent +from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline +from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from models.enums import MessageStatus + + +class TestBasedGenerateTaskPipeline: + @pytest.fixture + def pipeline(self): + app_config = SimpleNamespace( + tenant_id="tenant-1", + app_id="app-1", + sensitive_word_avoidance=None, + ) + app_generate_entity = SimpleNamespace(task_id="task-1", app_config=app_config) + return BasedGenerateTaskPipeline( + application_generate_entity=app_generate_entity, + queue_manager=Mock(), + stream=True, + ) + + def test_error_to_desc_quota_exceeded(self, pipeline): + message = pipeline._error_to_desc(QuotaExceededError()) + assert "quota" in message.lower() + + def test_handle_error_wraps_invoke_authorization(self, pipeline): + event = QueueErrorEvent(error=InvokeAuthorizationError()) + err = pipeline.handle_error(event=event) + assert isinstance(err, InvokeAuthorizationError) + assert str(err) == "Incorrect API key provided" + + def test_handle_error_preserves_invoke_error(self, pipeline): + event = QueueErrorEvent(error=InvokeError("bad")) + err = pipeline.handle_error(event=event) + assert err is event.error + + def test_handle_error_updates_message_when_found(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + message = SimpleNamespace(status=MessageStatus.NORMAL, error=None) + session = Mock() + session.scalar.return_value = message + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + assert message.status == MessageStatus.ERROR + assert message.error == "oops" + + def test_handle_error_returns_err_when_message_missing(self, pipeline): + event = QueueErrorEvent(error=ValueError("oops")) + session = Mock() + session.scalar.return_value = None + + err = pipeline.handle_error(event=event, session=session, message_id="msg-1") + + assert err is event.error + + def test_error_to_stream_response_and_ping(self, pipeline): + error_response = pipeline.error_to_stream_response(ValueError("boom")) + ping_response = pipeline.ping_stream_response() + + assert error_response.task_id == "task-1" + assert ping_response.task_id == "task-1" + + def test_handle_output_moderation_when_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("filtered", True) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result == "filtered" + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None + + def test_handle_output_moderation_when_not_flagged(self, pipeline): + handler = Mock() + handler.moderation_completion.return_value = ("safe", False) + pipeline.output_moderation_handler = handler + + result = pipeline.handle_output_moderation_when_task_finished("raw") + + assert result is None + handler.stop_thread.assert_called_once() + assert pipeline.output_moderation_handler is None diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 13fbca6e26..1c1bf391d3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -26,8 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py new file mode 100644 index 0000000000..ea000f3886 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -0,0 +1,1228 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from core.app.app_config.entities import ( + AppAdditionalFeatures, + EasyUIBasedAppConfig, + EasyUIBasedAppModelConfigFrom, + ModelConfigEntity, + PromptTemplateEntity, +) +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, CompletionAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueLLMChunkEvent, + QueueMessageEndEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) +from core.app.entities.task_entities import ( + ChatbotAppStreamResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from core.base.tts import AudioTrunk +from graphon.file.enums import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent +from models.model import AppMode + + +class _DummyModelConf: + def __init__(self) -> None: + self.model = "mock" + + +def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig: + return EasyUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=app_mode, + app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG, + app_model_config_id="model-config", + app_model_config_dict={}, + model=ModelConfigEntity(provider="mock", model="mock"), + prompt_template=PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ), + additional_features=AppAdditionalFeatures(), + variables=[], + ) + + +def _make_entity(entity_cls, app_mode: AppMode): + app_config = _make_app_config(app_mode) + return entity_cls.model_construct( + task_id="task", + app_config=app_config, + model_conf=_DummyModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + +class TestEasyUiBasedGenerateTaskPipeline: + def test_to_blocking_response_chat(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_to_blocking_response_completion(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.message.content = "answer" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="msg") + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "answer" + + def test_listen_audio_msg_returns_none_when_no_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_process_stream_response_handles_chunks_and_end(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[TextPromptMessageContent(data="hi"), TextPromptMessageContent(data="yo")] + ), + ), + ) + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + events = [ + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + SimpleNamespace(event=QueueMessageReplaceEvent(text="replace", reason="output_moderation")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueMessageEndEvent(llm_result=llm_result)), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: "chunk" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline.handle_output_moderation_when_task_finished = lambda completion: None + pipeline._message_end_to_stream_response = lambda: "end" + pipeline._save_message = lambda **kwargs: None + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "chunk" in responses + assert "replace" in responses + assert any(isinstance(item, PingStreamResponse) for item in responses) + assert responses[-1] == "end" + + def test_handle_output_moderation_chunk_directs_output(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline.output_moderation_handler = _Moderation() + pipeline.queue_manager.publish = lambda event, publish_from: events.append(event) + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is True + assert any(isinstance(event, QueueLLMChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_stop_updates_usage(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + class _ModelType: + def calc_response_usage(self, model, credentials, prompt_tokens, completion_tokens): + return LLMUsage.from_metadata( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + ) + + class _ModelConf: + def __init__(self) -> None: + self.model = "mock" + self.credentials = {} + self.provider_model_bundle = SimpleNamespace(model_type_instance=_ModelType()) + + app_config = _make_app_config(AppMode.CHAT) + application_generate_entity = ChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + model_conf=_ModelConf(), + file_upload_config=None, + conversation_id=None, + inputs={}, + query="hello", + files=[], + parent_message_id=None, + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + call_depth=0, + trace_manager=None, + ) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content="answer") + + calls: list[int] = [] + + class _FakeModelInstance: + def __init__(self, provider_model_bundle, model): + pass + + def get_llm_num_tokens(self, messages): + calls.append(1) + return 10 if len(calls) == 1 else 5 + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.ModelInstance", + _FakeModelInstance, + ) + + pipeline._handle_stop(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) + + assert pipeline._task_state.llm_result.usage.prompt_tokens == 10 + assert pipeline._task_state.llm_result.usage.completion_tokens == 5 + + def test_record_files_builds_file_payloads(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + message_files = [ + SimpleNamespace( + id="mf-1", + message_id="msg", + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/a.png", + upload_file_id=None, + type="image", + ), + SimpleNamespace( + id="mf-2", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-1", + type="image", + ), + SimpleNamespace( + id="mf-3", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/file.bin", + upload_file_id=None, + type="file", + ), + ] + upload_files = [ + SimpleNamespace( + id="upload-1", + name="local.png", + mime_type="image/png", + size=123, + extension="png", + ) + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else upload_files) + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "signed-url", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "signed-tool", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files + assert len(files) == 3 + + def test_process_stream_response_handles_annotation_and_error(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="agent"), + ), + ) + + events = [ + SimpleNamespace(event=QueueAnnotationReplyEvent(message_annotation_id="ann")), + SimpleNamespace(event=QueueAgentThoughtEvent(agent_thought_id="thought")), + SimpleNamespace(event=QueueMessageFileEvent(message_file_id="file")), + SimpleNamespace(event=QueueAgentMessageEvent(chunk=agent_chunk)), + SimpleNamespace(event=QueueErrorEvent(error=ValueError("boom"))), + ] + + pipeline.queue_manager.listen = lambda: iter(events) + pipeline._message_cycle_manager.handle_annotation_reply = lambda event: SimpleNamespace(content="annotated") + pipeline._agent_thought_to_stream_response = lambda event: "thought" + pipeline._message_cycle_manager.message_file_to_stream_response = lambda event: "file" + pipeline._agent_message_to_stream_response = lambda **kwargs: "agent" + pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline.error_to_stream_response = lambda err: err + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert "thought" in responses + assert "file" in responses + assert "agent" in responses + assert isinstance(responses[-1], ValueError) + assert pipeline._task_state.llm_result.message.content == "annotatedagent" + + def test_agent_thought_to_stream_response_returns_payload(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + agent_thought = SimpleNamespace( + id="thought", + position=1, + thought="t", + observation="o", + tool="tool", + tool_labels={}, + tool_input="input", + files=[], + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return agent_thought + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", + _Session, + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="thought")) + + assert response is not None + assert response.id == "thought" + + def test_process_routes_to_stream_and_starts_conversation_name_generation(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock(return_value=object()) + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_stream_response = lambda generator: "streamed" + + result = pipeline.process() + + assert result == "streamed" + pipeline._message_cycle_manager.generate_conversation_name.assert_called_once_with( + conversation_id="conv", query="hello" + ) + + def test_process_routes_to_blocking_for_completion_mode(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._message_cycle_manager.generate_conversation_name = Mock() + pipeline._wrapper_process_stream_response = lambda trace_manager: iter(["payload"]) + pipeline._to_blocking_response = lambda generator: "blocking" + + result = pipeline.process() + + assert result == "blocking" + pipeline._message_cycle_manager.generate_conversation_name.assert_not_called() + + def test_to_blocking_response_raises_error_stream_exception(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("stream error")) + + with pytest.raises(ValueError, match="stream error"): + pipeline._to_blocking_response(_gen()) + + def test_to_blocking_response_raises_when_generator_ends_without_message_end(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(RuntimeError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_gen()) + + def test_to_stream_response_wraps_completion_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.COMPLETION) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(CompletionAppGenerateEntity, AppMode.COMPLETION), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, CompletionAppStreamResponse) + assert response.message_id == "msg" + + def test_to_stream_response_wraps_chat_stream_events(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + def _gen(): + yield PingStreamResponse(task_id="task") + + response = list(pipeline._to_stream_response(_gen()))[0] + + assert isinstance(response, ChatbotAppStreamResponse) + assert response.conversation_id == "conv" + + def test_listen_audio_msg_returns_audio_response_for_non_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("responding", "abc")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + assert response.audio == "abc" + + def test_listen_audio_msg_returns_none_for_finish_audio(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk("finish", "abc")) + + assert pipeline._listen_audio_msg(publisher=publisher, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts_publisher(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses == ["payload"] + + def test_wrapper_process_stream_response_with_tts_publisher(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def check_and_get_audio(self): + return AudioTrunk("finish", "") + + inline_audio = MessageAudioStreamResponse(task_id="task", audio="inline") + audio_calls = iter([inline_audio, None]) + pipeline._listen_audio_msg = lambda publisher, task_id: next(audio_calls) + pipeline._process_stream_response = lambda publisher, trace_manager: iter(["payload"]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert responses[0] == inline_audio + assert responses[1] == "payload" + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_wrapper_process_stream_response_timeout_yields_audio_chunk(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + entity = _make_entity(ChatAppGenerateEntity, AppMode.CHAT) + entity.app_config.app_model_config_dict = { + "text_to_speech": {"autoPlay": "enabled", "enabled": True, "voice": "v", "language": "en"} + } + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=entity, + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Publisher: + def __init__(self): + self._events = iter([None, AudioTrunk("responding", "later"), AudioTrunk("finish", "")]) + + def check_and_get_audio(self): + return next(self._events) + + clock = {"value": 0.0} + + def _fake_time(): + clock["value"] += 0.1 + return clock["value"] + + pipeline._process_stream_response = lambda publisher, trace_manager: iter([]) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.AppGeneratorTTSPublisher", + lambda tenant_id, voice, language: _Publisher(), + ) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.time", _fake_time) + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.sleep", lambda _: None) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert isinstance(responses[-1], MessageAudioEndStreamResponse) + + def test_process_stream_response_handles_stop_event_and_output_replacement(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._task_state.llm_result.message.content = "raw answer" + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_stop = Mock() + pipeline.handle_output_moderation_when_task_finished = lambda answer: "moderated answer" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda answer: f"replace:{answer}" + pipeline._save_message = lambda **kwargs: None + pipeline._message_end_to_stream_response = lambda: "end" + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["replace:moderated answer", "end"] + pipeline._handle_stop.assert_called_once() + + def test_process_stream_response_handles_retriever_unknown_and_empty_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=None)), + ) + handled = {"retriever": 0} + + def _handle_retriever_resources(event): + handled["retriever"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _handle_retriever_resources + pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=retriever_event), + SimpleNamespace(event=SimpleNamespace()), + SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk)), + ] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + assert handled["retriever"] == 1 + + def test_process_stream_response_skips_when_output_moderation_directs_chunk(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = LLMResultChunk( + model="mock", + prompt_messages=[], + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="x")), + ) + pipeline._handle_output_moderation_chunk = lambda text: True + pipeline.queue_manager.listen = lambda: iter([SimpleNamespace(event=QueueLLMChunkEvent(chunk=chunk))]) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == [] + + def test_process_stream_response_ignores_unsupported_chunk_content_types(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + chunk = SimpleNamespace( + prompt_messages=[], + delta=SimpleNamespace(message=SimpleNamespace(content=[object(), "ok"])), + ) + pipeline._message_cycle_manager.get_message_event_type = lambda message_id: None + pipeline._message_cycle_manager.message_to_stream_response = lambda **kwargs: kwargs["answer"] + pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueLLMChunkEvent.model_construct(chunk=chunk))] + ) + + responses = list(pipeline._process_stream_response(publisher=None)) + + assert responses == ["ok"] + + def test_process_stream_response_reaches_post_loop_branch_with_thread_reference(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + pipeline._conversation_name_generate_thread = object() + pipeline.queue_manager.listen = lambda: iter([]) + + assert list(pipeline._process_stream_response(publisher=None)) == [] + + def test_save_message_persists_fields_and_emits_trace(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline.start_at = 10.0 + pipeline._model_config = SimpleNamespace(mode="chat") + pipeline._task_state.llm_result.prompt_messages = [AssistantPromptMessage(content="prompt")] + pipeline._task_state.llm_result.message = AssistantPromptMessage(content=" {{name}} hello ") + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata( + {"prompt_tokens": 3, "completion_tokens": 5, "total_price": "1.23"} + ) + + message_obj = SimpleNamespace(id="msg") + conversation_obj = SimpleNamespace(id="conv") + session = Mock() + session.scalar.side_effect = [message_obj, conversation_obj] + trace_manager = SimpleNamespace(add_trace_task=Mock()) + sent_payloads: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptMessageUtil.prompt_messages_to_prompt_for_saving", + lambda mode, prompt_messages: "serialized-prompt", + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.PromptTemplateParser.remove_template_variables", + lambda text: text.replace("{{name}}", "").strip(), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.naive_utc_now", + lambda: datetime(2024, 1, 1, tzinfo=UTC), + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.time.perf_counter", lambda: 15.0 + ) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.message_was_created.send", + lambda *args, **kwargs: sent_payloads.append((args, kwargs)), + ) + + pipeline._save_message(session=session, trace_manager=trace_manager) + + assert message_obj.message == "serialized-prompt" + assert message_obj.answer == "hello" + assert message_obj.provider_response_latency == 5.0 + assert trace_manager.add_trace_task.called + assert len(sent_payloads) == 1 + + def test_save_message_raises_when_message_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.return_value = None + + with pytest.raises(ValueError, match="message msg not found"): + pipeline._save_message(session=session) + + def test_save_message_raises_when_conversation_not_found(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + session = Mock() + session.scalar.side_effect = [SimpleNamespace(id="msg"), None] + + with pytest.raises(ValueError, match="Conversation conv not found"): + pipeline._save_message(session=session) + + def test_message_end_to_stream_response_includes_usage_metadata(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + pipeline._task_state.llm_result.usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 2}) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.id == "msg" + assert response.metadata["usage"]["prompt_tokens"] == 1 + + def test_record_files_returns_none_when_message_has_no_files(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + class _Result: + def all(self): + return [] + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + return _Result() + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._message_end_to_stream_response() + + assert response.files is None + + def test_record_files_handles_local_fallback_and_tool_url_variants(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + message_files = [ + SimpleNamespace( + id="mf-local-fallback", + message_id="msg", + transfer_method=FileTransferMethod.LOCAL_FILE, + url="", + upload_file_id="upload-missing", + type="file", + ), + SimpleNamespace( + id="mf-tool-http", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="http://cdn.example.com/file.txt?x=1", + upload_file_id=None, + type="file", + ), + SimpleNamespace( + id="mf-tool-noext", + message_id="msg", + transfer_method=FileTransferMethod.TOOL_FILE, + url="tool/path/toolid", + upload_file_id=None, + type="file", + ), + ] + + class _Result: + def __init__(self, items): + self._items = items + + def all(self): + return self._items + + class _Session: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, *args, **kwargs): + self.calls += 1 + return _Result(message_files if self.calls == 1 else []) + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url", + lambda **kwargs: "local-fallback-signed", + ) + monkeypatch.setattr( + "core.app.task_pipeline.message_file_utils.sign_tool_file", + lambda **kwargs: "tool-signed", + ) + + response = pipeline._message_end_to_stream_response() + files = response.files + + assert files is not None + assert files[0]["url"] == "local-fallback-signed" + assert files[1]["filename"] == "file.txt" + assert files[2]["extension"] == ".bin" + + def test_agent_message_to_stream_response_builds_payload(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + response = pipeline._agent_message_to_stream_response(answer="hello", message_id="msg") + + assert response.id == "msg" + assert response.answer == "hello" + + def test_agent_thought_to_stream_response_returns_none_when_not_found(self, monkeypatch): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + + class _Session: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *args, **kwargs): + return self + + def where(self, *args, **kwargs): + return self + + def first(self): + return None + + monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) + monkeypatch.setattr( + "core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", + SimpleNamespace(engine=object()), + ) + + response = pipeline._agent_thought_to_stream_response(QueueAgentThoughtEvent(agent_thought_id="missing")) + + assert response is None + + def test_handle_output_moderation_chunk_appends_token_when_not_directing(self): + conversation = SimpleNamespace(id="conv", mode=AppMode.CHAT) + message = SimpleNamespace(id="msg", created_at=datetime.now(UTC)) + pipeline = EasyUIBasedGenerateTaskPipeline( + application_generate_entity=_make_entity(ChatAppGenerateEntity, AppMode.CHAT), + queue_manager=SimpleNamespace(), + conversation=conversation, + message=message, + stream=True, + ) + appended_tokens: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + appended_tokens.append(text) + + pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("next-token") + + assert result is False + assert appended_tokens == ["next-token"] diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 582990c88a..abfbcdb941 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -21,7 +21,7 @@ from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod +from graphon.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile @@ -51,7 +51,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.LOCAL_FILE message_file.upload_file_id = str(uuid.uuid4()) message_file.url = None - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -63,7 +63,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.REMOTE_URL message_file.upload_file_id = None message_file.url = "https://example.com/image.jpg" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture @@ -75,7 +75,7 @@ class TestMessageEndStreamResponseFiles: message_file.transfer_method = FileTransferMethod.TOOL_FILE message_file.upload_file_id = None message_file.url = "tool_file_123.png" - message_file.type = "image" + message_file.type = FileType.IMAGE return message_file @pytest.fixture diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_exc.py b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py new file mode 100644 index 0000000000..9ea7e96e73 --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_exc.py @@ -0,0 +1,11 @@ +from core.app.task_pipeline.exc import RecordNotFoundError, WorkflowRunNotFoundError + + +class TestTaskPipelineExceptions: + def test_record_not_found_error_message(self): + err = RecordNotFoundError("Message", "msg-1") + assert str(err) == "Message with id msg-1 not found" + + def test_workflow_run_not_found_error_message(self): + err = WorkflowRunNotFoundError("run-1") + assert str(err) == "WorkflowRun with id run-1 not found" diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index c0c636715d..07ee75ed35 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -1,12 +1,16 @@ """Unit tests for the message cycle manager optimization.""" +from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from flask import current_app +from flask import Flask, current_app -from core.app.entities.task_entities import MessageStreamResponse, StreamEvent +from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent +from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from models.model import AppMode class TestMessageCycleManagerOptimization: @@ -90,6 +94,16 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE mock_session.scalar.assert_called_once() + def test_get_message_event_type_uses_cache_without_query(self, message_cycle_manager): + """Return MESSAGE_FILE directly from in-memory cache without opening a DB session.""" + message_cycle_manager._message_has_file.add("cached-message") + + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + result = message_cycle_manager.get_message_event_type("cached-message") + + assert result == StreamEvent.MESSAGE_FILE + mock_session_factory.create_session.assert_not_called() + def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager): """MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -180,3 +194,390 @@ class TestMessageCycleManagerOptimization: assert chunk2_response.event == StreamEvent.MESSAGE assert chunk1_response.answer == "Chunk 1" assert chunk2_response.answer == "Chunk 2" + + def test_generate_conversation_name_returns_none_for_completion(self, message_cycle_manager): + """Return None when completion entities are used for conversation naming. + + Args: message_cycle_manager with DummyCompletion injected as CompletionAppGenerateEntity. + Returns: None, indicating no name generation for completion apps. + Side effects: None expected. + """ + + class DummyCompletion: + pass + + with patch("core.app.task_pipeline.message_cycle_manager.CompletionAppGenerateEntity", DummyCompletion): + message_cycle_manager._application_generate_entity = DummyCompletion() + result = message_cycle_manager.generate_conversation_name(conversation_id="c1", query="hi") + + assert result is None + + def test_generate_conversation_name_starts_thread_and_flips_first_message_flag(self, message_cycle_manager): + """Spawn background generation thread for the first chat message.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": True} + flask_app = object() + + class DummyTimer: + def __init__(self, interval, function, args=None, kwargs=None): + self.interval = interval + self.function = function + self.args = args or [] + self.kwargs = kwargs + self.daemon = False + self.started = False + + def start(self): + self.started = True + + with ( + patch( + "core.app.task_pipeline.message_cycle_manager.current_app", + new=SimpleNamespace(_get_current_object=lambda: flask_app), + ), + patch("core.app.task_pipeline.message_cycle_manager.Timer", DummyTimer), + ): + thread = message_cycle_manager.generate_conversation_name(conversation_id="conv-1", query="hello") + + assert isinstance(thread, DummyTimer) + assert thread.interval == 1 + assert thread.function == message_cycle_manager._generate_conversation_name_worker + assert thread.started is True + assert thread.daemon is True + assert thread.kwargs["flask_app"] is flask_app + assert thread.kwargs["conversation_id"] == "conv-1" + assert thread.kwargs["query"] == "hello" + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + + def test_generate_conversation_name_skips_thread_when_auto_generate_disabled(self, message_cycle_manager): + """Skip thread creation when auto naming is disabled but still mark conversation as not new.""" + message_cycle_manager._application_generate_entity.is_new_conversation = True + message_cycle_manager._application_generate_entity.extras = {"auto_generate_conversation_name": False} + + with patch("core.app.task_pipeline.message_cycle_manager.Timer") as mock_timer: + result = message_cycle_manager.generate_conversation_name(conversation_id="conv-2", query="hello") + + assert result is None + assert message_cycle_manager._application_generate_entity.is_new_conversation is False + mock_timer.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_conversation_missing(self, message_cycle_manager): + """Return early when the conversation cannot be found.""" + flask_app = Flask(__name__) + db_session = Mock() + db_session.scalar.return_value = None + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-missing", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_returns_when_app_missing(self, message_cycle_manager): + """Return early when non-completion conversation has no app relation.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace(mode=AppMode.CHAT, app=None, app_id="app-id") + db_session = Mock() + db_session.scalar.return_value = conversation + + with patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db: + mock_db.session = db_session + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + db_session.commit.assert_not_called() + db_session.close.assert_not_called() + + def test_generate_conversation_name_worker_uses_cached_name(self, message_cycle_manager): + """Use cached conversation name when present and avoid LLM call.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = b"cached-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "cached-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_llm_generator.generate_conversation_name.assert_not_called() + mock_redis.setex.assert_not_called() + + def test_generate_conversation_name_worker_generates_and_caches_name(self, message_cycle_manager): + """Generate conversation name and write it to redis cache on cache miss.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.return_value = "generated-title" + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", "hello") + + assert conversation.name == "generated-title" + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_generate_conversation_name_worker_falls_back_when_generation_fails(self, message_cycle_manager): + """Fallback to truncated query when LLM generation fails.""" + flask_app = Flask(__name__) + conversation = SimpleNamespace( + mode=AppMode.CHAT, + app=SimpleNamespace(tenant_id="tenant-1"), + app_id="app-id", + name="", + ) + db_session = Mock() + db_session.scalar.return_value = conversation + long_query = "q" * 60 + + with ( + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + patch("core.app.task_pipeline.message_cycle_manager.redis_client") as mock_redis, + patch("core.app.task_pipeline.message_cycle_manager.LLMGenerator") as mock_llm_generator, + patch("core.app.task_pipeline.message_cycle_manager.dify_config") as mock_dify_config, + patch("core.app.task_pipeline.message_cycle_manager.logger") as mock_logger, + ): + mock_db.session = db_session + mock_redis.get.return_value = None + mock_llm_generator.generate_conversation_name.side_effect = RuntimeError("generation failed") + mock_dify_config.DEBUG = True + + message_cycle_manager._generate_conversation_name_worker(flask_app, "conv-1", long_query) + + assert conversation.name == (long_query[:47] + "...") + db_session.commit.assert_called_once() + db_session.close.assert_called_once() + mock_logger.exception.assert_called_once() + + def test_handle_annotation_reply_sets_metadata(self, message_cycle_manager): + """Populate task metadata from annotation reply events. + + Args: message_cycle_manager with TaskStateMetadata and a mocked AppAnnotationService. + Returns: The fetched annotation object. + Side effects: Updates metadata.annotation_reply with id and account name. + """ + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + annotation = SimpleNamespace( + id="ann-1", + account_id="acct-1", + account=SimpleNamespace(name="Alice"), + ) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = annotation + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="ann-1") + ) + + assert result == annotation + assert message_cycle_manager._task_state.metadata.annotation_reply.id == "ann-1" + assert message_cycle_manager._task_state.metadata.annotation_reply.account.name == "Alice" + + def test_handle_annotation_reply_returns_none_when_missing(self, message_cycle_manager): + """Return None and keep metadata unchanged when annotation is not found.""" + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with patch("core.app.task_pipeline.message_cycle_manager.AppAnnotationService") as mock_service: + mock_service.get_annotation_by_id.return_value = None + + result = message_cycle_manager.handle_annotation_reply( + QueueAnnotationReplyEvent(message_annotation_id="missing") + ) + + assert result is None + assert message_cycle_manager._task_state.metadata.annotation_reply is None + + def test_handle_retriever_resources_merges_and_deduplicates(self, message_cycle_manager): + """Merge retriever resources, deduplicate, and preserve ordering positions. + + Args: message_cycle_manager with show_retrieve_source enabled and existing metadata. + Returns: None. + Side effects: Updates metadata.retriever_resources with unique items and positions. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + existing = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[existing])) + + duplicate = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + new_resource = RetrievalSourceMetadata(dataset_id="d2", document_id="doc2") + + event = QueueRetrieverResourcesEvent(retriever_resources=[duplicate, new_resource]) + message_cycle_manager.handle_retriever_resources(event) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 2 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[1].position == 2 + + def test_message_file_to_stream_response_builds_signed_url(self, message_cycle_manager): + """Build a stream response with a signed tool file URL. + + Args: message_cycle_manager with mocked Session/db and sign_tool_file. + Returns: MessageStreamResponse with signed url and belongs_to normalized to user. + Side effects: Calls sign_tool_file for tool file ids. + """ + message_cycle_manager._application_generate_entity.task_id = "task-1" + + message_file = SimpleNamespace( + id="file-1", + type="image", + belongs_to=None, + url="tool://file.verylongextension", + message_id="msg-1", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-url" + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="file-1")) + + assert response.url == "signed-url" + assert response.belongs_to == "user" + mock_sign.assert_called_once_with(tool_file_id="file", extension=".bin") + + def test_handle_retriever_resources_requires_features(self, message_cycle_manager): + """Raise when retriever resources are handled without feature config. + + Args: message_cycle_manager with additional_features unset and empty metadata. + Raises: ValueError when show_retrieve_source configuration is missing. + """ + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace(additional_features=None) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata()) + + with pytest.raises(ValueError): + message_cycle_manager.handle_retriever_resources(QueueRetrieverResourcesEvent(retriever_resources=[])) + + def test_handle_retriever_resources_skips_none_entries(self, message_cycle_manager): + """Ignore null resource entries while preserving valid resources.""" + message_cycle_manager._application_generate_entity.app_config = SimpleNamespace( + additional_features=SimpleNamespace(show_retrieve_source=True) + ) + message_cycle_manager._task_state = SimpleNamespace(metadata=TaskStateMetadata(retriever_resources=[])) + resource = RetrievalSourceMetadata(dataset_id="d1", document_id="doc1") + + message_cycle_manager.handle_retriever_resources(SimpleNamespace(retriever_resources=[None, resource])) + + assert len(message_cycle_manager._task_state.metadata.retriever_resources) == 1 + assert message_cycle_manager._task_state.metadata.retriever_resources[0].position == 1 + + def test_message_file_to_stream_response_uses_http_url_directly(self, message_cycle_manager): + """Use original URL when message file URL is already HTTP.""" + message_cycle_manager._application_generate_entity.task_id = "task-http" + message_file = SimpleNamespace( + id="file-http", + type="image", + belongs_to="assistant", + url="http://example.com/pic.png", + message_id="msg-http", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-http") + ) + + assert response is not None + assert response.url == "http://example.com/pic.png" + assert "msg-http" in message_cycle_manager._message_has_file + + def test_message_file_to_stream_response_defaults_extension_to_bin_without_dot(self, message_cycle_manager): + """Default tool file extension to .bin when URL has no extension part.""" + message_cycle_manager._application_generate_entity.task_id = "task-bin" + message_file = SimpleNamespace( + id="file-bin", + type="file", + belongs_to="assistant", + url="tool-file-id", + message_id="msg-bin", + ) + + session = Mock() + session.scalar.return_value = message_file + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.sign_tool_file") as mock_sign, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + mock_sign.return_value = "signed-bin-url" + + response = message_cycle_manager.message_file_to_stream_response( + SimpleNamespace(message_file_id="file-bin") + ) + + assert response is not None + assert response.url == "signed-bin-url" + mock_sign.assert_called_once_with(tool_file_id="tool-file-id", extension=".bin") + + def test_message_file_to_stream_response_returns_none_when_file_missing(self, message_cycle_manager): + """Return None when message file lookup does not find a record.""" + session = Mock() + session.scalar.return_value = None + + with ( + patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_cls, + patch("core.app.task_pipeline.message_cycle_manager.db") as mock_db, + ): + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = session + + response = message_cycle_manager.message_file_to_stream_response(SimpleNamespace(message_file_id="missing")) + + assert response is None + + def test_message_replace_to_stream_response_returns_reason(self, message_cycle_manager): + """Include the provided replacement reason in the stream payload.""" + response = message_cycle_manager.message_replace_to_stream_response("replaced", reason="moderation") + + assert response.answer == "replaced" + assert response.reason == "moderation" diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py new file mode 100644 index 0000000000..21c761c579 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -0,0 +1,57 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.entities import ModelConfigEntity +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from models.provider_ids import ModelProviderID + + +def test_validate_and_set_defaults_reuses_single_model_assembly(): + provider_name = str(ModelProviderID("openai")) + provider_entity = SimpleNamespace(provider=provider_name) + model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"}) + provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model]) + assembly = SimpleNamespace( + model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]), + provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations), + ) + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"stop": []}, + } + } + + with patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config) + + assert result["model"]["provider"] == provider_name + assert result["model"]["mode"] == "chat" + assert keys == ["model"] + mock_assembly.assert_called_once_with(tenant_id="tenant-1") + + +def test_convert_keeps_model_config_shape(): + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0.3, "stop": ["END"]}, + } + } + + result = ModelConfigManager.convert(config) + + assert result == ModelConfigEntity( + provider="openai", + model="gpt-4o-mini", + mode="chat", + parameters={"temperature": 0.3}, + stop=["END"], + ) diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 0f8a846d11..5c50cb78da 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -8,8 +8,8 @@ from core.app.workflow.layers.persistence import ( WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType -from dify_graph.node_events import NodeRunResult +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: @@ -58,3 +58,42 @@ def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.Mon assert node_execution.finished_at == event_finished_at assert node_execution.elapsed_time == 2.0 + + +def test_update_node_execution_projects_start_outputs() -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-2" + node_execution.node_type = BuiltinNodeTypes.START + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="start", + title="Start", + predecessor_node_id=None, + iteration_id=None, + loop_id=None, + created_at=node_execution.created_at, + ) + + layer._update_node_execution( + node_execution, + NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + }, + ), + WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + node_execution.update_from_mapping.assert_called_once_with( + inputs={"question": "hello"}, + process_data={}, + outputs={"question": "hello"}, + metadata={}, + ) diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py new file mode 100644 index 0000000000..cddd03f4b0 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope +from core.app.workflow import file_runtime +from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime +from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType +from models import ToolFile, UploadFile + + +def _build_file( + *, + transfer_method: FileTransferMethod, + reference: str | None = None, + remote_url: str | None = None, + extension: str | None = None, +) -> File: + return File( + id="file-id", + type=FileType.IMAGE, + transfer_method=transfer_method, + reference=reference, + remote_url=remote_url, + filename="diagram.png", + extension=extension, + mime_type="image/png", + size=128, + ) + + +def _build_runtime() -> DifyWorkflowFileRuntime: + return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()) + + +def test_resolve_file_url_returns_remote_url() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/diagram.png", + ) + + assert runtime.resolve_file_url(file=file) == "https://example.com/diagram.png" + + +def test_resolve_file_url_requires_file_reference() -> None: + runtime = _build_runtime() + file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None) + + with pytest.raises(ValueError, match="Missing file reference"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_requires_extension_for_tool_files() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=None, + ) + + with pytest.raises(ValueError, match="Missing file extension"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sign_tool_file = MagicMock(return_value="https://signed.example.com/file") + monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file) + runtime = _build_runtime() + + tool_file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=".png", + ) + datasource_file = _build_file( + transfer_method=FileTransferMethod.DATASOURCE_FILE, + reference=build_file_reference(record_id="datasource-file-id"), + extension=".png", + ) + + assert runtime.resolve_file_url(file=tool_file) == "https://signed.example.com/file" + assert runtime.resolve_file_url(file=datasource_file) == "https://signed.example.com/file" + assert sign_tool_file.call_count == 2 + + +def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr( + "core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", + "https://internal.example.com", + ) + + runtime = _build_runtime() + url = runtime.resolve_upload_file_url( + upload_file_id="upload-file-id", + as_attachment=True, + for_external=False, + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-file-id/file-preview" + assert query["as_attachment"] == ["true"] + assert query["timestamp"] == ["1700000000"] + + +def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60) + runtime = _build_runtime() + payload = "file-preview|upload-file-id|1700000000|nonce" + sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode() + + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is True + ) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000100) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is False + ) + + +def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: b"image-bytes") + + assert runtime.load_file_bytes(file=file) == b"image-bytes" + session.get.assert_called_with(UploadFile, "upload-file-id") + + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: "not-bytes") + with pytest.raises(ValueError, match="is not a bytes object"): + runtime.load_file_bytes(file=file) + + +def test_resolve_storage_key_ignores_encoded_reference_when_unscoped(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + session.get.assert_called_once_with(UploadFile, "upload-file-id") + + +def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key") + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id") + + +def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = None + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match="Upload file upload-file-id not found"): + runtime.resolve_upload_file_url(upload_file_id="upload-file-id") + + +@pytest.mark.parametrize( + ("transfer_method", "record_id", "expected_storage_key"), + [ + (FileTransferMethod.LOCAL_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.DATASOURCE_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.TOOL_FILE, "tool-file-id", "tool-storage-key"), + ], +) +def test_resolve_storage_key_loads_database_records( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + record_id: str, + expected_storage_key: str, +) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + + def get(model_class, value): + if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}: + assert model_class is UploadFile + return SimpleNamespace(key="upload-storage-key") + assert model_class is ToolFile + return SimpleNamespace(file_key="tool-storage-key") + + session.get.side_effect = get + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == expected_storage_key + + +@pytest.mark.parametrize( + ("transfer_method", "expected_message"), + [ + (FileTransferMethod.LOCAL_FILE, "Upload file upload-file-id not found"), + (FileTransferMethod.TOOL_FILE, "Tool file tool-file-id not found"), + ], +) +def test_resolve_storage_key_raises_when_records_are_missing( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + expected_message: str, +) -> None: + runtime = _build_runtime() + record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id" + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + session.get.return_value = None + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match=expected_message): + runtime._resolve_storage_key(file=file) + + +def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + runtime = _build_runtime() + + assert runtime.multimodal_send_format == "url" + + with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch.object(file_runtime.storage, "load", return_value=b"data") as mock_load: + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + +def test_bind_dify_workflow_file_runtime_registers_runtime(monkeypatch: pytest.MonkeyPatch) -> None: + set_runtime = MagicMock() + monkeypatch.setattr(file_runtime, "set_workflow_file_runtime", set_runtime) + + bind_dify_workflow_file_runtime() + + set_runtime.assert_called_once() + assert isinstance(set_runtime.call_args.args[0], DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py new file mode 100644 index 0000000000..c4bfb23272 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -0,0 +1,161 @@ +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.node_factory import DifyNodeFactory +from graphon.enums import BuiltinNodeTypes + + +class DummyNode: + def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = id + self.config = config + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self.kwargs = kwargs + + +class DummyCodeNode(DummyNode): + @classmethod + def default_code_providers(cls): + return () + + +class DummyTemplateTransformNode(DummyNode): + pass + + +class DummyHttpRequestNode(DummyNode): + pass + + +class DummyKnowledgeRetrievalNode(DummyNode): + pass + + +class DummyDocumentExtractorNode(DummyNode): + pass + + +class TestDifyNodeFactory: + @staticmethod + def _stub_node_resolution(monkeypatch, node_class): + monkeypatch.setattr( + "core.workflow.node_factory.resolve_workflow_node_class", + lambda **_kwargs: node_class, + ) + + def _factory(self, monkeypatch): + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_LENGTH", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER", 10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MIN_NUMBER", -10) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_PRECISION", 4) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_DEPTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_STRING_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH", 2) + monkeypatch.setattr("core.workflow.node_factory.dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH", 100) + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_URL", "http://u") + monkeypatch.setattr("core.workflow.node_factory.dify_config.UNSTRUCTURED_API_KEY", "key") + + run_context = build_dify_run_context( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + return DifyNodeFactory( + graph_init_params=SimpleNamespace(run_context=run_context), + graph_runtime_state=SimpleNamespace(), + ) + + def test_create_node_unknown_type(self, monkeypatch): + factory = self._factory(monkeypatch) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": "unknown"}}) + + def test_create_node_missing_mapping(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", lambda: {}) + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_missing_latest_class(self, monkeypatch): + factory = self._factory(monkeypatch) + monkeypatch.setattr( + "core.workflow.node_factory.get_node_type_classes_mapping", + lambda: {BuiltinNodeTypes.START: {"1": None}}, + ) + monkeypatch.setattr("core.workflow.node_factory.LATEST_VERSION", "latest") + + with pytest.raises(ValueError): + factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START}}) + + def test_create_node_selects_versioned_class(self, monkeypatch): + factory = self._factory(monkeypatch) + selected_versions: list[tuple[str, str]] = [] + + class DummyNodeV2(DummyNode): + pass + + def _get_mapping(): + selected_versions.append(("snapshot", "called")) + return {BuiltinNodeTypes.START: {"1": DummyNode, "2": DummyNodeV2}} + + monkeypatch.setattr("core.workflow.node_factory.get_node_type_classes_mapping", _get_mapping) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.START, "version": "2"}}) + + assert isinstance(node, DummyNodeV2) + assert node.id == "node-1" + assert selected_versions == [("snapshot", "called")] + + def test_create_node_code_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyCodeNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.CODE}}) + + assert isinstance(node, DummyCodeNode) + assert node.id == "node-1" + + def test_create_node_template_transform_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyTemplateTransformNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) + + assert isinstance(node, DummyTemplateTransformNode) + assert "jinja2_template_renderer" in node.kwargs + + def test_create_node_http_request_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyHttpRequestNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.HTTP_REQUEST}}) + + assert isinstance(node, DummyHttpRequestNode) + assert "http_request_config" in node.kwargs + + def test_create_node_knowledge_retrieval_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyKnowledgeRetrievalNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}) + + assert isinstance(node, DummyKnowledgeRetrievalNode) + assert node.kwargs == {} + + def test_create_node_document_extractor_branch(self, monkeypatch): + factory = self._factory(monkeypatch) + self._stub_node_resolution(monkeypatch, DummyDocumentExtractorNode) + + node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.DOCUMENT_EXTRACTOR}}) + + assert isinstance(node, DummyDocumentExtractorNode) + assert "unstructured_api_config" in node.kwargs diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py new file mode 100644 index 0000000000..82552470a9 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes + + +class TestObservabilityLayerExtras: + def test_init_tracer_enabled_sets_tracer(self, monkeypatch): + tracer = object() + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", lambda _: tracer) + + layer = ObservabilityLayer() + + assert layer._is_disabled is False + assert layer._tracer is tracer + + def test_init_tracer_disables_when_get_tracer_fails(self, monkeypatch, caplog): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + def _raise(*_args, **_kwargs): + raise RuntimeError("tracer init failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.get_tracer", _raise) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + assert layer._tracer is None + assert "Failed to get OpenTelemetry tracer" in caplog.text + + def test_init_tracer_disables_when_otel_disabled(self, monkeypatch): + monkeypatch.setattr("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) + monkeypatch.setattr("core.app.workflow.layers.observability.is_instrument_flag_enabled", lambda: False) + + layer = ObservabilityLayer() + + assert layer._is_disabled is True + + def test_get_parser_uses_registry_when_node_type_matches(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=BuiltinNodeTypes.TOOL)) + + assert parser is layer._parsers[BuiltinNodeTypes.TOOL] + + def test_get_parser_defaults_when_node_type_missing(self): + layer = ObservabilityLayer() + + parser = layer._get_parser(SimpleNamespace(node_type=None)) + + assert parser is layer._default_parser + + def test_on_graph_start_clears_contexts(self): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_start() + + assert layer._node_contexts == {} + + def test_on_event_is_noop(self): + layer = ObservabilityLayer() + + layer.on_event(object()) + + def test_on_graph_end_clears_unfinished_contexts(self, caplog): + layer = ObservabilityLayer() + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_graph_end(error=None) + + assert layer._node_contexts == {} + assert "node spans were not properly ended" in caplog.text + + def test_on_node_run_start_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + layer._tracer = None + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: object()) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert layer._node_contexts == {} + + def test_on_node_run_start_skips_when_execution_id_missing_even_with_tracer(self): + layer = ObservabilityLayer() + layer._is_disabled = False + calls: list[str] = [] + layer._tracer = SimpleNamespace(start_span=lambda *_args, **_kwargs: calls.append("called")) + + layer.on_node_run_start(SimpleNamespace(execution_id=None, title="node", id="node")) + + assert calls == [] + + def test_on_node_run_start_logs_warning_when_span_creation_fails(self, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + def _raise(*_args, **_kwargs): + raise RuntimeError("start failed") + + layer._tracer = SimpleNamespace(start_span=_raise) + + layer.on_node_run_start(SimpleNamespace(execution_id="exec", title="node", id="node")) + + assert "Failed to create OpenTelemetry span for node" in caplog.text + + def test_on_node_run_end_without_context_noop(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id="missing", id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_skips_when_disabled(self): + layer = ObservabilityLayer() + layer._is_disabled = True + layer._node_contexts["exec"] = SimpleNamespace(span=object(), token="token") + + layer.on_node_run_end(SimpleNamespace(execution_id="exec", id="node"), error=None) + + assert "exec" in layer._node_contexts + + def test_on_node_run_end_skips_without_execution_id(self): + layer = ObservabilityLayer() + layer._is_disabled = False + + layer.on_node_run_end(SimpleNamespace(execution_id=None, id="node"), error=None) + + assert layer._node_contexts == {} + + def test_on_node_run_end_calls_span_end(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + ended: list[str] = [] + + class _Parser: + def parse(self, **_kwargs): + return None + + span = SimpleNamespace(end=lambda: ended.append("ended")) + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=span, token="token") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda _token: None) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert ended == ["ended"] + assert "exec" not in layer._node_contexts + + def test_on_node_run_end_logs_detach_failure(self, monkeypatch, caplog): + layer = ObservabilityLayer() + layer._is_disabled = False + + class _Parser: + def parse(self, **_kwargs): + return None + + layer._default_parser = _Parser() + layer._node_contexts["exec"] = SimpleNamespace(span=SimpleNamespace(end=lambda: None), token="bad-token") + + def _raise(*_args, **_kwargs): + raise RuntimeError("detach failed") + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", _raise) + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + layer.on_node_run_end(node, error=None) + + assert "Failed to detach OpenTelemetry token" in caplog.text + assert "exec" not in layer._node_contexts + + def test_on_node_run_start_and_end_creates_span(self, monkeypatch): + layer = ObservabilityLayer() + layer._is_disabled = False + + span = SimpleNamespace(end=lambda: None) + tracer = SimpleNamespace(start_span=lambda *args, **kwargs: span) + + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.get_current", lambda: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.set_span_in_context", lambda s: object()) + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.attach", lambda ctx: "token") + monkeypatch.setattr("core.app.workflow.layers.observability.context_api.detach", lambda token: None) + + layer._tracer = tracer + + node = SimpleNamespace(execution_id="exec", title="Node", id="node", node_type=None) + + layer.on_node_run_start(node) + assert "exec" in layer._node_contexts + + layer.on_node_run_end(node, error=None) + assert "exec" not in layer._node_contexts diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py new file mode 100644 index 0000000000..9863f34aba --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -0,0 +1,500 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.entities.pause_reason import SchedulingPause +from graphon.entities.workflow_node_execution import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionStatus, + WorkflowType, +) +from graphon.graph_events.graph import ( + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.graph_events.node import ( + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunPauseRequestedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool + + +class _RepoRecorder: + def __init__(self) -> None: + self.saved: list[object] = [] + self.saved_exec_data: list[object] = [] + + def save(self, entity): + self.saved.append(entity) + + def save_execution_data(self, entity): + self.saved_exec_data.append(entity) + + +def _naive_utc_now() -> datetime: + return datetime.now(UTC).replace(tzinfo=None) + + +def _make_layer( + system_variables: list | None = None, + *, + extras: dict | None = None, + trace_manager: object | None = None, +): + system_variables = system_variables or build_system_variables( + workflow_execution_id="run-id", + conversation_id="conv-id", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) + read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) + + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=SimpleNamespace(app_id="app", tenant_id="tenant"), + inputs={"foo": "bar"}, + files=[], + user_id="user", + stream=False, + invoke_from=None, + trace_manager=None, + workflow_execution_id="run-id", + extras=extras or {}, + call_depth=0, + ) + + workflow_info = PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={"nodes": [], "edges": []}, + ) + + workflow_execution_repo = _RepoRecorder() + workflow_node_execution_repo = _RepoRecorder() + + layer = WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=workflow_info, + workflow_execution_repository=workflow_execution_repo, + workflow_node_execution_repository=workflow_node_execution_repo, + trace_manager=trace_manager, + ) + layer.initialize(read_only_state, command_channel=None) + + return layer, workflow_execution_repo, workflow_node_execution_repo, runtime_state + + +class TestWorkflowPersistenceLayer: + def test_on_graph_start_resets_state(self): + layer, _, _, _ = _make_layer() + layer._workflow_execution = object() + layer._node_execution_cache["cached"] = object() + layer._node_snapshots["cached"] = object() + layer._node_sequence = 9 + + layer.on_graph_start() + + assert layer._workflow_execution is None + assert layer._node_execution_cache == {} + assert layer._node_snapshots == {} + assert layer._node_sequence == 0 + + def test_get_execution_id_requires_system_variable(self): + layer, _, _, _ = _make_layer(build_system_variables()) + + with pytest.raises(ValueError, match="workflow_execution_id must be provided"): + layer._get_execution_id() + + def test_prepare_workflow_inputs_excludes_conversation_id(self, monkeypatch): + layer, _, _, _ = _make_layer() + + monkeypatch.setattr( + "core.workflow.workflow_entry.WorkflowEntry.handle_special_values", + lambda inputs: inputs, + ) + + inputs = layer._prepare_workflow_inputs() + + assert "sys.conversation_id" not in inputs + assert inputs[f"sys.{SystemVariableKey.WORKFLOW_EXECUTION_ID.value}"] == "run-id" + + def test_fail_running_node_executions_marks_failed(self): + layer, _, node_repo, _ = _make_layer() + + execution = WorkflowNodeExecution( + id="exec-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[execution.id] = execution + + layer._fail_running_node_executions(error_message="boom") + + assert execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_repo.saved + + def test_handle_graph_run_started_saves_execution(self): + layer, exec_repo, _, _ = _make_layer() + + layer._handle_graph_run_started() + + assert exec_repo.saved + + def test_handle_graph_run_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 3 + runtime_state.node_run_steps = 2 + runtime_state.outputs = {"out": "v"} + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.SUCCEEDED + assert saved.total_tokens == 3 + assert saved.total_steps == 2 + + def test_handle_graph_run_partial_succeeded_updates_execution(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 5 + runtime_state.node_run_steps = 4 + runtime_state._graph_execution = SimpleNamespace(exceptions_count=2) + + layer._handle_graph_run_partial_succeeded( + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=2) + ) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert saved.exceptions_count == 2 + assert saved.total_tokens == 5 + + def test_handle_graph_run_failed_marks_nodes_and_enqueues_trace(self): + trace_tasks: list[object] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, exec_repo, node_repo, _ = _make_layer(extras={"external_trace_id": "trace"}, trace_manager=trace_manager) + layer._handle_graph_run_started() + + running = WorkflowNodeExecution( + id="node-exec", + workflow_id="workflow-id", + workflow_execution_id="run-id", + index=1, + node_id="node", + node_type=BuiltinNodeTypes.START, + title="Start", + created_at=_naive_utc_now(), + ) + layer._node_execution_cache[running.id] = running + + layer._handle_graph_run_failed(GraphRunFailedEvent(error="boom", exceptions_count=1)) + + assert node_repo.saved + assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED + assert trace_tasks + + def test_handle_graph_run_aborted_sets_status(self): + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + + layer._handle_graph_run_aborted(GraphRunAbortedEvent(reason=None, outputs={})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.STOPPED + assert saved.error_message + + def test_handle_graph_run_paused_updates_outputs(self): + layer, exec_repo, _, runtime_state = _make_layer() + layer._handle_graph_run_started() + runtime_state.total_tokens = 7 + runtime_state.node_run_steps = 5 + + layer._handle_graph_run_paused(GraphRunPausedEvent(outputs={"pause": True})) + + saved = exec_repo.saved[-1] + assert saved.status == WorkflowExecutionStatus.PAUSED + assert saved.outputs == {"pause": True} + assert saved.finished_at is None + + def test_handle_node_started_and_retry(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + predecessor_node_id="prev", + in_iteration_id="iter", + in_loop_id="loop", + ) + layer._handle_node_started(start_event) + + assert node_repo.saved + assert "exec" in layer._node_execution_cache + assert layer._node_snapshots["exec"].node_id == "node" + + retry_event = NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + layer._handle_node_retry(retry_event) + assert node_repo.saved_exec_data + + def test_handle_node_result_events_update_execution(self): + layer, _, node_repo, _ = _make_layer() + layer._handle_graph_run_started() + + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + result = NodeRunResult(inputs={"a": 1}, process_data={"b": 2}, outputs={"c": 3}, metadata={}) + success_event = NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + node_run_result=result, + ) + layer._handle_node_succeeded(success_event) + + failed_event = NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="boom", + node_run_result=result, + ) + layer._handle_node_failed(failed_event) + + exception_event = NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=_naive_utc_now(), + error="err", + node_run_result=result, + ) + layer._handle_node_exception(exception_event) + + assert node_repo.saved_exec_data + + def test_handle_node_pause_requested_skips_outputs(self): + layer, _, _, _ = _make_layer() + layer._handle_graph_run_started() + start_event = NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=_naive_utc_now(), + ) + layer._handle_node_started(start_event) + + domain_execution = layer._node_execution_cache["exec"] + domain_execution.inputs = {"old": True} + + result = NodeRunResult(inputs={"new": True}, outputs={"out": 1}, process_data={"p": 1}, metadata={}) + pause_event = NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + reason=SchedulingPause(message="pause"), + node_run_result=result, + ) + layer._handle_node_pause_requested(pause_event) + + assert domain_execution.status == WorkflowNodeExecutionStatus.PAUSED + assert domain_execution.inputs == {"old": True} + + def test_get_node_execution_raises_for_missing(self): + layer, _, _, _ = _make_layer() + with pytest.raises(ValueError, match="Node execution not found"): + layer._get_node_execution("missing") + + def test_get_workflow_execution_raises_when_uninitialized(self): + layer, _, _, _ = _make_layer() + + with pytest.raises(ValueError, match="workflow execution not initialized"): + layer._get_workflow_execution() + + def test_next_node_sequence_increments(self): + layer, _, _, _ = _make_layer() + assert layer._next_node_sequence() == 1 + assert layer._next_node_sequence() == 2 + + def test_on_graph_end_is_noop(self): + layer, _, _, _ = _make_layer() + + assert layer.on_graph_end(error=None) is None + + def test_on_event_dispatches_to_all_known_handlers(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_graph_run_started = _record("started") + layer._handle_graph_run_succeeded = _record("succeeded") + layer._handle_graph_run_partial_succeeded = _record("partial") + layer._handle_graph_run_failed = _record("failed") + layer._handle_graph_run_aborted = _record("aborted") + layer._handle_graph_run_paused = _record("paused") + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + layer._handle_node_succeeded = _record("node_succeeded") + layer._handle_node_failed = _record("node_failed") + layer._handle_node_exception = _record("node_exception") + layer._handle_node_pause_requested = _record("node_paused") + + node_result = NodeRunResult() + now = _naive_utc_now() + events = [ + GraphRunStartedEvent(), + GraphRunSucceededEvent(outputs={"ok": True}), + GraphRunPartialSucceededEvent(outputs={"ok": True}, exceptions_count=1), + GraphRunFailedEvent(error="boom", exceptions_count=1), + GraphRunAbortedEvent(reason="stop", outputs={"x": 1}), + GraphRunPausedEvent(outputs={"pause": True}), + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + ), + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=now, + error="retry", + retry_index=1, + ), + NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + node_run_result=node_result, + ), + NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="failed", + node_run_result=node_result, + ), + NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=now, + error="error", + node_run_result=node_result, + ), + NodeRunPauseRequestedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + reason=SchedulingPause(message="pause"), + node_run_result=node_result, + ), + ] + expected_order = [ + "started", + "succeeded", + "partial", + "failed", + "aborted", + "paused", + "node_started", + "node_retry", + "node_succeeded", + "node_failed", + "node_exception", + "node_paused", + ] + + for event in events: + layer.on_event(event) + + assert called == expected_order + + def test_on_event_dispatches_retry_before_started_for_retry_event(self): + layer, _, _, _ = _make_layer() + called: list[str] = [] + + def _record(name: str): + def _handler(*_args, **_kwargs): + called.append(name) + + return _handler + + layer._handle_node_started = _record("node_started") + layer._handle_node_retry = _record("node_retry") + + layer.on_event( + NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=_naive_utc_now(), + error="retry", + retry_index=1, + ) + ) + + assert called == ["node_retry"] + + def test_enqueue_trace_task_skips_when_disabled(self): + trace_tasks: list[object] = [] + layer, exec_repo, _, _ = _make_layer() + layer._handle_graph_run_started() + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + assert exec_repo.saved + assert not trace_tasks diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 3759b6aa37..7b433ab57b 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -28,10 +28,7 @@ def mock_model_instance(mocker): def mock_model_manager(mocker, mock_model_instance): manager = mocker.MagicMock() manager.get_default_model_instance.return_value = mock_model_instance - mocker.patch( - "core.base.tts.app_generator_tts_publisher.ModelManager", - return_value=manager, - ) + mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager) return manager @@ -64,16 +61,14 @@ class TestInvoiceTTS: [None, "", " "], ) def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): - result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + result = _invoice_tts(text, mock_model_instance, "voice1") assert result is None mock_model_instance.invoke_tts.assert_not_called() def test_invoice_tts_valid_text(self, mock_model_instance): - result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + result = _invoice_tts(" hello ", mock_model_instance, "voice1") mock_model_instance.invoke_tts.assert_called_once_with( content_text="hello", - user="responding_tts", - tenant_id="tenant", voice="voice1", ) assert result == [b"audio1", b"audio2"] @@ -307,8 +302,8 @@ class TestAppGeneratorTTSPublisher: publisher.executor = MagicMock() from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import ( + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, @@ -342,8 +337,8 @@ class TestAppGeneratorTTSPublisher: publisher.executor = MagicMock() from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage chunk = LLMResultChunk( model="model", diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d5eeae912c..af992e4e9f 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -7,10 +7,11 @@ from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from core.workflow.file_reference import parse_file_reference +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.file import File +from graphon.file.enums import FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -428,11 +429,8 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): return fake_tool_file mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) - mocker.patch( - "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE - ) + mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", @@ -533,7 +531,6 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - tenant_id="t1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", @@ -664,6 +661,8 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + assert parse_file_reference(f.reference).storage_key is None + assert f.storage_key == "k" def test_get_upload_file_by_id_raises_when_missing(mocker): diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 43f582feb7..0b91d59953 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -4,8 +4,8 @@ import pytest from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType +from graphon.file import File +from graphon.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index 2e4f6d34fb..ef8f360dbf 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -4,8 +4,8 @@ from core.entities.execution_extra_content import ( HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 7a3d5e84ed..a0b2820157 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -16,9 +16,9 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 95d58757f1..fe2c226843 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -24,9 +24,9 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FieldModelSchema, @@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( "core.entities.provider_configuration.encrypter.encrypt_token", @@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None: with patch("core.entities.provider_configuration.db") as mock_db: mock_db.engine = Mock() mock_session_cls.return_value.__enter__.return_value = mock_session - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -434,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder: model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema + assert mock_factory_builder.call_count == 2 mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", @@ -449,6 +455,33 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: ) +def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: + configuration = _build_provider_configuration() + bound_runtime = Mock() + configuration.bind_model_runtime(bound_runtime) + + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with ( + patch( + "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory + ) as mock_factory_cls, + patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + ): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + assert mock_factory_cls.call_count == 2 + mock_factory_cls.assert_called_with(model_runtime=bound_runtime) + mock_factory_builder.assert_not_called() + + def test_get_provider_model_returns_none_when_model_not_found() -> None: configuration = _build_provider_configuration() fake_model = SimpleNamespace(model="other-model") @@ -475,7 +508,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -689,7 +722,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1034,7 +1067,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( @@ -1050,7 +1083,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"region": "us"} with _patched_session(session): - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1540,7 +1575,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1662,7 +1697,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index c5bfd05a1e..a159d3ad4d 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -8,7 +8,7 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index deebf41320..bb6e40e224 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,4 +1,4 @@ -from dify_graph.file import File, FileTransferMethod, FileType +from graphon.file import File, FileTransferMethod, FileType def test_file(): @@ -15,18 +15,17 @@ def test_file(): storage_key="test-storage-key", url="https://example.com/image.png", ) - assert file.tenant_id == "test-tenant-id" assert file.type == FileType.IMAGE assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" assert file.filename == "image.png" assert file.extension == ".png" assert file.mime_type == "image/png" assert file.size == 67 -def test_file_model_validate_with_legacy_fields(): - """Test `File` model can handle data containing compatibility fields.""" +def test_file_model_validate_accepts_legacy_tenant_id(): data = { "id": "test-file", "tenant_id": "test-tenant-id", @@ -45,10 +44,8 @@ def test_file_model_validate_with_legacy_fields(): "datasource_file_id": "datasource-file-789", } - # Should be able to create `File` object without raising an exception file = File.model_validate(data) - # The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes. - # Instead, check it does not expose unrecognized legacy fields (should raise on getattr). - for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"): - assert not hasattr(file, legacy_field) + assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" + assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 46c9dc6f9c..6ed9ddb476 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -16,20 +16,20 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultWithStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 5b7640696f..b3a5885814 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -6,14 +6,14 @@ import pytest from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @pytest.fixture def mock_model_instance(self): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_default_model_instance.return_value = instance mock_manager.return_value.get_model_instance.return_value = instance @@ -98,7 +98,7 @@ class TestLLMGenerator: assert questions[0] == "Question 1?" def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] @@ -528,7 +528,7 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_model_instance.return_value = instance mock_response = MagicMock() diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f982765b1a..bfb1fde502 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -18,7 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 5ecfe01808..f459250b8e 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -6,7 +6,7 @@ from uuid import uuid4 import pytest from core.memory.token_buffer_memory import TokenBufferMemory -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py new file mode 100644 index 0000000000..249ecb5006 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -0,0 +1,420 @@ +from unittest.mock import Mock + +import pytest + +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _build_model(model: str, model_type: ModelType) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _build_provider( + *, + provider: str, + provider_name: str, + supported_model_types: list[ModelType], + models: list[AIModelEntity] | None = None, + provider_credential_schema: ProviderCredentialSchema | None = None, + model_credential_schema: ModelCredentialSchema | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + provider_name=provider_name, + label=I18nObject(en_US=provider_name or provider), + supported_model_types=supported_model_types, + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + provider_credential_schema=provider_credential_schema, + model_credential_schema=model_credential_schema, + ) + + +class _FakeModelRuntime: + def __init__(self, providers: list[ProviderEntity]) -> None: + self._providers = providers + self.validate_provider_credentials = Mock() + self.validate_model_credentials = Mock() + self.get_model_schema = Mock() + self.get_provider_icon = Mock() + + def fetch_model_providers(self) -> list[ProviderEntity]: + return self._providers + + +def test_model_provider_factory_resolves_runtime_provider_name() -> None: + provider = ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_resolves_canonical_short_name_independent_of_provider_order() -> None: + providers = [ + ProviderEntity( + provider="acme/openai/openai", + provider_name="", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_requires_runtime() -> None: + with pytest.raises(ValueError, match="model_runtime is required"): + ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + + +def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + result = factory.get_providers() + + assert list(result) == providers + assert result is not providers + + +def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None: + provider = _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + result = factory.get_provider_schema("openai") + + assert result is provider + + +def test_model_provider_factory_raises_for_unknown_provider() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Invalid provider: anthropic"): + factory.get_model_provider("anthropic") + + +def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ), + _build_provider( + provider="langgenius/cohere/cohere", + provider_name="cohere", + supported_model_types=[ModelType.RERANK], + models=[_build_model("rerank-v3", ModelType.RERANK)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai", model_type=ModelType.LLM) + + assert len(results) == 1 + assert results[0].provider == "langgenius/openai/openai" + assert [model.model for model in results[0].models] == ["gpt-4o-mini"] + + +def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + models=[_build_model("gpt-4o-mini", ModelType.LLM)], + ), + _build_provider( + provider="langgenius/elevenlabs/elevenlabs", + provider_name="elevenlabs", + supported_model_types=[ModelType.TTS], + models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(model_type=ModelType.TTS) + + assert len(results) == 1 + assert results[0].provider == "langgenius/elevenlabs/elevenlabs" + assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"] + + +def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai") + + assert len(results) == 1 + assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"] + + +def test_model_provider_factory_validates_provider_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + provider_credential_schema=ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ] + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.provider_credentials_validate( + provider="openai", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_provider_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"}) + + +def test_model_provider_factory_validates_model_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + model_credential_schema=ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_model_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + runtime.get_model_schema.return_value = "schema" + runtime.get_provider_icon.return_value = (b"icon", "image/png") + factory = ModelProviderFactory(model_runtime=runtime) + + assert ( + factory.get_model_schema( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials=None, + ) + == "schema" + ) + assert factory.get_provider_icon("openai", "icon_small", "en_US") == (b"icon", "image/png") + runtime.get_model_schema.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + runtime.get_provider_icon.assert_called_once_with( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + +@pytest.mark.parametrize( + ("model_type", "expected_type"), + [ + (ModelType.LLM, LargeLanguageModel), + (ModelType.TEXT_EMBEDDING, TextEmbeddingModel), + (ModelType.RERANK, RerankModel), + (ModelType.SPEECH2TEXT, Speech2TextModel), + (ModelType.MODERATION, ModerationModel), + (ModelType.TTS, TTSModel), + ], +) +def test_model_provider_factory_builds_model_type_instances( + model_type: ModelType, + expected_type: type[object], +) -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] + ) + ) + + instance = factory.get_model_type_instance("openai", model_type) + + assert isinstance(instance, expected_type) + + +def test_model_provider_factory_rejects_unsupported_model_type() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Unsupported model type: unsupported"): + factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index e61cde22e7..3a97ad5c5d 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index dfd61acfa7..c2324fdec4 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -34,8 +34,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -396,14 +396,14 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = ["node1"] + repo.get_by_workflow_execution.return_value = ["node1"] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) result = trace_instance.get_workflow_node_executions(trace_info) assert result == ["node1"] - repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + repo.get_by_workflow_execution.assert_called_once_with(workflow_execution_id=trace_info.workflow_run_id) def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index 763fc90710..fa885e9320 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -24,8 +24,8 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py index 1cee2f5b68..4ce9e22fd7 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -254,7 +254,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac node1.id = "n1" node1.error = None - repo.get_by_workflow_run.return_value = [node1] + repo.get_by_workflow_execution.return_value = [node1] with patch.object(trace_instance, "get_service_account_with_tenant"): trace_instance.workflow_trace(info) diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 0ff135562c..fdf66d4d40 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -25,7 +25,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -174,7 +174,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = None repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -244,7 +244,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) @@ -680,7 +680,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index f656f7435f..e89359c25b 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -21,7 +21,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -184,7 +184,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + repo.get_by_workflow_execution.return_value = [node_llm, node_other, node_retrieval] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -255,7 +255,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) @@ -565,7 +565,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm] + repo.get_by_workflow_execution.return_value = [node_llm] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index cccedaa08c..7ff6f7dcfd 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index b2cb7d5109..6625cb719f 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -18,7 +18,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -199,7 +199,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -253,7 +253,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) @@ -657,7 +657,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index a0b6d52720..6113e5c6c8 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -25,8 +25,8 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index f259e4639f..265652381c 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -14,8 +14,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -413,7 +413,7 @@ class TestTencentDataTrace: with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" ) as mock_repo: - mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 49d6b698ef..4b925390d9 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from dify_graph.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py index 7660967183..ad9d0846be 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -130,7 +130,7 @@ class TestWorkflowTraceWithoutMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, @@ -262,7 +262,7 @@ class TestWorkflowTraceWithMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8057bbbad5..8987b6682c 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -589,7 +589,7 @@ class TestWorkflowTrace: nodes = [] repo = MagicMock() - repo.get_by_workflow_run.return_value = nodes + repo.get_by_workflow_execution.return_value = nodes mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py new file mode 100644 index 0000000000..7491e79f30 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch + +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly + + +def test_plugin_model_assembly_reuses_single_runtime_across_views(): + runtime = Mock(name="runtime") + provider_factory = Mock(name="provider_factory") + provider_manager = Mock(name="provider_manager") + model_manager = Mock(name="model_manager") + + with ( + patch( + "core.plugin.impl.model_runtime_factory.create_plugin_model_runtime", + return_value=runtime, + ) as mock_runtime_factory, + patch( + "core.plugin.impl.model_runtime_factory.ModelProviderFactory", + return_value=provider_factory, + ) as mock_provider_factory_cls, + patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls, + patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls, + ): + assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1") + + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + + mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) + mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py new file mode 100644 index 0000000000..c24d3ac012 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.entities.request import RequestInvokeSummary +from graphon.model_runtime.entities.message_entities import UserPromptMessage + + +def test_system_model_helpers_forward_user_id(): + with ( + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens", + return_value=4096, + ) as mock_max_tokens, + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens", + return_value=7, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096 + assert ( + PluginModelBackwardsInvocation.get_prompt_tokens( + "tenant-1", + [UserPromptMessage(content="hello")], + user_id="user-1", + ) + == 7 + ) + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="user-1", + ) + + +def test_invoke_summary_uses_same_user_scope_for_token_helpers(): + tenant = SimpleNamespace(id="tenant-1") + payload = RequestInvokeSummary(text="short", instruction="keep it concise") + + with ( + patch.object( + PluginModelBackwardsInvocation, + "get_system_model_max_tokens", + return_value=100, + ) as mock_max_tokens, + patch.object( + PluginModelBackwardsInvocation, + "get_prompt_tokens", + return_value=10, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short" + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="short")], + user_id="user-1", + ) diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py new file mode 100644 index 0000000000..68aa130518 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -0,0 +1,506 @@ +"""Unit tests for the plugin-backed model runtime adapter.""" + +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import Mock, sentinel + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl import model_runtime as model_runtime_module +from core.plugin.impl.model import PluginModelClient +from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +def _build_model_schema() -> AIModelEntity: + return AIModelEntity( + model="gpt-4o-mini", + label=I18nObject(en_US="GPT-4o mini"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +class TestPluginModelRuntime: + """Validate the adapter keeps plugin-specific routing out of the runtime port.""" + + def test_fetch_model_providers_returns_runtime_entities(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert providers[0].provider_name == "openai" + assert providers[0].label.en_US == "OpenAI" + client.fetch_model_providers.assert_called_once_with("tenant") + + def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="acme/openai/openai", + plugin_id="acme/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + provider_aliases = {provider.provider: provider.provider_name for provider in providers} + assert provider_aliases["acme/openai/openai"] == "" + assert provider_aliases["langgenius/openai/openai"] == "openai" + + def test_fetch_model_providers_keeps_google_alias_on_canonical_gemini_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="google", + tenant_id="tenant", + plugin_unique_identifier="langgenius/gemini/google", + plugin_id="langgenius/gemini", + declaration=ProviderEntity( + provider="google", + label=I18nObject(en_US="Google"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert providers[0].provider == "langgenius/gemini/google" + assert providers[0].provider_name == "google" + + def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.validate_provider_credentials( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + client.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + credentials={"api_key": "secret"}, + ) + + def test_invoke_llm_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + def test_invoke_llm_rejects_per_call_user_override(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + + with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): + runtime.invoke_llm( # type: ignore[call-arg] + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + user_id="request-user", + ) + + client.invoke_llm.assert_not_called() + + def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_tts.return_value = iter([b"chunk"]) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + + result = runtime.invoke_tts( + provider="langgenius/openai/openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + assert list(result) == [b"chunk"] + client.invoke_tts.assert_called_once_with( + tenant_id="tenant", + user_id=None, + plugin_id="langgenius/openai", + provider="openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.fetch_model_providers() + runtime.fetch_model_providers() + + client.fetch_model_providers.assert_called_once_with("tenant") + + +def test_create_plugin_model_runtime_without_user_context() -> None: + runtime = create_plugin_model_runtime(tenant_id="tenant") + + assert runtime.user_id is None + + +def test_plugin_model_runtime_requires_client() -> None: + with pytest.raises(ValueError, match="client is required"): + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + + +def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=schema.model_dump_json()), + delete=Mock(), + setex=Mock(), + ), + ) + + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + client.get_model_schema.assert_not_called() + + +def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + delete = Mock() + setex = Mock() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value="not-json"), + delete=delete, + setex=setex, + ), + ) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) + client.get_model_schema.return_value = schema + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + delete.assert_called_once() + client.get_model_schema.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model_type=ModelType.LLM.value, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + setex.assert_called_once() + + +def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert ( + runtime.get_llm_num_tokens( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + prompt_messages=[], + tools=None, + ) + == 0 + ) + client.get_llm_num_tokens.assert_not_called() + + +def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=I18nObject(en_US="logo.svg"), + icon_small_dark=I18nObject(en_US="logo-dark.png"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + fetch_asset = Mock(return_value=b"") + monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + icon_bytes, mime_type = runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + assert icon_bytes == b"" + assert mime_type == "image/svg+xml" + fetch_asset.assert_called_once_with(tenant_id="tenant", id="logo.svg") + + +def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + with pytest.raises(ValueError, match="does not have small dark icon"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small_dark", + lang="en_US", + ) + + with pytest.raises(ValueError, match="Unsupported icon type"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_large", + lang="en_US", + ) + + +def test_get_schema_cache_key_is_stable_across_credential_order() -> None: + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + + first = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"b": "2", "a": "1"}, + ) + second = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1", "b": "2"}, + ) + + assert first == second + + +def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: + first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + + first = first_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + second = second_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert first != second + + +def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + user_key = user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert tenant_key != user_key + assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key + + +def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + empty_user_key = empty_user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + + assert tenant_key != empty_user_key + assert empty_user_key.endswith(":") + assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key + + +def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" + + with pytest.raises(ValueError, match="Invalid provider"): + runtime._get_provider_schema("missing") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index b0b64a601b..f1c4c7e700 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -25,7 +25,7 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage, diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 4f038d4a5b..eae9d9459e 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -36,14 +36,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index c7e94aa4cf..4d4313dd84 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -6,8 +6,8 @@ from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3d08525aba..395d392127 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -9,8 +9,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, @@ -145,7 +145,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("dify_graph.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: + with patch("graphon.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 634703740c..803afa54d7 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -6,13 +6,13 @@ from core.app.entities.app_invoke_entities import ( from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 9fc300348a..5d865d934c 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,6 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index d379e3067a..9f9ea33695 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -4,14 +4,14 @@ from unittest.mock import MagicMock, patch import pytest from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -# from dify_graph.model_runtime.entities.message_entities import UserPromptMessage -# from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule -# from dify_graph.model_runtime.entities.provider_entities import ProviderEntity -# from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.entities.message_entities import UserPromptMessage +# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +# from graphon.model_runtime.entities.provider_entities import ProviderEntity +# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index e6d28224d7..0dc74b33df 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -18,7 +18,7 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, diff --git a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py index 65ee62b8dd..c7a4265a95 100644 --- a/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py +++ b/api/tests/unit_tests/core/rag/cleaner/test_clean_processor.py @@ -211,3 +211,16 @@ class TestCleanProcessor: text = "[Text with (parens) and symbols](https://example.com)" expected = "[Text with (parens) and symbols](https://example.com)" assert CleanProcessor.clean(text, process_rule) == expected + + def test_clean_remove_urls_emails_preserves_markdown_image_links(self): + """Remove plain URLs and emails while preserving markdown image links.""" + process_rule = {"rules": {"pre_processing_rules": [{"id": "remove_urls_emails", "enabled": True}]}} + text = "Email test@example.com and remove https://remove.com but keep ![diagram](https://example.com/image.png)" + result = CleanProcessor.clean(text, process_rule) + + assert result == "Email and remove but keep ![diagram](https://example.com/image.png)" + + def test_filter_string_returns_input_text(self): + """Test filter_string passthrough behavior.""" + processor = CleanProcessor() + assert processor.filter_string("raw text") == "raw text" diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py new file mode 100644 index 0000000000..1f3247590c --- /dev/null +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -0,0 +1,246 @@ +from unittest.mock import MagicMock, patch + +from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError + + +def _doc(content: str) -> Document: + return Document(page_content=content) + + +class TestDataPostProcessor: + def test_init_sets_rerank_and_reorder_runners(self): + rerank_runner = object() + reorder_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_runner", return_value=rerank_runner) as rerank_mock: + with patch.object(DataPostProcessor, "_get_reorder_runner", return_value=reorder_runner) as reorder_mock: + processor = DataPostProcessor( + tenant_id="tenant-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_model={"config": "value"}, + weights={"weight": "value"}, + reorder_enabled=True, + ) + + assert processor.rerank_runner is rerank_runner + assert processor.reorder_runner is reorder_runner + rerank_mock.assert_called_once_with( + RerankMode.WEIGHTED_SCORE, + "tenant-1", + {"config": "value"}, + {"weight": "value"}, + ) + reorder_mock.assert_called_once_with(True) + + def test_invoke_applies_rerank_then_reorder(self): + original_documents = [_doc("doc-a")] + reranked_documents = [_doc("doc-b")] + reordered_documents = [_doc("doc-c")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = MagicMock() + processor.rerank_runner.run.return_value = reranked_documents + processor.reorder_runner = MagicMock() + processor.reorder_runner.run.return_value = reordered_documents + + result = processor.invoke( + query="how to test", + documents=original_documents, + score_threshold=0.3, + top_n=2, + query_type=QueryType.IMAGE_QUERY, + ) + + processor.rerank_runner.run.assert_called_once_with( + "how to test", + original_documents, + 0.3, + 2, + QueryType.IMAGE_QUERY, + ) + processor.reorder_runner.run.assert_called_once_with(reranked_documents) + assert result == reordered_documents + + def test_invoke_returns_original_documents_when_no_runner_is_configured(self): + documents = [_doc("doc-a"), _doc("doc-b")] + + processor = DataPostProcessor.__new__(DataPostProcessor) + processor.rerank_runner = None + processor.reorder_runner = None + + assert processor.invoke(query="query", documents=documents) == documents + + def test_get_rerank_runner_for_weighted_score(self): + weights_config = { + "vector_setting": { + "vector_weight": 0.7, + "embedding_provider_name": "provider-x", + "embedding_model_name": "embedding-y", + }, + "keyword_setting": {"keyword_weight": 0.3}, + } + expected_runner = object() + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.WEIGHTED_SCORE, + tenant_id="tenant-1", + reranking_model=None, + weights=weights_config, + ) + + assert result is expected_runner + kwargs = factory_mock.call_args.kwargs + assert kwargs["runner_type"] == RerankMode.WEIGHTED_SCORE + assert kwargs["tenant_id"] == "tenant-1" + assert kwargs["weights"].vector_setting.vector_weight == 0.7 + assert kwargs["weights"].vector_setting.embedding_provider_name == "provider-x" + assert kwargs["weights"].vector_setting.embedding_model_name == "embedding-y" + assert kwargs["weights"].keyword_setting.keyword_weight == 0.3 + + def test_get_rerank_runner_for_reranking_model_returns_none_without_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + reranking_model = { + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + } + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=None) as model_mock: + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner" + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model=reranking_model, + weights=None, + ) + + assert result is None + model_mock.assert_called_once_with("tenant-1", reranking_model) + factory_mock.assert_not_called() + + def test_get_rerank_runner_for_reranking_model_creates_runner_with_model_instance(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + expected_runner = object() + + with patch.object(DataPostProcessor, "_get_rerank_model_instance", return_value=model_instance): + with patch( + "core.rag.data_post_processor.data_post_processor.RerankRunnerFactory.create_rerank_runner", + return_value=expected_runner, + ) as factory_mock: + result = processor._get_rerank_runner( + reranking_mode=RerankMode.RERANKING_MODEL, + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "model-y", + }, + weights=None, + ) + + assert result is expected_runner + factory_mock.assert_called_once_with( + runner_type=RerankMode.RERANKING_MODEL, + rerank_model_instance=model_instance, + ) + + def test_get_rerank_runner_returns_none_for_unsupported_mode(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert processor._get_rerank_runner("unsupported", "tenant-1", None, None) is None + assert processor._get_rerank_runner(RerankMode.WEIGHTED_SCORE, "tenant-1", None, None) is None + + def test_get_reorder_runner_by_flag(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + assert isinstance(processor._get_reorder_runner(True), ReorderRunner) + assert processor._get_reorder_runner(False) is None + + def test_get_rerank_model_instance_returns_none_when_config_is_missing(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + assert processor._get_rerank_model_instance("tenant-1", None) is None + + def test_get_rerank_model_instance_returns_none_for_incomplete_config(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) + + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + + def test_get_rerank_model_instance_success(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + model_instance = object() + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value + manager_instance.get_model_instance.return_value = model_instance + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is model_instance + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + manager_instance.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider-x", + model_type=ModelType.RERANK, + model="reranker-1", + ) + + def test_get_rerank_model_instance_handles_authorization_error(self): + processor = DataPostProcessor.__new__(DataPostProcessor) + + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value + manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") + + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={ + "reranking_provider_name": "provider-x", + "reranking_model_name": "reranker-1", + }, + ) + + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + + +class TestReorderRunner: + def test_run_reorders_even_sized_document_list(self): + documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4"), _doc("5")] + + reordered = ReorderRunner().run(documents) + + assert [document.page_content for document in reordered] == ["0", "2", "4", "5", "3", "1"] + + def test_run_handles_odd_sized_and_empty_document_lists(self): + odd_documents = [_doc("0"), _doc("1"), _doc("2"), _doc("3"), _doc("4")] + runner = ReorderRunner() + + odd_reordered = runner.run(odd_documents) + + assert [document.page_content for document in odd_reordered] == ["0", "2", "4", "3", "1"] + assert runner.run([]) == [] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py new file mode 100644 index 0000000000..795a325a6b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -0,0 +1,414 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.keyword.jieba.jieba as jieba_module +from core.rag.datasource.keyword.jieba.jieba import Jieba, dumps_with_sets, set_orjson_default +from core.rag.models.document import Document + + +class _DummyLock: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _Field: + def __init__(self, name: str): + self._name = name + + def __eq__(self, other): + return ("eq", self._name, other) + + def in_(self, values): + return ("in", self._name, tuple(values)) + + +class _FakeQuery: + def __init__(self): + self.where_calls: list[tuple] = [] + + def where(self, *conditions): + self.where_calls.append(conditions) + return self + + +class _FakeExecuteResult: + def __init__(self, segments: list[SimpleNamespace]): + self._segments = segments + + def scalars(self): + return self + + def all(self): + return self._segments + + +class _FakeSelect: + def __init__(self): + self.where_conditions: tuple | None = None + + def where(self, *conditions): + self.where_conditions = conditions + return self + + +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): + return SimpleNamespace( + data_source_type=data_source_type, + keyword_table_dict=keyword_table_dict, + keyword_table="", + ) + + +def _dataset(dataset_keyword_table=None, keyword_number=None): + return SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + keyword_number=keyword_number, + dataset_keyword_table=dataset_keyword_table, + ) + + +@pytest.fixture +def patched_runtime(monkeypatch): + session = MagicMock() + db = SimpleNamespace(session=session) + storage = MagicMock() + lock = MagicMock(return_value=_DummyLock()) + redis_client = SimpleNamespace(lock=lock) + + monkeypatch.setattr(jieba_module, "db", db) + monkeypatch.setattr(jieba_module, "storage", storage) + monkeypatch.setattr(jieba_module, "redis_client", redis_client) + + return SimpleNamespace(session=session, storage=storage, lock=lock) + + +def test_create_indexes_documents_and_returns_self(monkeypatch, patched_runtime): + dataset = _dataset(_dataset_keyword_table(), keyword_number=2) + keyword = Jieba(dataset) + handler = MagicMock() + handler.extract_keywords.return_value = {"kw1", "kw2"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + result = keyword.create( + [ + Document(page_content="alpha", metadata={"doc_id": "node-1"}), + SimpleNamespace(page_content="ignored", metadata=None), + ] + ) + + assert result is keyword + keyword._update_segment_keywords.assert_called_once() + call_args = keyword._update_segment_keywords.call_args.args + assert call_args[0] == "dataset-1" + assert call_args[1] == "node-1" + assert set(call_args[2]) == {"kw1", "kw2"} + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["kw1"] == {"node-1"} + assert saved_table["kw2"] == {"node-1"} + patched_runtime.lock.assert_called_once_with("keyword_indexing_lock_dataset-1", timeout=600) + + +def test_add_texts_supports_keywords_list_and_extract_fallback(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=3)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + texts = [ + Document(page_content="extract-this", metadata={"doc_id": "node-1"}), + Document(page_content="use-manual", metadata={"doc_id": "node-2"}), + ] + keyword.add_texts(texts, keywords_list=[[], ["manual"]]) + + assert keyword._update_segment_keywords.call_count == 2 + first_call = keyword._update_segment_keywords.call_args_list[0].args + second_call = keyword._update_segment_keywords.call_args_list[1].args + assert set(first_call[2]) == {"auto"} + assert second_call[2] == ["manual"] + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_add_texts_without_keywords_list_always_uses_extractor(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=1)) + handler = MagicMock() + handler.extract_keywords.return_value = {"from-extractor"} + + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.add_texts([Document(page_content="content", metadata={"doc_id": "node-1"})]) + + handler.extract_keywords.assert_called_once_with("content", 1) + assert set(keyword._update_segment_keywords.call_args.args[2]) == {"from-extractor"} + + +def test_text_exists_handles_missing_and_existing_keyword_table(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + assert keyword.text_exists("node-1") is False + + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + assert keyword.text_exists("node-2") is True + assert keyword.text_exists("node-x") is False + + +def test_delete_by_ids_updates_table_when_present(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock(return_value={"k": {"node-2"}})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_called_once_with({"k": {"node-1", "node-2"}}, ["node-1"]) + keyword._save_dataset_keyword_table.assert_called_once_with({"k": {"node-2"}}) + + +def test_delete_by_ids_saves_none_when_keyword_table_is_missing(monkeypatch, patched_runtime): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value=None)) + monkeypatch.setattr(keyword, "_delete_ids_from_keyword_table", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.delete_by_ids(["node-1"]) + + keyword._delete_ids_from_keyword_table.assert_not_called() + keyword._save_dataset_keyword_table.assert_called_once_with(None) + + +def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + document_id = _Field("document_id") + + keyword = Jieba(_dataset(_dataset_keyword_table())) + query_stmt = _FakeQuery() + patched_runtime.session.query.return_value = query_stmt + patched_runtime.session.execute.return_value = _FakeExecuteResult( + [ + SimpleNamespace( + index_node_id="node-2", + content="segment-content", + index_node_hash="hash-2", + document_id="doc-2", + dataset_id="dataset-1", + ) + ] + ) + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) + monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) + + documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) + + assert len(query_stmt.where_calls) == 2 + assert len(documents) == 1 + assert documents[0].page_content == "segment-content" + assert documents[0].metadata["doc_id"] == "node-2" + assert documents[0].metadata["doc_hash"] == "hash-2" + + +def test_delete_removes_keyword_table_and_optional_file(monkeypatch, patched_runtime): + db_keyword = _dataset_keyword_table(data_source_type="database") + file_keyword = _dataset_keyword_table(data_source_type="object_storage") + + keyword_db = Jieba(_dataset(db_keyword)) + keyword_db.delete() + patched_runtime.storage.delete.assert_not_called() + + keyword_file = Jieba(_dataset(file_keyword)) + keyword_file.delete() + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + assert patched_runtime.session.delete.call_count == 2 + assert patched_runtime.session.commit.call_count == 2 + + +def test_save_dataset_keyword_table_to_database(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="database") + keyword = Jieba(_dataset(dataset_keyword_table)) + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + assert '"__type__":"keyword_table"' in dataset_keyword_table.keyword_table + assert '"index_id":"dataset-1"' in dataset_keyword_table.keyword_table + patched_runtime.session.commit.assert_called_once() + + +def test_save_dataset_keyword_table_to_file_storage(monkeypatch, patched_runtime): + dataset_keyword_table = _dataset_keyword_table(data_source_type="file") + keyword = Jieba(_dataset(dataset_keyword_table)) + patched_runtime.storage.exists.return_value = True + + keyword._save_dataset_keyword_table({"kw": {"node-1"}}) + + patched_runtime.storage.delete.assert_called_once_with("keyword_files/tenant-1/dataset-1.txt") + patched_runtime.storage.save.assert_called_once() + save_args = patched_runtime.storage.save.call_args.args + assert save_args[0] == "keyword_files/tenant-1/dataset-1.txt" + assert isinstance(save_args[1], bytes) + + +def test_get_dataset_keyword_table_returns_existing_table_data(monkeypatch, patched_runtime): + existing = _dataset_keyword_table( + keyword_table_dict={"__type__": "keyword_table", "__data__": {"table": {"kw": ["node-1"]}}} + ) + keyword = Jieba(_dataset(existing)) + assert keyword._get_dataset_keyword_table() == {"kw": ["node-1"]} + + missing_payload = _dataset_keyword_table(keyword_table_dict=None) + keyword_with_missing_payload = Jieba(_dataset(missing_payload)) + assert keyword_with_missing_payload._get_dataset_keyword_table() == {} + + +def test_get_dataset_keyword_table_creates_table_when_missing(monkeypatch, patched_runtime): + created_tables: list[SimpleNamespace] = [] + + def _fake_dataset_keyword_table(**kwargs): + kwargs.setdefault("keyword_table", "") + kwargs.setdefault("keyword_table_dict", None) + table = SimpleNamespace(**kwargs) + created_tables.append(table) + return table + + keyword = Jieba(_dataset(dataset_keyword_table=None)) + monkeypatch.setattr(jieba_module, "DatasetKeywordTable", _fake_dataset_keyword_table) + monkeypatch.setattr(jieba_module.dify_config, "KEYWORD_DATA_SOURCE_TYPE", "database") + + result = keyword._get_dataset_keyword_table() + + assert result == {} + assert len(created_tables) == 1 + assert created_tables[0].dataset_id == "dataset-1" + assert created_tables[0].data_source_type == "database" + assert '"index_id":"dataset-1"' in created_tables[0].keyword_table + patched_runtime.session.add.assert_called_once_with(created_tables[0]) + patched_runtime.session.commit.assert_called_once() + + +def test_add_and_delete_ids_from_keyword_table_helpers(): + keyword = Jieba(_dataset(_dataset_keyword_table())) + keyword_table = {"kw1": {"node-1"}, "kw2": {"node-1", "node-2"}} + + updated = keyword._add_text_to_keyword_table(keyword_table, "node-3", ["kw1", "kw3"]) + assert updated["kw1"] == {"node-1", "node-3"} + assert updated["kw3"] == {"node-3"} + + deleted = keyword._delete_ids_from_keyword_table(updated, ["node-1", "node-3"]) + assert "kw3" not in deleted + assert "kw1" not in deleted + assert deleted["kw2"] == {"node-2"} + + +def test_retrieve_ids_by_query_ranks_by_keyword_frequency(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + handler = MagicMock() + handler.extract_keywords.return_value = ["kw-a", "kw-b"] + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + + ranked_ids = keyword._retrieve_ids_by_query( + {"kw-a": {"node-1", "node-2"}, "kw-b": {"node-2"}, "kw-c": {"node-3"}}, + "query", + k=1, + ) + + assert ranked_ids == ["node-2"] + + +def test_update_segment_keywords_updates_when_segment_exists(monkeypatch, patched_runtime): + class _FakeDocumentSegment: + dataset_id = _Field("dataset_id") + index_node_id = _Field("index_node_id") + + monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) + + keyword = Jieba(_dataset(_dataset_keyword_table())) + segment = SimpleNamespace(keywords=[]) + patched_runtime.session.scalar.return_value = segment + + keyword._update_segment_keywords("dataset-1", "node-1", ["kw1", "kw2"]) + + assert segment.keywords == ["kw1", "kw2"] + patched_runtime.session.add.assert_called_once_with(segment) + patched_runtime.session.commit.assert_called_once() + + patched_runtime.session.reset_mock() + patched_runtime.session.scalar.return_value = None + + keyword._update_segment_keywords("dataset-1", "node-missing", ["kw3"]) + + patched_runtime.session.add.assert_not_called() + patched_runtime.session.commit.assert_not_called() + + +def test_create_segment_keywords_and_update_segment_keywords_index(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table())) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_update_segment_keywords", MagicMock()) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + keyword.create_segment_keywords("node-1", ["kw"]) + keyword._update_segment_keywords.assert_called_once_with("dataset-1", "node-1", ["kw"]) + keyword._save_dataset_keyword_table.assert_called_once() + + keyword._save_dataset_keyword_table.reset_mock() + keyword.update_segment_keywords_index("node-2", ["kw2"]) + keyword._save_dataset_keyword_table.assert_called_once() + + +def test_multi_create_segment_keywords_uses_provided_and_extracted_keywords(monkeypatch): + keyword = Jieba(_dataset(_dataset_keyword_table(), keyword_number=2)) + handler = MagicMock() + handler.extract_keywords.return_value = {"auto"} + monkeypatch.setattr(jieba_module, "JiebaKeywordTableHandler", lambda: handler) + monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={})) + monkeypatch.setattr(keyword, "_save_dataset_keyword_table", MagicMock()) + + first_segment = SimpleNamespace(index_node_id="node-1", content="first content", keywords=None) + second_segment = SimpleNamespace(index_node_id="node-2", content="second content", keywords=None) + + keyword.multi_create_segment_keywords( + [ + {"segment": first_segment, "keywords": ["manual"]}, + {"segment": second_segment, "keywords": []}, + ] + ) + + assert first_segment.keywords == ["manual"] + assert second_segment.keywords == ["auto"] + saved_table = keyword._save_dataset_keyword_table.call_args.args[0] + assert saved_table["manual"] == {"node-1"} + assert saved_table["auto"] == {"node-2"} + + +def test_set_orjson_default_and_dumps_with_sets(): + assert set(set_orjson_default({"a", "b"})) == {"a", "b"} + + with pytest.raises(TypeError, match="is not JSON serializable"): + set_orjson_default(("not", "a", "set")) + + payload = {"items": {"a", "b"}} + json_payload = dumps_with_sets(payload) + decoded = json.loads(json_payload) + assert set(decoded["items"]) == {"a", "b"} diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py new file mode 100644 index 0000000000..a4586c141b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba_keyword_table_handler.py @@ -0,0 +1,142 @@ +import sys +import types +from types import SimpleNamespace + +from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +class _DummyTFIDF: + def __init__(self): + self.stop_words = set() + + @staticmethod + def extract_tags(sentence: str, top_k: int | None = 20, **kwargs): + return ["alpha_beta", "during", "gamma"] + + +def _install_fake_jieba_modules( + monkeypatch, + analyse_module: types.ModuleType, + jieba_attrs: dict[str, object] | None = None, + tfidf_module: types.ModuleType | None = None, +): + jieba_module = types.ModuleType("jieba") + jieba_module.__path__ = [] + if jieba_attrs: + for key, value in jieba_attrs.items(): + setattr(jieba_module, key, value) + + jieba_module.analyse = analyse_module + analyse_module.__package__ = "jieba" + + monkeypatch.setitem(sys.modules, "jieba", jieba_module) + monkeypatch.setitem(sys.modules, "jieba.analyse", analyse_module) + if tfidf_module is not None: + monkeypatch.setitem(sys.modules, "jieba.analyse.tfidf", tfidf_module) + else: + monkeypatch.delitem(sys.modules, "jieba.analyse.tfidf", raising=False) + + +def test_init_uses_existing_default_tfidf(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + default_tfidf = _DummyTFIDF() + analyse_module.default_tfidf = default_tfidf + + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert handler._tfidf is default_tfidf + assert handler._tfidf.stop_words == STOPWORDS + + +def test_load_tfidf_extractor_uses_tfidf_class_and_caches_default(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + class _TFIDFFactory(_DummyTFIDF): + pass + + analyse_module.TFIDF = _TFIDFFactory + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _TFIDFFactory) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_imports_from_tfidf_submodule(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + + tfidf_module = types.ModuleType("jieba.analyse.tfidf") + + class _ImportedTFIDF(_DummyTFIDF): + pass + + tfidf_module.TFIDF = _ImportedTFIDF + _install_fake_jieba_modules(monkeypatch, analyse_module, tfidf_module=tfidf_module) + + handler = JiebaKeywordTableHandler() + + assert isinstance(handler._tfidf, _ImportedTFIDF) + assert analyse_module.default_tfidf is handler._tfidf + + +def test_load_tfidf_extractor_falls_back_when_tfidf_unavailable(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + analyse_module.default_tfidf = None + _install_fake_jieba_modules(monkeypatch, analyse_module) + + handler = JiebaKeywordTableHandler() + fallback_keywords = handler._tfidf.extract_tags("one two two and three", topK=1) + + assert fallback_keywords == ["two"] + + +def test_build_fallback_tfidf_uses_lcut_when_available(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules(monkeypatch, analyse_module, jieba_attrs={"lcut": lambda _: ["x", "x", "y"]}) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["x"] + + +def test_build_fallback_tfidf_uses_cut_when_lcut_is_missing(monkeypatch): + analyse_module = types.ModuleType("jieba.analyse") + _install_fake_jieba_modules( + monkeypatch, + analyse_module, + jieba_attrs={"cut": lambda _: iter(["foo", "foo", "bar"])}, + ) + + tfidf = JiebaKeywordTableHandler._build_fallback_tfidf() + + assert tfidf.extract_tags("ignored", topK=1) == ["foo"] + + +def test_extract_keywords_expands_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + handler._tfidf = SimpleNamespace(extract_tags=lambda *_args, **_kwargs: ["alpha-beta", "during", "gamma"]) + + keywords = handler.extract_keywords("input text", max_keywords_per_chunk=3) + + assert "alpha-beta" in keywords + assert "alpha" in keywords + assert "beta" in keywords + assert "during" in keywords + assert "gamma" in keywords + + +def test_expand_tokens_with_subtokens_filters_stopwords_from_subtokens(): + handler = JiebaKeywordTableHandler.__new__(JiebaKeywordTableHandler) + + expanded = handler._expand_tokens_with_subtokens({"alpha-during-beta"}) + + assert "alpha-during-beta" in expanded + assert "alpha" in expanded + assert "beta" in expanded + assert "during" not in expanded diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py new file mode 100644 index 0000000000..1b1541ddd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_stopwords.py @@ -0,0 +1,6 @@ +from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS + + +def test_stopwords_loaded(): + assert "during" in STOPWORDS + assert "the" in STOPWORDS diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py new file mode 100644 index 0000000000..55e22aea0a --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_base.py @@ -0,0 +1,97 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.models.document import Document + + +class _KeywordThatRaises(BaseKeyword): + def create(self, texts: list[Document], **kwargs): + return super().create(texts, **kwargs) + + def add_texts(self, texts: list[Document], **kwargs): + return super().add_texts(texts, **kwargs) + + def text_exists(self, id: str) -> bool: + return super().text_exists(id) + + def delete_by_ids(self, ids: list[str]): + return super().delete_by_ids(ids) + + def delete(self): + return super().delete() + + def search(self, query: str, **kwargs): + return super().search(query, **kwargs) + + +class _KeywordForHelpers(BaseKeyword): + def __init__(self, dataset, existing_ids: set[str] | None = None): + super().__init__(dataset) + self._existing_ids = existing_ids or set() + + def create(self, texts: list[Document], **kwargs): + return self + + def add_texts(self, texts: list[Document], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete(self): + return None + + def search(self, query: str, **kwargs): + return [] + + +def test_abstract_methods_raise_not_implemented(): + keyword = _KeywordThatRaises(SimpleNamespace(id="dataset-1")) + + with pytest.raises(NotImplementedError): + keyword.create([]) + + with pytest.raises(NotImplementedError): + keyword.add_texts([]) + + with pytest.raises(NotImplementedError): + keyword.text_exists("doc-1") + + with pytest.raises(NotImplementedError): + keyword.delete_by_ids(["doc-1"]) + + with pytest.raises(NotImplementedError): + keyword.delete() + + with pytest.raises(NotImplementedError): + keyword.search("query") + + +def test_filter_duplicate_texts_removes_existing_doc_ids(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1"), existing_ids={"duplicate"}) + texts = [ + Document(page_content="keep", metadata={"doc_id": "keep"}), + Document(page_content="duplicate", metadata={"doc_id": "duplicate"}), + SimpleNamespace(page_content="without-metadata", metadata=None), + ] + + filtered = keyword._filter_duplicate_texts(texts) + + assert [text.metadata["doc_id"] for text in filtered if text.metadata] == ["keep"] + assert any(text.metadata is None for text in filtered) + + +def test_get_uuids_returns_only_docs_with_metadata(): + keyword = _KeywordForHelpers(SimpleNamespace(id="dataset-1")) + texts = [ + Document(page_content="doc-1", metadata={"doc_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "doc-2"}), + SimpleNamespace(page_content="doc-3", metadata=None), + ] + + assert keyword._get_uuids(texts) == ["doc-1", "doc-2"] diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py new file mode 100644 index 0000000000..0d969a3270 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/keyword/test_keyword_factory.py @@ -0,0 +1,84 @@ +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.keyword.keyword_type import KeyWordType +from core.rag.models.document import Document + + +def test_get_keyword_factory_returns_jieba_factory(monkeypatch): + fake_module = types.ModuleType("core.rag.datasource.keyword.jieba.jieba") + + class FakeJieba: + pass + + fake_module.Jieba = FakeJieba + monkeypatch.setitem(sys.modules, "core.rag.datasource.keyword.jieba.jieba", fake_module) + + assert Keyword.get_keyword_factory(KeyWordType.JIEBA) is FakeJieba + + +def test_get_keyword_factory_raises_for_unsupported_type(): + with pytest.raises(ValueError, match="Keyword store unsupported is not supported"): + Keyword.get_keyword_factory("unsupported") + + +def test_keyword_initialization_uses_configured_factory(monkeypatch): + dataset = SimpleNamespace(id="dataset-1") + fake_processor = MagicMock() + + monkeypatch.setattr("core.rag.datasource.keyword.keyword_factory.dify_config.KEYWORD_STORE", KeyWordType.JIEBA) + monkeypatch.setattr(Keyword, "get_keyword_factory", staticmethod(lambda keyword_type: lambda _: fake_processor)) + + keyword = Keyword(dataset) + + assert keyword._keyword_processor is fake_processor + + +def test_keyword_methods_forward_to_processor(): + processor = MagicMock() + processor.text_exists.return_value = True + processor.search.return_value = [Document(page_content="matched", metadata={"doc_id": "doc-1"})] + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = processor + + docs = [Document(page_content="doc", metadata={"doc_id": "doc-1"})] + keyword.create(docs, foo="bar") + keyword.add_texts(docs, batch=True) + assert keyword.text_exists("doc-1") is True + keyword.delete_by_ids(["doc-1"]) + keyword.delete() + assert keyword.search("query", top_k=1) == processor.search.return_value + + processor.create.assert_called_once_with(docs, foo="bar") + processor.add_texts.assert_called_once_with(docs, batch=True) + processor.text_exists.assert_called_once_with("doc-1") + processor.delete_by_ids.assert_called_once_with(["doc-1"]) + processor.delete.assert_called_once() + processor.search.assert_called_once_with("query", top_k=1) + + +def test_keyword_getattr_returns_callable_and_raises_for_invalid_attributes(): + class Processor: + value = 1 + + @staticmethod + def custom(): + return "ok" + + keyword = Keyword.__new__(Keyword) + keyword._keyword_processor = Processor() + + assert keyword.custom() == "ok" + + with pytest.raises(AttributeError): + _ = keyword.value + + keyword._keyword_processor = None + with pytest.raises(AttributeError): + _ = keyword.missing_method diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py new file mode 100644 index 0000000000..63de4b8af2 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -0,0 +1,1176 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, call, patch +from uuid import uuid4 + +import pytest + +from core.rag.datasource import retrieval_service as retrieval_service_module +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType +from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset + + +def create_mock_document( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + This helper function standardizes document creation across tests, + ensuring consistent structure and reducing code duplication. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + + Example: + >>> doc = create_mock_document("Python is great", "doc1", score=0.95) + >>> assert doc.metadata["score"] == 0.95 + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + # Merge additional metadata if provided + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +class _ImmediateFuture: + def __init__(self, exception: Exception | None = None) -> None: + self._exception = exception + self.cancel_called = False + + def exception(self) -> Exception | None: + return self._exception + + def cancel(self) -> None: + self.cancel_called = True + + +class _ImmediateExecutor: + def __init__(self) -> None: + self.futures: list[_ImmediateFuture] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + def submit(self, fn, *args, **kwargs): + try: + fn(*args, **kwargs) + future = _ImmediateFuture() + except Exception as exc: # pragma: no cover - only for defensive parity with Future semantics + future = _ImmediateFuture(exc) + self.futures.append(future) + return future + + +class _FakeExecuteScalarResult: + def __init__(self, data: list) -> None: + self._data = data + + def all(self) -> list: + return self._data + + +class _FakeExecuteResult: + def __init__(self, data: list) -> None: + self._data = data + + def scalars(self) -> _FakeExecuteScalarResult: + return _FakeExecuteScalarResult(self._data) + + +class _FakeSummaryQuery: + def __init__(self, summaries: list) -> None: + self._summaries = summaries + + def filter(self, *args, **kwargs): + return self + + def all(self) -> list: + return self._summaries + + +class _FakeSession: + def __init__(self, execute_payloads: list[list], summaries: list) -> None: + self._payloads = list(execute_payloads) + self._summaries = summaries + + def execute(self, stmt): + data = self._payloads.pop(0) if self._payloads else [] + return _FakeExecuteResult(data) + + def query(self, model): + return _FakeSummaryQuery(self._summaries) + + +class _FakeSessionContext: + def __init__(self, session: _FakeSession) -> None: + self._session = session + + def __enter__(self) -> _FakeSession: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +class _SimpleRetrievalChildChunk: + def __init__(self, id: str, content: str, score: float, position: int) -> None: + self.id = id + self.content = content + self.score = score + self.position = position + + +class _SimpleRetrievalSegment: + def __init__( + self, + segment, + child_chunks: list[_SimpleRetrievalChildChunk] | None = None, + score: float | None = None, + files: list[dict[str, str | int]] | None = None, + summary: str | None = None, + ) -> None: + self.segment = segment + self.child_chunks = child_chunks + self.score = score + self.files = files + self.summary = summary + + +class TestRetrievalServiceInternals: + @pytest.fixture + def internal_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = "dataset-id" + dataset.tenant_id = "tenant-id" + dataset.is_multimodal = False + dataset.doc_form = IndexStructureType.PARENT_CHILD_INDEX + return dataset + + @pytest.fixture + def internal_flask_app(self): + app = MagicMock() + app.app_context.return_value.__enter__ = Mock() + app.app_context.return_value.__exit__.return_value = False + return app + + def test_retrieve_with_attachment_ids_only(self, monkeypatch, internal_dataset): + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset", return_value=internal_dataset), + patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") as mock_retrieve, + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + def side_effect( + flask_app, + retrieval_method, + dataset, + all_documents, + exceptions, + query=None, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + ): + all_documents.append(create_mock_document(f"content-{attachment_id}", attachment_id or "none", 0.9)) + + mock_retrieve.side_effect = side_effect + + results = RetrievalService.retrieve( + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset_id=internal_dataset.id, + query="", + attachment_ids=["att-1", "att-2"], + ) + + assert len(results) == 2 + assert {doc.metadata["doc_id"] for doc in results} == {"att-1", "att-2"} + assert mock_retrieve.call_count == 2 + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + mock_validate.return_value = "validated-condition" + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, + ) + + assert results == expected_documents + mock_validate.assert_called_once() + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition="validated-condition", + ) + + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_returns_empty_when_dataset_not_found(self, mock_scalar): + mock_scalar.return_value = None + + results = RetrievalService.external_retrieve(dataset_id="missing", query="q") + + assert results == [] + + @patch("core.rag.datasource.retrieval_service.Session") + def test_get_dataset_queries_by_id(self, mock_session_class): + expected_dataset = Mock(spec=Dataset) + mock_session = Mock() + mock_session.query.return_value.where.return_value.first.return_value = expected_dataset + mock_session_class.return_value.__enter__.return_value = mock_session + + with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())): + result = RetrievalService._get_dataset("dataset-123") + + assert result == expected_dataset + mock_session.query.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_success(self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.return_value = [create_mock_document("keyword-content", "kw-1", 0.91)] + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "with quotes"', + top_k=5, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + keyword_instance.search.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_dataset_missing(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Keyword") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_keyword_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_keyword_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + keyword_instance = Mock() + keyword_instance.search.side_effect = RuntimeError("keyword failed") + mock_keyword_class.return_value = keyword_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.keyword_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=2, + all_documents=all_documents, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["keyword failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.return_value = [create_mock_document("vector-content", "vec-1", 0.7)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + document_ids_filter=["doc-1"], + query_type=QueryType.TEXT_QUERY, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_vector.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_non_multimodal_returns_early( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-1", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == [] + assert exceptions == [] + vector_instance.search_by_file.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_with_vision_reranking( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + reranked_docs = [create_mock_document("image-content-reranked", "img-doc", 0.97)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = True + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) + model_manager.check_model_support_vision.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_image_multimodal_without_vision_support( + self, + mock_get_dataset, + mock_vector_class, + mock_processor_class, + mock_model_manager_class, + internal_dataset, + internal_flask_app, + ): + internal_dataset.is_multimodal = True + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("image-content", "img-doc", 0.73)] + + vector_instance = Mock() + vector_instance.search_by_file.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = [create_mock_document("unused", "unused", 0.1)] + mock_processor_class.return_value = processor_instance + + model_manager = Mock() + model_manager.check_model_support_vision.return_value = False + mock_model_manager_class.return_value = model_manager + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="file-id", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.IMAGE_QUERY, + ) + + assert all_documents == original_docs + assert exceptions == [] + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) + processor_instance.invoke.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_text_with_reranking_non_multimodal( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + internal_dataset.is_multimodal = False + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("vector-content", "vec-doc", 0.62)] + reranked_docs = [create_mock_document("vector-content-reranked", "vec-doc", 0.89)] + + vector_instance = Mock() + vector_instance.search_by_vector.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_embedding_search_appends_exception_when_vector_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_vector.side_effect = RuntimeError("vector failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.embedding_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.5, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + exceptions=exceptions, + query_type=QueryType.TEXT_QUERY, + ) + + assert all_documents == [] + assert exceptions == ["vector failed"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_without_reranking( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = [create_mock_document("fulltext", "ft-1", 0.68)] + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query='query "x"', + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert len(all_documents) == 1 + assert exceptions == [] + vector_instance.search_by_full_text.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.DataPostProcessor") + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_with_reranking( + self, mock_get_dataset, mock_vector_class, mock_processor_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + original_docs = [create_mock_document("fulltext", "ft-1", 0.68)] + reranked_docs = [create_mock_document("fulltext-reranked", "ft-1", 0.9)] + + vector_instance = Mock() + vector_instance.search_by_full_text.return_value = original_docs + mock_vector_class.return_value = vector_instance + + processor_instance = Mock() + processor_instance.invoke.return_value = reranked_docs + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model={ + "reranking_provider_name": "provider", + "reranking_model_name": "model", + }, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == reranked_docs + assert exceptions == [] + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_dataset_not_found(self, mock_get_dataset, internal_flask_app): + mock_get_dataset.return_value = None + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id="dataset-id", + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["dataset not found"] + + @patch("core.rag.datasource.retrieval_service.Vector") + @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") + def test_full_text_index_search_appends_exception_when_search_fails( + self, mock_get_dataset, mock_vector_class, internal_dataset, internal_flask_app + ): + mock_get_dataset.return_value = internal_dataset + vector_instance = Mock() + vector_instance.search_by_full_text.side_effect = RuntimeError("fulltext failed") + mock_vector_class.return_value = vector_instance + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService.full_text_index_search( + flask_app=internal_flask_app, + dataset_id=internal_dataset.id, + query="query", + top_k=4, + score_threshold=0.4, + reranking_model=None, + all_documents=all_documents, + retrieval_method=RetrievalMethod.FULL_TEXT_SEARCH, + exceptions=exceptions, + ) + + assert all_documents == [] + assert exceptions == ["fulltext failed"] + + def test_format_retrieval_documents_with_empty_input_returns_empty_list(self): + assert RetrievalService.format_retrieval_documents([]) == [] + + def test_format_retrieval_documents_without_document_id_returns_empty_list(self): + documents = [Document(page_content="content", metadata={"doc_id": "doc-1", "score": 0.4}, provider="dify")] + + assert RetrievalService.format_retrieval_documents(documents) == [] + + def test_format_retrieval_documents_with_parent_child_summary_and_attachments(self, monkeypatch): + dataset_doc_parent = SimpleNamespace( + id="doc-parent", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + dataset_doc_text = SimpleNamespace(id="doc-text", doc_form="paragraph", dataset_id="dataset-id") + dataset_doc_parent_summary = SimpleNamespace( + id="doc-parent-summary", + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + dataset_id="dataset-id", + ) + + dataset_query = Mock() + dataset_query.where.return_value.options.return_value.all.return_value = [ + dataset_doc_parent, + dataset_doc_text, + dataset_doc_parent_summary, + ] + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query)) + monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) + monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) + + input_documents = [ + Document( + page_content="child node content", + metadata={"document_id": "doc-parent", "doc_id": "child-node-1", "score": 0.7}, + provider="dify", + ), + Document( + page_content="parent image", + metadata={ + "document_id": "doc-parent", + "doc_id": "attach-node-1", + "doc_type": DocType.IMAGE, + "score": 0.8, + }, + provider="dify", + ), + Document( + page_content="text index node", + metadata={"document_id": "doc-text", "doc_id": "index-node-1", "score": 0.6}, + provider="dify", + ), + Document( + page_content="text image node", + metadata={ + "document_id": "doc-text", + "doc_id": "attach-text-1", + "doc_type": DocType.IMAGE, + "score": 0.65, + }, + provider="dify", + ), + Document( + page_content="summary candidate 1", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-1", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.9", + }, + provider="dify", + ), + Document( + page_content="summary candidate 2", + metadata={ + "document_id": "doc-text", + "doc_id": "summary-node-2", + "is_summary": True, + "original_chunk_id": "segment-summary", + "score": "0.95", + }, + provider="dify", + ), + Document( + page_content="invalid score summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-invalid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "invalid", + }, + provider="dify", + ), + Document( + page_content="valid parent summary", + metadata={ + "document_id": "doc-parent-summary", + "doc_id": "summary-parent-valid", + "is_summary": True, + "original_chunk_id": "segment-parent-summary", + "score": "0.4", + }, + provider="dify", + ), + ] + + child_chunk = SimpleNamespace( + id="child-chunk-1", + segment_id="segment-parent", + index_node_id="child-node-1", + content="child details", + position=2, + ) + segment_parent = SimpleNamespace(id="segment-parent", document_id="doc-parent", index_node_id="parent-node") + segment_text = SimpleNamespace(id="segment-text", document_id="doc-text", index_node_id="index-node-1") + segment_summary = SimpleNamespace(id="segment-summary", document_id="doc-text", index_node_id="summary-node") + segment_parent_summary = SimpleNamespace( + id="segment-parent-summary", + document_id="doc-parent-summary", + index_node_id="summary-parent-node", + ) + + fake_session = _FakeSession( + execute_payloads=[ + [child_chunk], + [segment_text], + [segment_parent, segment_text], + [segment_summary, segment_parent_summary], + ], + summaries=[ + SimpleNamespace(chunk_id="segment-summary", summary_content="summary for text"), + SimpleNamespace(chunk_id="segment-parent-summary", summary_content="summary for parent"), + ], + ) + monkeypatch.setattr( + retrieval_service_module.session_factory, + "create_session", + lambda: _FakeSessionContext(fake_session), + ) + monkeypatch.setattr( + RetrievalService, + "get_segment_attachment_infos", + lambda attachment_ids, session: [ + { + "attachment_id": "attach-node-1", + "attachment_info": { + "id": "attach-node-1", + "name": "img-parent", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://parent", + "size": 11, + }, + "segment_id": "segment-parent", + }, + { + "attachment_id": "attach-text-1", + "attachment_info": { + "id": "attach-text-1", + "name": "img-text", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://text", + "size": 22, + }, + "segment_id": "segment-text", + }, + ], + ) + + result = RetrievalService.format_retrieval_documents(input_documents) + + assert len(result) == 4 + result_by_segment_id = {item.segment.id: item for item in result} + assert result_by_segment_id["segment-summary"].score == pytest.approx(0.95) + assert result_by_segment_id["segment-summary"].summary == "summary for text" + assert result_by_segment_id["segment-parent"].score == pytest.approx(0.8) + assert result_by_segment_id["segment-parent"].files is not None + assert len(result_by_segment_id["segment-parent"].child_chunks or []) == 1 + assert result_by_segment_id["segment-text"].score == pytest.approx(0.65) + assert result_by_segment_id["segment-parent-summary"].score == pytest.approx(0.4) + assert result_by_segment_id["segment-parent-summary"].summary == "summary for parent" + assert result_by_segment_id["segment-parent-summary"].child_chunks == [] + + def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): + rollback = Mock() + monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) + monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error"))) + + documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] + + with pytest.raises(RuntimeError, match="db error"): + RetrievalService.format_retrieval_documents(documents) + + rollback.assert_called_once() + + def test_retrieve_internal_returns_early_without_query_or_attachment(self, internal_dataset, internal_flask_app): + all_documents: list[Document] = [] + exceptions: list[str] = [] + + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=exceptions, + query=None, + attachment_id=None, + ) + + assert all_documents == [] + assert exceptions == [] + + def test_retrieve_internal_cancels_futures_when_future_has_exception(self, internal_dataset, internal_flask_app): + future_error = Mock() + future_error.exception.return_value = RuntimeError("future failed") + future_ok = Mock() + future_ok.exception.return_value = None + + with ( + patch("core.rag.datasource.retrieval_service.ThreadPoolExecutor") as mock_executor, + patch( + "core.rag.datasource.retrieval_service.concurrent.futures.as_completed", + return_value=[future_error, future_ok], + ), + ): + mock_executor_instance = Mock() + mock_executor_instance.submit.side_effect = [future_error, future_ok] + mock_executor.return_value.__enter__.return_value = mock_executor_instance + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=[], + query="query", + attachment_id="file-1", + ) + + future_error.cancel.assert_called() + future_ok.cancel.assert_called() + + def test_retrieve_internal_raises_value_error_when_exceptions_exist( + self, monkeypatch, internal_dataset, internal_flask_app + ): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + with patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") as mock_keyword_search: + mock_keyword_search.side_effect = lambda *args, **kwargs: None + with pytest.raises(ValueError, match="keyword error"): + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.KEYWORD_SEARCH, + dataset=internal_dataset, + all_documents=[], + exceptions=["keyword error"], + query="query", + ) + + def test_retrieve_internal_hybrid_weighted_attachment_flow(self, monkeypatch, internal_dataset, internal_flask_app): + executor = _ImmediateExecutor() + monkeypatch.setattr(retrieval_service_module, "ThreadPoolExecutor", lambda *args, **kwargs: executor) + monkeypatch.setattr( + retrieval_service_module.concurrent.futures, + "as_completed", + lambda futures, timeout=None: iter(futures), + ) + + text_doc = create_mock_document("text", "text-doc", 0.81) + image_doc = create_mock_document("image", "image-doc", 0.72) + fulltext_doc = create_mock_document("full", "full-doc", 0.65) + processed_doc = create_mock_document("processed", "processed-doc", 0.99) + + with ( + patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") as mock_embedding_search, + patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search") as mock_fulltext, + patch("core.rag.datasource.retrieval_service.DataPostProcessor") as mock_processor_class, + ): + + def embedding_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + query_type=QueryType.TEXT_QUERY, + ): + if query_type == QueryType.IMAGE_QUERY: + all_documents.append(image_doc) + else: + all_documents.append(text_doc) + + mock_embedding_search.side_effect = embedding_side_effect + + def fulltext_side_effect( + flask_app, + dataset_id, + query, + top_k, + score_threshold, + reranking_model, + all_documents, + retrieval_method, + exceptions, + document_ids_filter=None, + ): + all_documents.append(fulltext_doc) + + mock_fulltext.side_effect = fulltext_side_effect + processor_instance = Mock() + processor_instance.invoke.return_value = [processed_doc] + mock_processor_class.return_value = processor_instance + + all_documents: list[Document] = [] + RetrievalService()._retrieve( + flask_app=internal_flask_app, + retrieval_method=RetrievalMethod.HYBRID_SEARCH, + dataset=internal_dataset, + all_documents=all_documents, + exceptions=[], + query="query", + attachment_id="file-1", + reranking_mode=RerankMode.WEIGHTED_SCORE, + top_k=3, + ) + + assert len(all_documents) == 4 + assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) + processor_instance.invoke.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_info_success(self, mock_sign): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1") + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = binding + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result == { + "attachment_info": { + "id": "upload-1", + "name": "file-name", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + mock_sign.assert_called_once_with("upload-1", "png") + + def test_get_segment_attachment_info_returns_none_when_binding_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.first.return_value = upload_file + binding_query = Mock() + binding_query.where.return_value.first.return_value = None + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self): + upload_query = Mock() + upload_query.where.return_value.first.return_value = None + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) + + assert result is None + + def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self): + upload_query = Mock() + upload_query.where.return_value.all.return_value = [] + session = Mock() + session.query.return_value = upload_query + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + def test_get_segment_attachment_infos_returns_empty_when_bindings_missing(self): + upload_file = SimpleNamespace( + id="upload-1", + name="file-name", + extension="png", + mime_type="image/png", + size=42, + ) + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) + + assert result == [] + + @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + def test_get_segment_attachment_infos_success(self, mock_sign): + upload_file_1 = SimpleNamespace( + id="upload-1", + name="file-1", + extension="png", + mime_type="image/png", + size=42, + ) + upload_file_2 = SimpleNamespace( + id="upload-2", + name="file-2", + extension="jpg", + mime_type="image/jpeg", + size=99, + ) + binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1") + + upload_query = Mock() + upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2] + binding_query = Mock() + binding_query.where.return_value.all.return_value = [binding] + session = Mock() + session.query.side_effect = [upload_query, binding_query] + + result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session) + + assert result == [ + { + "attachment_id": "upload-1", + "attachment_info": { + "id": "upload-1", + "name": "file-1", + "extension": ".png", + "mime_type": "image/png", + "source_url": "signed://file", + "size": 42, + }, + "segment_id": "segment-1", + } + ] + mock_sign.assert_has_calls( + [ + call("upload-1", "png"), + call("upload-2", "jpg"), + ] + ) + assert mock_sign.call_count == 2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py new file mode 100644 index 0000000000..e063a49f22 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py @@ -0,0 +1,74 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module +from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory + + +def test_validate_distance_function_accepts_supported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + assert factory._validate_distance_function("cosine") == "cosine" + assert factory._validate_distance_function("euclidean") == "euclidean" + + +def test_validate_distance_function_rejects_unsupported_values(): + factory = AlibabaCloudMySQLVectorFactory() + + with pytest.raises(ValueError, match="Invalid distance function"): + factory._validate_distance_function("dot_product") + + +def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "cosine") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 6) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" + + +def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch): + factory = AlibabaCloudMySQLVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", + index_struct_dict=None, + index_struct=None, + ) + + monkeypatch.setattr(alibaba_module.Dataset, "gen_collection_name_by_id", lambda dataset_id: f"COL_{dataset_id}") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HOST", "host") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PORT", 3306) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_USER", "user") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_PASSWORD", "password") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DATABASE", "db") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_MAX_CONNECTION", 5) + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_CHARSET", "utf8mb4") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION", "euclidean") + monkeypatch.setattr(alibaba_module.dify_config, "ALIBABACLOUD_MYSQL_HNSW_M", 12) + + with patch.object(alibaba_module, "AlibabaCloudMySQLVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + vector_cls.assert_called_once() + assert vector_cls.call_args.kwargs["collection_name"] == "COL_dataset-2" + assert dataset.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py new file mode 100644 index 0000000000..545565cdf4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py @@ -0,0 +1,133 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig +from core.rag.models.document import Document + + +def test_init_prefers_openapi_when_api_config_is_provided(): + api_config = AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorOpenAPI", return_value="openapi_runner") as openapi_cls: + vector = AnalyticdbVector("COLLECTION", api_config=api_config, sql_config=None) + + assert vector.analyticdb_vector == "openapi_runner" + openapi_cls.assert_called_once_with("COLLECTION", api_config) + + +def test_init_uses_sql_implementation_when_api_config_is_missing(): + sql_config = AnalyticdbVectorBySqlConfig( + host="localhost", + port=5432, + account="account", + account_password="password", + min_connection=1, + max_connection=2, + namespace="dify", + ) + + with patch.object(analyticdb_module, "AnalyticdbVectorBySql", return_value="sql_runner") as sql_cls: + vector = AnalyticdbVector("COLLECTION", api_config=None, sql_config=sql_config) + + assert vector.analyticdb_vector == "sql_runner" + sql_cls.assert_called_once_with("COLLECTION", sql_config) + + +def test_init_raises_when_both_configs_are_missing(): + with pytest.raises(ValueError, match="Either api_config or sql_config must be provided"): + AnalyticdbVector("COLLECTION", api_config=None, sql_config=None) + + +def test_vector_methods_delegate_to_underlying_implementation(): + runner = MagicMock() + runner.search_by_vector.return_value = [Document(page_content="v", metadata={"doc_id": "1"})] + runner.search_by_full_text.return_value = [Document(page_content="t", metadata={"doc_id": "2"})] + runner.text_exists.return_value = True + + vector = AnalyticdbVector.__new__(AnalyticdbVector) + vector.analyticdb_vector = runner + + texts = [Document(page_content="hello", metadata={"doc_id": "d1"})] + vector.create(texts=texts, embeddings=[[0.1, 0.2]]) + vector.add_texts(documents=texts, embeddings=[[0.1, 0.2]]) + assert vector.text_exists("d1") is True + vector.delete_by_ids(["d1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector.search_by_vector([0.1, 0.2], top_k=2) == runner.search_by_vector.return_value + assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value + vector.delete() + + runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) + runner.delete_by_ids.assert_called_once_with(["d1"]) + runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") + runner.delete.assert_called_once() + + +def test_get_type_is_analyticdb(): + vector = AnalyticdbVector.__new__(AnalyticdbVector) + assert vector.get_type() == "analyticdb" + + +def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(analyticdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", None) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_ID", "ak") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_KEY_SECRET", "sk") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_REGION_ID", "cn-hz") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_INSTANCE_ID", "instance") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE_PASSWORD", "ns-password") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "auto_collection" + assert isinstance(args[1], AnalyticdbVectorOpenAPIConfig) + assert args[2] is None + assert dataset.index_struct is not None + + +def test_factory_builds_sql_config_when_host_is_present(monkeypatch): + factory = AnalyticdbVectorFactory() + dataset = SimpleNamespace( + id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_HOST", "127.0.0.1") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PORT", 5432) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_ACCOUNT", "account") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_PASSWORD", "password") + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MIN_CONNECTION", 1) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_MAX_CONNECTION", 3) + monkeypatch.setattr(analyticdb_module.dify_config, "ANALYTICDB_NAMESPACE", "dify") + + with patch.object(analyticdb_module, "AnalyticdbVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + args = vector_cls.call_args.args + assert args[0] == "existing" + assert args[1] is None + assert isinstance(args[2], AnalyticdbVectorBySqlConfig) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py new file mode 100644 index 0000000000..45777774d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py @@ -0,0 +1,384 @@ +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from core.rag.models.document import Document + + +def _request_class(name: str): + class _Request: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + _Request.__name__ = name + return _Request + + +def _install_openapi_stubs(monkeypatch): + gpdb_package = types.ModuleType("alibabacloud_gpdb20160503") + gpdb_package.__path__ = [] + gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models") + for class_name in [ + "InitVectorDatabaseRequest", + "DescribeNamespaceRequest", + "CreateNamespaceRequest", + "DescribeCollectionRequest", + "CreateCollectionRequest", + "UpsertCollectionDataRequestRows", + "UpsertCollectionDataRequest", + "QueryCollectionDataRequest", + "DeleteCollectionDataRequest", + "DeleteCollectionRequest", + ]: + setattr(gpdb_models, class_name, _request_class(class_name)) + + class _Client: + def __init__(self, config): + self.config = config + + gpdb_client = types.ModuleType("alibabacloud_gpdb20160503.client") + gpdb_client.Client = _Client + gpdb_package.models = gpdb_models + + tea_openapi = types.ModuleType("alibabacloud_tea_openapi") + tea_openapi.__path__ = [] + tea_openapi_models = types.ModuleType("alibabacloud_tea_openapi.models") + + class OpenApiConfig: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + tea_openapi_models.Config = OpenApiConfig + tea_openapi.models = tea_openapi_models + + tea_package = types.ModuleType("Tea") + tea_package.__path__ = [] + tea_exceptions = types.ModuleType("Tea.exceptions") + + class TeaError(Exception): + def __init__(self, status_code=None, **kwargs): + super().__init__("TeaException") + status_code = kwargs.get("statusCode", status_code) + self.statusCode = status_code + self.status_code = status_code + + tea_exceptions.TeaException = TeaError + tea_package.exceptions = tea_exceptions + + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503", gpdb_package) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.models", gpdb_models) + monkeypatch.setitem(sys.modules, "alibabacloud_gpdb20160503.client", gpdb_client) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi", tea_openapi) + monkeypatch.setitem(sys.modules, "alibabacloud_tea_openapi.models", tea_openapi_models) + monkeypatch.setitem(sys.modules, "Tea", tea_package) + monkeypatch.setitem(sys.modules, "Tea.exceptions", tea_exceptions) + + return SimpleNamespace(models=gpdb_models, TeaException=TeaError, OpenApiConfig=OpenApiConfig) + + +def _config() -> AnalyticdbVectorOpenAPIConfig: + return AnalyticdbVectorOpenAPIConfig( + access_key_id="ak", + access_key_secret="sk", + region_id="cn-hangzhou", + instance_id="instance-1", + account="account", + account_password="password", + namespace="dify", + namespace_password="ns-password", + ) + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("access_key_id", "", "ANALYTICDB_KEY_ID"), + ("access_key_secret", "", "ANALYTICDB_KEY_SECRET"), + ("region_id", "", "ANALYTICDB_REGION_ID"), + ("instance_id", "", "ANALYTICDB_INSTANCE_ID"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("namespace_password", "", "ANALYTICDB_NAMESPACE_PASSWORD"), + ], +) +def test_openapi_config_validation(field, value, error_message): + values = _config().model_dump() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorOpenAPIConfig.model_validate(values) + + +def test_openapi_config_to_client_params(): + config = _config() + params = config.to_analyticdb_client_params() + + assert params["access_key_id"] == "ak" + assert params["access_key_secret"] == "sk" + assert params["region_id"] == "cn-hangzhou" + assert params["read_timeout"] == 60000 + + +def test_init_creates_openapi_client_and_runs_initialize(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + initialize_mock = MagicMock() + monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock) + + vector = AnalyticdbVectorOpenAPI("COLLECTION_1", _config()) + + assert vector._collection_name == "collection_1" + assert isinstance(vector._client_config, stubs.OpenApiConfig) + assert vector._client_config.user_agent == "dify" + assert vector._client_config.access_key_id == "ak" + assert vector._client.config is vector._client_config + initialize_mock.assert_called_once_with() + + +def test_initialize_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + vector._create_namespace_if_not_exists.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._initialize_vector_database = MagicMock() + vector._create_namespace_if_not_exists = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + vector._create_namespace_if_not_exists.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_initialize_vector_database_calls_openapi_client(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._initialize_vector_database() + + request = vector._client.init_vector_database.call_args.args[0] + assert request.dbinstance_id == "instance-1" + assert request.region_id == "cn-hangzhou" + assert request.manager_account == "account" + assert request.manager_account_password == "password" + + +def test_create_namespace_creates_when_namespace_not_found(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=404) + + vector._create_namespace_if_not_exists() + + vector._client.create_namespace.assert_called_once() + + +def test_create_namespace_raises_on_unexpected_api_error(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_namespace.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create namespace"): + vector._create_namespace_if_not_exists() + + +def test_create_namespace_noop_when_namespace_exists(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector.config = _config() + vector._client = MagicMock() + + vector._create_namespace_if_not_exists() + + vector._client.describe_namespace.assert_called_once() + vector._client.create_namespace.assert_not_called() + + +def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.create_collection.assert_called_once() + openapi_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector._create_collection_if_not_exists(embedding_dimension=1024) + + vector._client.describe_collection.assert_not_called() + vector._client.create_collection.assert_not_called() + + +def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): + stubs = _install_openapi_stubs(monkeypatch) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(openapi_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(openapi_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(openapi_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) + + with pytest.raises(ValueError, match="failed to create collection collection_1"): + vector._create_collection_if_not_exists(embedding_dimension=512) + + +def test_openapi_add_delete_and_search_methods(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + documents = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + embeddings = [[0.1, 0.2], [0.2, 0.3]] + vector.add_texts(documents, embeddings) + + upsert_request = vector._client.upsert_collection_data.call_args.args[0] + assert upsert_request.collection == "collection_1" + assert len(upsert_request.rows) == 1 + + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[SimpleNamespace()])) + ) + assert vector.text_exists("d1") is True + + vector.delete_by_ids(["d1", "d2"]) + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "ref_doc_id IN ('d1','d2')" + + vector.delete_by_metadata_field("document_id", "doc-1") + request = vector._client.delete_collection_data.call_args.args[0] + assert request.collection_data_filter == "metadata_ ->> 'document_id' = 'doc-1'" + + match_high = SimpleNamespace( + score=0.9, + metadata={"metadata_": json.dumps({"document_id": "doc-1"}), "page_content": "high"}, + values=SimpleNamespace(value=[1.0, 2.0]), + ) + match_low = SimpleNamespace( + score=0.1, + metadata={"metadata_": json.dumps({"document_id": "doc-2"}), "page_content": "low"}, + values=SimpleNamespace(value=[3.0, 4.0]), + ) + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[match_low, match_high])) + ) + + docs_by_vector = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + assert len(docs_by_vector) == 1 + assert docs_by_vector[0].page_content == "high" + assert docs_by_vector[0].metadata["score"] == 0.9 + + docs_by_text = vector.search_by_full_text("hello", top_k=2, score_threshold=0.2) + assert len(docs_by_text) == 1 + assert docs_by_text[0].page_content == "high" + + +def test_text_exists_returns_false_when_matches_empty(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.query_collection_data.return_value = SimpleNamespace( + body=SimpleNamespace(matches=SimpleNamespace(match=[])) + ) + + assert vector.text_exists("missing-id") is False + + +def test_openapi_delete_success(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + + vector.delete() + vector._client.delete_collection.assert_called_once() + + +def test_openapi_delete_propagates_errors(monkeypatch): + _install_openapi_stubs(monkeypatch) + vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI) + vector._collection_name = "collection_1" + vector.config = _config() + vector._client = MagicMock() + vector._client.delete_collection.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.delete() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py new file mode 100644 index 0000000000..8f1206696b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py @@ -0,0 +1,427 @@ +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import psycopg2.errors +import pytest + +import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module +from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( + AnalyticdbVectorBySql, + AnalyticdbVectorBySqlConfig, +) +from core.rag.models.document import Document + + +def _config_values() -> dict: + return { + "host": "localhost", + "port": 5432, + "account": "account", + "account_password": "password", + "min_connection": 1, + "max_connection": 2, + "namespace": "dify", + } + + +@pytest.mark.parametrize( + ("field", "value", "error_message"), + [ + ("host", "", "ANALYTICDB_HOST"), + ("port", 0, "ANALYTICDB_PORT"), + ("account", "", "ANALYTICDB_ACCOUNT"), + ("account_password", "", "ANALYTICDB_PASSWORD"), + ("min_connection", 0, "ANALYTICDB_MIN_CONNECTION"), + ("max_connection", 0, "ANALYTICDB_MAX_CONNECTION"), + ], +) +def test_sql_config_required_fields(field, value, error_message): + values = _config_values() + values[field] = value + + with pytest.raises(ValueError, match=error_message): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_sql_config_rejects_min_connection_greater_than_max_connection(): + values = _config_values() + values["min_connection"] = 10 + values["max_connection"] = 2 + + with pytest.raises(ValueError, match="ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION"): + AnalyticdbVectorBySqlConfig.model_validate(values) + + +def test_initialize_skips_when_cache_exists(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_not_called() + + +def test_initialize_runs_when_cache_is_missing(monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._initialize_vector_database = MagicMock() + + vector._initialize() + + vector._initialize_vector_database.assert_called_once() + sql_module.redis_client.set.assert_called_once() + + +def test_create_connection_pool_uses_psycopg2_pool(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + pool_instance = MagicMock() + monkeypatch.setattr(sql_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool_instance)) + + pool = vector._create_connection_pool() + + assert pool is pool_instance + sql_module.psycopg2.pool.SimpleConnectionPool.assert_called_once() + + +def test_get_cursor_context_manager_handles_connection_lifecycle(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + cursor = MagicMock() + connection = MagicMock() + connection.cursor.return_value = cursor + pool = MagicMock() + pool.getconn.return_value = connection + vector.pool = pool + + with vector._get_cursor() as cur: + assert cur is cursor + + cursor.close.assert_called_once() + connection.commit.assert_called_once() + pool.putconn.assert_called_once_with(connection) + + +def test_add_texts_inserts_only_documents_with_metadata(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + monkeypatch.setattr(sql_module.uuid, "uuid4", lambda: "prefix-id") + monkeypatch.setattr(sql_module.psycopg2.extras, "execute_batch", MagicMock()) + + docs = [ + Document(page_content="doc 1", metadata={"doc_id": "d1", "document_id": "doc-1"}), + SimpleNamespace(page_content="doc 2", metadata=None), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.2, 0.3]]) + + execute_args = sql_module.psycopg2.extras.execute_batch.call_args.args + assert execute_args[0] is cursor + assert len(execute_args[2]) == 1 + + +def test_text_exists_returns_true_and_false_based_on_query_result(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.fetchone.return_value = ("row",) + assert vector.text_exists("d1") is True + + cursor.fetchone.return_value = None + assert vector.text_exists("d1") is False + + +def test_delete_by_ids_handles_empty_input_and_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_ids(["d1"]) + + +def test_delete_by_metadata_field_handles_missing_table_error(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + cursor.execute.side_effect = psycopg2.errors.UndefinedTable("relation does not exist") + vector.delete_by_metadata_field("document_id", "doc-1") + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_vector_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=invalid_top_k) + + +def test_search_by_vector_returns_documents_above_threshold(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], 0.8, "content 1", {"doc_id": "id1", "document_id": "doc-1"}), + ("id2", [2.0], 0.3, "content 2", {"doc_id": "id2", "document_id": "doc-2"}), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content 1" + assert docs[0].metadata["score"] == 0.8 + + +@pytest.mark.parametrize("invalid_top_k", [0, "x", -1]) +def test_search_by_full_text_validates_top_k(invalid_top_k): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=invalid_top_k) + + +def test_search_by_full_text_returns_documents(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ("id1", [1.0], "content 1", {"doc_id": "id1", "document_id": "doc-1"}, 0.9), + ] + ) + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + docs = vector.search_by_full_text("query", top_k=1, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + assert docs[0].page_content == "content 1" + + +def test_delete_drops_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + vector.delete() + + cursor.execute.assert_called_once() + + +def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch): + config = AnalyticdbVectorBySqlConfig(**_config_values()) + created_pool = MagicMock() + + monkeypatch.setattr(AnalyticdbVectorBySql, "_initialize", MagicMock()) + monkeypatch.setattr(AnalyticdbVectorBySql, "_create_connection_pool", MagicMock(return_value=created_pool)) + + vector = AnalyticdbVectorBySql("My_Collection", config) + + assert vector._collection_name == "my_collection" + assert vector.table_name == "dify.my_collection" + assert vector.databaseName == "knowledgebase" + assert vector.pool is created_pool + + +def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + bootstrap_cursor.execute.side_effect = RuntimeError("database already exists") + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + + def _execute(sql, *args, **kwargs): + if "CREATE TEXT SEARCH CONFIGURATION zh_cn" in sql: + raise RuntimeError("already exists") + + worker_cursor.execute.side_effect = _execute + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + vector._initialize_vector_database() + + bootstrap_cursor.close.assert_called_once() + bootstrap_connection.close.assert_called_once() + vector._create_connection_pool.assert_called_once() + assert any( + "CREATE OR REPLACE FUNCTION public.to_tsquery_from_text" in call.args[0] + for call in worker_cursor.execute.call_args_list + ) + assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list) + + +def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector.databaseName = "knowledgebase" + + bootstrap_cursor = MagicMock() + bootstrap_connection = MagicMock() + bootstrap_connection.cursor.return_value = bootstrap_cursor + monkeypatch.setattr(sql_module.psycopg2, "connect", MagicMock(return_value=bootstrap_connection)) + + worker_cursor = MagicMock() + worker_connection = MagicMock() + worker_cursor.connection = worker_connection + worker_cursor.execute.side_effect = RuntimeError("zhparser unavailable") + + pooled_connection = MagicMock() + pooled_connection.cursor.return_value = worker_cursor + pool = MagicMock() + pool.getconn.return_value = pooled_connection + vector._create_connection_pool = MagicMock(return_value=pool) + + with pytest.raises(RuntimeError, match="Failed to create zhparser extension"): + vector._initialize_vector_database() + + worker_connection.rollback.assert_called_once() + + +def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + vector._create_collection_if_not_exists(embedding_dimension=3) + + assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) + assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) + sql_module.redis_client.set.assert_called_once() + + +def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.config = AnalyticdbVectorBySqlConfig(**_config_values()) + vector._collection_name = "collection" + vector.table_name = "dify.collection" + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(sql_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(sql_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(sql_module.redis_client, "set", MagicMock()) + + cursor = MagicMock() + cursor.execute.side_effect = RuntimeError("permission denied") + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + with pytest.raises(RuntimeError, match="permission denied"): + vector._create_collection_if_not_exists(embedding_dimension=3) + + +def test_delete_methods_raise_when_error_is_not_missing_table(): + vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql) + vector.table_name = "dify.collection" + cursor = MagicMock() + + @contextmanager + def _cursor_context(): + yield cursor + + vector._get_cursor = _cursor_context + + cursor.execute.side_effect = RuntimeError("unexpected delete failure") + with pytest.raises(RuntimeError, match="unexpected delete failure"): + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("unexpected metadata failure") + with pytest.raises(RuntimeError, match="unexpected metadata failure"): + vector.delete_by_metadata_field("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py new file mode 100644 index 0000000000..c46c3d5e4b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py @@ -0,0 +1,542 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_pymochow_modules(): + pymochow = types.ModuleType("pymochow") + pymochow.__path__ = [] + pymochow_auth = types.ModuleType("pymochow.auth") + pymochow_auth.__path__ = [] + pymochow_credentials = types.ModuleType("pymochow.auth.bce_credentials") + pymochow_configuration = types.ModuleType("pymochow.configuration") + pymochow_exception = types.ModuleType("pymochow.exception") + pymochow_model = types.ModuleType("pymochow.model") + pymochow_model.__path__ = [] + pymochow_model_database = types.ModuleType("pymochow.model.database") + pymochow_model_enum = types.ModuleType("pymochow.model.enum") + pymochow_model_schema = types.ModuleType("pymochow.model.schema") + pymochow_model_table = types.ModuleType("pymochow.model.table") + + class _SimpleObject: + def __init__(self, *args, **kwargs): + self.args = args + for key, value in kwargs.items(): + setattr(self, key, value) + + class ServerError(Exception): + def __init__(self, code): + super().__init__(f"server error {code}") + self.code = code + + class ServerErrCode: + TABLE_NOT_EXIST = 1001 + DB_ALREADY_EXIST = 1002 + + class IndexType: + __members__ = {"HNSW": "HNSW"} + + class MetricType: + __members__ = {"IP": "IP"} + + class IndexState: + NORMAL = "NORMAL" + + class TableState: + NORMAL = "NORMAL" + + class InvertedIndexAnalyzer: + DEFAULT_ANALYZER = "DEFAULT_ANALYZER" + + class InvertedIndexParseMode: + COARSE_MODE = "COARSE_MODE" + + class InvertedIndexFieldAttribute: + ANALYZED = "ANALYZED" + + class FieldType: + STRING = "STRING" + TEXT = "TEXT" + JSON = "JSON" + FLOAT_VECTOR = "FLOAT_VECTOR" + + pymochow.MochowClient = _SimpleObject + pymochow_credentials.BceCredentials = _SimpleObject + pymochow_configuration.Configuration = _SimpleObject + pymochow_exception.ServerError = ServerError + pymochow_model_database.Database = _SimpleObject + + pymochow_model_enum.FieldType = FieldType + pymochow_model_enum.IndexState = IndexState + pymochow_model_enum.IndexType = IndexType + pymochow_model_enum.MetricType = MetricType + pymochow_model_enum.ServerErrCode = ServerErrCode + pymochow_model_enum.TableState = TableState + + for cls_name in [ + "AutoBuildRowCountIncrement", + "Field", + "FilteringIndex", + "HNSWParams", + "InvertedIndex", + "InvertedIndexParams", + "Schema", + "VectorIndex", + ]: + setattr(pymochow_model_schema, cls_name, _SimpleObject) + pymochow_model_schema.InvertedIndexAnalyzer = InvertedIndexAnalyzer + pymochow_model_schema.InvertedIndexFieldAttribute = InvertedIndexFieldAttribute + pymochow_model_schema.InvertedIndexParseMode = InvertedIndexParseMode + + for cls_name in ["AnnSearch", "BM25SearchRequest", "HNSWSearchParams", "Partition", "Row"]: + setattr(pymochow_model_table, cls_name, _SimpleObject) + + pymochow.auth = pymochow_auth + pymochow.model = pymochow_model + pymochow_auth.bce_credentials = pymochow_credentials + pymochow_model.database = pymochow_model_database + pymochow_model.enum = pymochow_model_enum + pymochow_model.schema = pymochow_model_schema + pymochow_model.table = pymochow_model_table + + modules = { + "pymochow": pymochow, + "pymochow.auth": pymochow_auth, + "pymochow.auth.bce_credentials": pymochow_credentials, + "pymochow.configuration": pymochow_configuration, + "pymochow.exception": pymochow_exception, + "pymochow.model": pymochow_model, + "pymochow.model.database": pymochow_model_database, + "pymochow.model.enum": pymochow_model_enum, + "pymochow.model.schema": pymochow_model_schema, + "pymochow.model.table": pymochow_model_table, + } + return modules + + +@pytest.fixture +def baidu_module(monkeypatch): + for name, module in _build_fake_pymochow_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + import core.rag.datasource.vdb.baidu.baidu_vector as module + + return importlib.reload(module) + + +def test_baidu_config_validation(baidu_module): + values = { + "endpoint": "https://example.com", + "account": "account", + "api_key": "key", + "database": "database", + } + config = baidu_module.BaiduConfig.model_validate(values) + assert config.endpoint == "https://example.com" + + for key, error_message in [ + ("endpoint", "BAIDU_VECTOR_DB_ENDPOINT"), + ("account", "BAIDU_VECTOR_DB_ACCOUNT"), + ("api_key", "BAIDU_VECTOR_DB_API_KEY"), + ("database", "BAIDU_VECTOR_DB_DATABASE"), + ]: + invalid = dict(values) + invalid[key] = "" + with pytest.raises(ValueError, match=error_message): + baidu_module.BaiduConfig.model_validate(invalid) + + +def test_get_search_result_handles_metadata_and_threshold(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace( + rows=[ + {"row": {"page_content": "doc1", "metadata": '{"document_id":"d1"}'}, "score": 0.9}, + {"row": {"page_content": "doc2", "metadata": {"document_id": "d2"}}, "score": 0.4}, + {"row": {"page_content": "doc3", "metadata": 123}, "score": 0.95}, + ] + ) + + docs = vector._get_search_res(response, score_threshold=0.8) + + assert len(docs) == 2 + assert docs[0].page_content == "doc1" + assert docs[0].metadata["score"] == 0.9 + assert docs[1].page_content == "doc3" + + +def test_delete_by_ids_and_delete_by_metadata_field(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + vector._collection_name = "collection_1" + + vector.delete_by_ids([]) + table.delete.assert_not_called() + + vector.delete_by_ids(["id1", "id2"]) + table.delete.assert_called_once() + + table.delete.reset_mock() + vector.delete_by_metadata_field("source", 'abc"def') + delete_filter = table.delete.call_args.kwargs["filter"] + assert delete_filter == 'metadata["source"] = "abc\\"def"' + + +def test_delete_handles_table_not_exist_error_and_raises_for_other_codes(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + + vector._db.drop_table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + vector.delete() + + vector._db.drop_table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector.delete() + + +def test_init_database_uses_existing_or_creates_when_missing(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + + vector._client.list_databases.return_value = [SimpleNamespace(database_name="my_db")] + vector._client.database.return_value = "existing_db" + assert vector._init_database() == "existing_db" + + vector._client.list_databases.return_value = [] + vector._client.database.return_value = "created_db" + vector._client.create_database.side_effect = None + assert vector._init_database() == "created_db" + + vector._client.create_database.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.DB_ALREADY_EXIST) + assert vector._init_database() == "created_db" + + +def test_table_existed_checks_table_access(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._db.table.return_value = MagicMock() + + assert vector._table_existed() is True + + vector._db.table.side_effect = baidu_module.ServerError(baidu_module.ServerErrCode.TABLE_NOT_EXIST) + assert vector._table_existed() is False + + vector._db.table.side_effect = baidu_module.ServerError(9999) + with pytest.raises(baidu_module.ServerError): + vector._table_existed() + + +def test_search_methods_delegate_to_database_table(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._get_search_res = MagicMock(return_value=[Document(page_content="doc", metadata={"doc_id": "1"})]) + + table = MagicMock() + vector._db.table.return_value = table + table.search.return_value = "vector_result" + table.bm25_search.return_value = "bm25_result" + + result1 = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + result2 = vector.search_by_full_text("query", top_k=3, document_ids_filter=["doc-1"], score_threshold=0.2) + + assert result1 == vector._get_search_res.return_value + assert result2 == vector._get_search_res.return_value + assert vector._get_search_res.call_count == 2 + + +def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None) + monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "auto_collection" + assert dataset.index_struct is not None + + +def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch): + init_client = MagicMock(return_value="client") + init_database = MagicMock(return_value="database") + monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client) + monkeypatch.setattr(baidu_module.BaiduVector, "_init_database", init_database) + + config = baidu_module.BaiduConfig( + endpoint="https://example.com", + account="account", + api_key="key", + database="db", + ) + vector = baidu_module.BaiduVector(collection_name="my_collection", config=config) + + assert vector.get_type() == baidu_module.VectorType.BAIDU + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "my_collection" + assert vector._client == "client" + assert vector._db == "database" + + vector._create_table = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="p1", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_table.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_rows(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2", "document_id": "doc-2"}), + ] + vector.add_texts(docs, [[0.1, 0.2], [0.3, 0.4]]) + + assert table.upsert.call_count == 1 + inserted_rows = table.upsert.call_args.kwargs["rows"] + assert len(inserted_rows) == 2 + + +def test_add_texts_batches_more_than_batch_size(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + docs = [ + Document(page_content=f"doc-{idx}", metadata={"doc_id": f"id-{idx}", "document_id": f"doc-{idx}"}) + for idx in range(1001) + ] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + vector.add_texts(docs, embeddings) + + assert table.upsert.call_count == 2 + assert len(table.upsert.call_args_list[0].kwargs["rows"]) == 1000 + assert len(table.upsert.call_args_list[1].kwargs["rows"]) == 1 + + +def test_text_exists_returns_false_when_query_code_is_not_success(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + table = MagicMock() + vector._db = MagicMock() + vector._db.table.return_value = table + + table.query.return_value = SimpleNamespace(code=0) + assert vector.text_exists("id-1") is True + + table.query.return_value = SimpleNamespace(code=1) + assert vector.text_exists("id-1") is False + + table.query.return_value = None + assert vector.text_exists("id-1") is False + + +def test_get_search_result_handles_invalid_metadata_json(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + response = SimpleNamespace(rows=[{"row": {"page_content": "doc1", "metadata": "{bad json"}, "score": 0.7}]) + + docs = vector._get_search_res(response, score_threshold=0.1) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.7 + assert "document_id" not in docs[0].metadata + + +def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch): + credentials = MagicMock(return_value="credentials") + configuration = MagicMock(return_value="configuration") + client_cls = MagicMock(return_value="client") + monkeypatch.setattr(baidu_module, "BceCredentials", credentials) + monkeypatch.setattr(baidu_module, "Configuration", configuration) + monkeypatch.setattr(baidu_module, "MochowClient", client_cls) + + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + config = SimpleNamespace(account="account", api_key="key", endpoint="https://endpoint") + + client = vector._init_client(config) + + assert client == "client" + credentials.assert_called_once_with("account", "key") + configuration.assert_called_once_with(credentials="credentials", endpoint="https://endpoint") + client_cls.assert_called_once_with("configuration") + + +def test_init_database_raises_for_unknown_create_database_error(baidu_module): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._client = MagicMock() + vector._client_config = SimpleNamespace(database="my_db") + vector._client.list_databases.return_value = [] + vector._client.create_database.side_effect = baidu_module.ServerError(9999) + + with pytest.raises(baidu_module.ServerError): + vector._init_database() + + +def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + table = MagicMock() + table.state = baidu_module.TableState.NORMAL + vector._db.describe_table.return_value = table + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(vector, "_wait_for_index_ready", MagicMock()) + + # Cached table skips all work. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Existing table also skips creation. + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + vector._table_existed.return_value = True + vector._create_table(3) + vector._db.create_table.assert_not_called() + + # Create table when cache is empty and table does not exist. + vector._table_existed.return_value = False + vector._create_table(3) + vector._db.create_table.assert_called_once() + baidu_module.redis_client.set.assert_called_once_with("vector_indexing_collection_1", 1, ex=3600) + table.rebuild_index.assert_called_once_with(vector.vector_index) + vector._wait_for_index_ready.assert_called_once_with(table, 3600) + + +def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._db = MagicMock() + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + vector._client_config = SimpleNamespace( + index_type="INVALID", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_table(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "INVALID" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_table(3) + + +def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch): + vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector) + vector._collection_name = "collection_1" + vector._client_config = SimpleNamespace( + index_type="HNSW", + metric_type="IP", + inverted_index_analyzer="DEFAULT_ANALYZER", + inverted_index_parser_mode="COARSE_MODE", + auto_build_row_count_increment=500, + auto_build_row_count_increment_ratio=0.05, + rebuild_index_timeout_in_seconds=300, + replicas=1, + shard=1, + ) + vector._db = MagicMock() + vector._db.describe_table.return_value = SimpleNamespace(state="CREATING") + vector._table_existed = MagicMock(return_value=False) + vector.delete = MagicMock() + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(baidu_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(baidu_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(baidu_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(baidu_module.time, "time", MagicMock(side_effect=[0, 301])) + + with pytest.raises(TimeoutError, match="Table creation timeout"): + vector._create_table(3) + + +def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch): + factory = baidu_module.BaiduVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ENDPOINT", "https://endpoint") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS", 1000) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_ACCOUNT", "account") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_API_KEY", "key") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_DATABASE", "database") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REPLICAS", 1) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER", "DEFAULT_ANALYZER") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE", "COARSE_MODE") + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT", 500) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO", 0.05) + monkeypatch.setattr(baidu_module.dify_config, "BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS", 300) + + with patch.object(baidu_module, "BaiduVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py new file mode 100644 index 0000000000..44427b7d87 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py @@ -0,0 +1,199 @@ +import importlib +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_chroma_modules(): + chroma = types.ModuleType("chromadb") + chroma.DEFAULT_TENANT = "default_tenant" + chroma.DEFAULT_DATABASE = "default_database" + + class Settings: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + class QueryResult(UserDict): + pass + + class _Collection: + def __init__(self): + self.upsert = MagicMock() + self.delete = MagicMock() + self.query = MagicMock() + self.get = MagicMock(return_value={}) + + class _Client: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.collection = _Collection() + self.get_or_create_collection = MagicMock(return_value=self.collection) + self.delete_collection = MagicMock() + + chroma.Settings = Settings + chroma.QueryResult = QueryResult + chroma.HttpClient = _Client + return chroma + + +@pytest.fixture +def chroma_module(monkeypatch): + fake_chroma = _build_fake_chroma_modules() + monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) + import core.rag.datasource.vdb.chroma.chroma_vector as module + + return importlib.reload(module) + + +def test_chroma_config_to_params_builds_expected_payload(chroma_module): + config = chroma_module.ChromaConfig( + host="localhost", + port=8000, + tenant="tenant-1", + database="db-1", + auth_provider="provider", + auth_credentials="credentials", + ) + + params = config.to_chroma_params() + + assert params["host"] == "localhost" + assert params["port"] == 8000 + assert params["tenant"] == "tenant-1" + assert params["database"] == "db-1" + assert params["ssl"] is False + assert params["settings"].chroma_client_auth_provider == "provider" + assert params["settings"].chroma_client_auth_credentials == "credentials" + + +def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(chroma_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(chroma_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(chroma_module.redis_client, "set", MagicMock()) + + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create_collection("collection_1") + + vector._client.get_or_create_collection.assert_called_once_with("collection_1") + chroma_module.redis_client.set.assert_called_once() + + +def test_create_with_empty_texts_is_noop(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector.create([], []) + vector._client.get_or_create_collection.assert_not_called() + + +def test_create_with_texts_creates_collection_and_upserts(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + docs = [Document(page_content="hello", metadata={"doc_id": "d1", "document_id": "doc-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._client.get_or_create_collection.assert_called() + vector._client.collection.upsert.assert_called_once() + + +def test_delete_methods_and_text_exists(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + + vector.delete_by_ids([]) + vector._client.collection.delete.assert_not_called() + + vector.delete_by_ids(["id-1"]) + vector._client.collection.delete.assert_called_with(ids=["id-1"]) + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.collection.delete.assert_called_with(where={"document_id": {"$eq": "doc-1"}}) + + vector._client.collection.get.return_value = {"ids": ["id-1"]} + assert vector.text_exists("id-1") is True + vector._client.collection.get.return_value = {} + assert vector.text_exists("id-2") is False + + vector.delete() + vector._client.delete_collection.assert_called_once_with("collection_1") + + +def test_search_by_vector_handles_empty_results(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = {"ids": [], "documents": [], "metadatas": [], "distances": []} + + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + +def test_search_by_vector_applies_score_threshold_and_sorting(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + vector._client.collection.query.return_value = { + "ids": [["id-1", "id-2"]], + "documents": [["doc high", "doc low"]], + "metadatas": [[{"doc_id": "id-1"}, {"doc_id": "id-2"}]], + "distances": [[0.1, 0.8]], + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc high" + assert docs[0].metadata["score"] == 0.9 + + +def test_search_by_full_text_returns_empty_list(chroma_module): + vector = chroma_module.ChromaVector( + collection_name="collection_1", + config=chroma_module.ChromaConfig(host="localhost", port=8000, tenant="t", database="d"), + ) + assert vector.search_by_full_text("query") == [] + + +def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch): + factory = chroma_module.ChromaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(chroma_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_HOST", "localhost") + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_PORT", 8000) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_TENANT", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_DATABASE", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_PROVIDER", None) + monkeypatch.setattr(chroma_module.dify_config, "CHROMA_AUTH_CREDENTIALS", None) + + with patch.object(chroma_module, "ChromaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py new file mode 100644 index 0000000000..0ce5c04dd6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py @@ -0,0 +1,927 @@ +import importlib +import queue +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickzetta_module(): + clickzetta = types.ModuleType("clickzetta") + + class _FakeCursor: + def __init__(self): + self.execute = MagicMock() + self.executemany = MagicMock() + self.fetchall = MagicMock(return_value=[]) + self.fetchone = MagicMock(return_value=(0,)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _FakeConnection: + def __init__(self): + self.cursor_obj = _FakeCursor() + + def cursor(self): + return self.cursor_obj + + def close(self): + return None + + def connect(**_kwargs): + return _FakeConnection() + + clickzetta.connect = connect + return clickzetta + + +@pytest.fixture +def clickzetta_module(monkeypatch): + monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) + import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.ClickzettaConfig( + username="username", + password="password", + instance="instance", + service="service", + workspace="workspace", + vcluster="cluster", + schema_name="dify", + ) + + +@pytest.mark.parametrize( + ("field", "error_message"), + [ + ("username", "CLICKZETTA_USERNAME"), + ("password", "CLICKZETTA_PASSWORD"), + ("instance", "CLICKZETTA_INSTANCE"), + ("service", "CLICKZETTA_SERVICE"), + ("workspace", "CLICKZETTA_WORKSPACE"), + ("vcluster", "CLICKZETTA_VCLUSTER"), + ("schema_name", "CLICKZETTA_SCHEMA"), + ], +) +def test_clickzetta_config_validation(clickzetta_module, field, error_message): + values = _config(clickzetta_module).model_dump() + values[field] = "" + with pytest.raises(ValueError, match=error_message): + clickzetta_module.ClickzettaConfig.model_validate(values) + + +def test_parse_metadata_handles_valid_double_encoded_and_invalid_json(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed = vector._parse_metadata('{"document_id":"doc-1"}', "row-1") + assert parsed["doc_id"] == "row-1" + assert parsed["document_id"] == "doc-1" + + parsed_double = vector._parse_metadata('"{\\"document_id\\": \\"doc-2\\"}"', "row-2") + assert parsed_double["doc_id"] == "row-2" + assert parsed_double["document_id"] == "doc-2" + + parsed_fallback = vector._parse_metadata("not-json", "row-3") + assert parsed_fallback["doc_id"] == "row-3" + assert parsed_fallback["document_id"] == "row-3" + + +def test_safe_doc_id_and_vector_format_helpers(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + assert vector._format_vector_simple([0.1, 0.2, 0.3]) == "0.1,0.2,0.3" + assert vector._safe_doc_id("abc-123_DEF") == "abc-123_DEF" + assert vector._safe_doc_id("ab c;\n") == "abc" + assert len(vector._safe_doc_id("a" * 300)) == 255 + + +def test_table_exists_returns_false_for_not_found_and_other_exceptions(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_not_found(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("CZLH-42000 table or view not found") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_not_found + assert vector._table_exists() is False + + @contextmanager + def _ctx_other_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("permission denied") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_other_error + assert vector._table_exists() is False + + +def test_text_exists_handles_missing_table_and_existing_rows(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.text_exists("doc-1") is False + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchone.return_value = (1,) + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + assert vector.text_exists("doc-1") is True + + +def test_delete_by_ids_and_delete_by_metadata_field_short_circuit(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + + vector.delete_by_ids([]) + vector._execute_write.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + vector.delete_by_ids(["doc-1"]) + vector._execute_write.assert_not_called() + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_not_called() + + +def test_search_short_circuit_behaviors(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_vector([0.1, 0.2], top_k=2) == [] + + vector._config.enable_inverted_index = False + assert vector.search_by_full_text("query", top_k=2) == [] + + +def test_search_by_like_returns_documents_with_default_score(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}')] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector._search_by_like("query", top_k=3, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "content" + assert docs[0].metadata["score"] == 0.5 + + +def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch): + factory = clickzetta_module.ClickzettaVectorFactory() + dataset = SimpleNamespace(id="dataset-1") + + monkeypatch.setattr(clickzetta_module.Dataset, "gen_collection_name_by_id", lambda _id: "COLLECTION") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_USERNAME", "username") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_PASSWORD", "password") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_INSTANCE", "instance") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SERVICE", "service") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_WORKSPACE", "workspace") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VCLUSTER", "cluster") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_SCHEMA", "dify") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_BATCH_SIZE", 10) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ENABLE_INVERTED_INDEX", True) + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_TYPE", "chinese") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_ANALYZER_MODE", "smart") + monkeypatch.setattr(clickzetta_module.dify_config, "CLICKZETTA_VECTOR_DISTANCE_FUNCTION", "cosine_distance") + + with patch.object(clickzetta_module, "ClickzettaVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "collection" + + +def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch): + clickzetta_module.ClickzettaConnectionPool._instance = None + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + + pool_1 = clickzetta_module.ClickzettaConnectionPool.get_instance() + pool_2 = clickzetta_module.ClickzettaConnectionPool.get_instance() + key = pool_1._get_config_key(_config(clickzetta_module)) + + assert pool_1 is pool_2 + assert "username:instance:service:workspace:cluster:dify" in key + + +def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + connection = MagicMock() + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr( + clickzetta_module.clickzetta, "connect", MagicMock(side_effect=[RuntimeError("boom"), connection]) + ) + pool._configure_connection = MagicMock() + + created = pool._create_connection(config) + + assert created is connection + assert clickzetta_module.clickzetta.connect.call_count == 2 + pool._configure_connection.assert_called_once_with(connection) + + +def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + monkeypatch.setattr(clickzetta_module.clickzetta, "connect", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + pool._create_connection(config) + + +def test_connection_pool_configure_and_validate_connection(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection = MagicMock() + connection.cursor.return_value = cursor + + pool._configure_connection(connection) + assert cursor.execute.call_count >= 2 + assert pool._is_connection_valid(connection) is True + + bad_connection = MagicMock() + bad_connection.cursor.side_effect = RuntimeError("bad connection") + assert pool._is_connection_valid(bad_connection) is False + monkeypatch.undo() + + +def test_connection_pool_configure_connection_swallows_errors(clickzetta_module): + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + connection = MagicMock() + connection.cursor.side_effect = RuntimeError("cannot configure") + + pool._configure_connection(connection) + monkeypatch.undo() + + +def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch): + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock()) + pool = clickzetta_module.ClickzettaConnectionPool() + config = _config(clickzetta_module) + key = pool._get_config_key(config) + + created_connection = MagicMock() + pool._create_connection = MagicMock(return_value=created_connection) + first = pool.get_connection(config) + assert first is created_connection + + reusable_connection = MagicMock() + pool._pools[key] = [(reusable_connection, clickzetta_module.time.time())] + pool._is_connection_valid = MagicMock(return_value=True) + reused = pool.get_connection(config) + assert reused is reusable_connection + + expired_connection = MagicMock() + pool._pools[key] = [(expired_connection, 0.0)] + pool._is_connection_valid = MagicMock(return_value=False) + monkeypatch.setattr(clickzetta_module.time, "time", MagicMock(return_value=1000.0)) + pool.get_connection(config) + expired_connection.close.assert_called_once() + + random_connection = MagicMock() + pool._is_connection_valid = MagicMock(return_value=True) + pool.return_connection(config, random_connection) + assert len(pool._pools[key]) == 1 + + pool._pools[key] = [(MagicMock(), 0.0), (MagicMock(), 1000.0)] + pool._connection_timeout = 10 + pool._cleanup_expired_connections() + assert len(pool._pools[key]) == 1 + + unknown_pool = MagicMock() + pool.return_connection(_config(clickzetta_module).model_copy(update={"workspace": "other"}), unknown_pool) + unknown_pool.close.assert_called_once() + + pool.shutdown() + assert pool._shutdown is True + + +def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True)) + + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + self.started = False + + def start(self): + self.started = True + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + + assert pool._cleanup_thread.started is True + pool._cleanup_expired_connections.assert_called_once() + + +def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch): + pool = MagicMock() + pool.get_connection.return_value = "conn" + monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool)) + monkeypatch.setattr(clickzetta_module.ClickzettaVector, "_init_write_queue", MagicMock()) + + vector = clickzetta_module.ClickzettaVector("My-Collection", _config(clickzetta_module)) + assert vector._table_name == "my_collection" + + assert vector._get_connection() == "conn" + vector._return_connection("conn") + pool.return_connection.assert_called_with(vector._config, "conn") + + with vector.get_connection_context() as conn: + assert conn == "conn" + assert pool.return_connection.call_count >= 2 + + assert vector.get_type() == "clickzetta" + assert vector._ensure_connection() == "conn" + + +def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch): + class _Thread: + def __init__(self, target, daemon): + self.target = target + self.daemon = daemon + self.started = 0 + + def start(self): + self.started += 1 + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_thread = None + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._init_write_queue() + clickzetta_module.ClickzettaVector._init_write_queue() + assert clickzetta_module.ClickzettaVector._write_thread.started == 1 + + result_queue_ok = queue.Queue() + result_queue_fail = queue.Queue() + clickzetta_module.ClickzettaVector._write_queue = queue.Queue() + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue.put((lambda x: x + 1, (1,), {}, result_queue_ok)) + clickzetta_module.ClickzettaVector._write_queue.put( + (lambda: (_ for _ in ()).throw(RuntimeError("worker error")), (), {}, result_queue_fail) + ) + clickzetta_module.ClickzettaVector._write_queue.put(None) + clickzetta_module.ClickzettaVector._write_worker() + + assert result_queue_ok.get() == (True, 2) + failed = result_queue_fail.get() + assert failed[0] is False + assert isinstance(failed[1], RuntimeError) + + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + clickzetta_module.ClickzettaVector._write_queue = None + with pytest.raises(RuntimeError, match="Write queue not initialized"): + vector._execute_write(lambda: None) + + class _ImmediateSuccessQueue: + def put(self, task): + func, args, kwargs, result_q = task + result_q.put((True, func(*args, **kwargs))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateSuccessQueue() + assert vector._execute_write(lambda x: x * 2, 3) == 6 + + class _ImmediateFailQueue: + def put(self, task): + _, _, _, result_q = task + result_q.put((False, ValueError("write failed"))) + + clickzetta_module.ClickzettaVector._write_queue = _ImmediateFailQueue() + with pytest.raises(ValueError, match="write failed"): + vector._execute_write(lambda: None) + + +def test_table_exists_true_and_create_invokes_write_and_add_texts(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + + @contextmanager + def _ctx_exists(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_exists + assert vector._table_exists() is True + + vector._execute_write = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="content", metadata={"doc_id": "d1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._execute_write.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_table_and_indexes_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._create_vector_index = MagicMock() + vector._create_inverted_index = MagicMock() + + vector._table_exists = MagicMock(return_value=True) + vector._create_table_and_indexes([[0.1, 0.2]]) + vector._create_vector_index.assert_not_called() + + vector._table_exists = MagicMock(return_value=False) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._create_table_and_indexes([[0.1, 0.2, 0.3]]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_called_once() + + vector._config.enable_inverted_index = False + vector._create_vector_index.reset_mock() + vector._create_inverted_index.reset_mock() + vector._create_table_and_indexes([]) + vector._create_vector_index.assert_called_once() + vector._create_inverted_index.assert_not_called() + + +def test_create_vector_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_vector", "embedding_vector")] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show index failed"), None] + vector._create_vector_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("already exists")] + cursor.fetchall.return_value = [] + vector._create_vector_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("unexpected")] + cursor.fetchall.return_value = [] + with pytest.raises(RuntimeError, match="unexpected"): + vector._create_vector_index(cursor) + + +def test_create_inverted_index_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 1 + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("show failed"), None] + vector._create_inverted_index(cursor) + assert cursor.execute.call_count == 2 + + cursor.reset_mock() + cursor.execute.side_effect = [ + None, + RuntimeError("already has index"), + None, + ] + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + vector._create_inverted_index(cursor) + + cursor.reset_mock() + cursor.execute.side_effect = [None, RuntimeError("other create failure")] + cursor.fetchall.return_value = [] + vector._create_inverted_index(cursor) + + +def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.batch_size = 2 + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._safe_doc_id = MagicMock(side_effect=lambda doc_id: str(doc_id)) + + docs = [ + Document(page_content="doc-1", metadata={"doc_id": "id-1"}), + Document(page_content="doc-2", metadata={"doc_id": "id-2"}), + Document(page_content="doc-3", metadata={"doc_id": "id-3"}), + ] + vectors = [[0.1], [0.2], [0.3]] + + vector.add_texts([], []) + vector._execute_write.assert_not_called() + + added_ids = vector.add_texts(docs, vectors) + assert added_ids == ["id-1", "id-2", "id-3"] + assert vector._execute_write.call_count == 2 + assert vector._execute_write.call_args_list[0].args == ( + vector._insert_batch, + docs[:2], + vectors[:2], + ["id-1", "id-2"], + 0, + 2, + 2, + ) + assert vector._execute_write.call_args_list[1].args == ( + vector._insert_batch, + docs[2:], + vectors[2:], + ["id-3"], + 2, + 2, + 2, + ) + + vector._insert_batch([], [], [], 0, 2, 1) + vector._insert_batch(docs[:1], vectors, ["id-1"], 0, 2, 1) + + bad_doc = Document(page_content="doc-bad", metadata={"doc_id": "id-bad", "bad": {1}}) + good_doc = Document(page_content="doc-good", metadata={"doc_id": "id-good"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [bad_doc, good_doc], + [[0.1, 0.2], [0.3, 0.4]], + ["id-bad", "id-good"], + 0, + 2, + 1, + ) + + @contextmanager + def _ctx_error(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.executemany.side_effect = RuntimeError("insert failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_error + with pytest.raises(RuntimeError, match="insert failed"): + vector._insert_batch([good_doc], [[0.1, 0.2]], ["id-good"], 0, 1, 1) + + monkeypatch.setattr(clickzetta_module.uuid, "uuid4", lambda: "generated-id") + vector._safe_doc_id = clickzetta_module.ClickzettaVector._safe_doc_id.__get__(vector) + assert vector._safe_doc_id("") == "generated-id" + assert vector._safe_doc_id("!!!") == "generated-id" + + +def test_delete_by_ids_and_metadata_impl_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._execute_write = MagicMock() + vector._table_exists = MagicMock(return_value=True) + + vector.delete_by_ids(["id-1", "id-2"]) + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_ids_impl + + vector._execute_write.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector._execute_write.assert_called_once() + assert vector._execute_write.call_args.args[0] == vector._delete_by_metadata_field_impl + + vector._safe_doc_id = MagicMock(side_effect=lambda x: x) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector._delete_by_ids_impl(["id-1", "id-2"]) + vector._delete_by_metadata_field_impl("document_id", "doc-1") + + +def test_search_by_vector_covers_cosine_and_l2_paths(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.vector_distance_function = "cosine_distance" + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + vector._parse_metadata = MagicMock(return_value={"document_id": "doc-1", "doc_id": "seg-1"}) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [("seg-1", "content", '{"document_id":"doc-1"}', 0.2)] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + cosine_docs = vector.search_by_vector( + [0.1, 0.2], top_k=3, score_threshold=0.5, document_ids_filter=["doc-1"], filter={"k": "v"} + ) + assert cosine_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._config.vector_distance_function = "l2_distance" + l2_docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert l2_docs[0].metadata["score"] == pytest.approx(1 / 1.2) + + +def test_search_by_full_text_success_and_fallback(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx_success(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", '"{\\"document_id\\":\\"doc-1\\"}"'), + ("seg-2", "content-2", "invalid-json"), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_success + docs = vector.search_by_full_text("search'value", top_k=2, document_ids_filter=["doc-1"], filter={"a": 1}) + assert len(docs) == 2 + assert docs[0].metadata["score"] == 1.0 + assert docs[1].metadata["doc_id"] == "seg-2" + + @contextmanager + def _ctx_failure(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.execute.side_effect = RuntimeError("full text failed") + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx_failure + vector._search_by_like = MagicMock(return_value=[Document(page_content="fallback", metadata={"score": 0.5})]) + fallback_docs = vector.search_by_full_text("query", top_k=1) + assert fallback_docs == vector._search_by_like.return_value + + +def test_search_by_like_missing_table_and_delete_table(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + vector._table_exists = MagicMock(return_value=False) + assert vector._search_by_like("query", top_k=1) == [] + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + vector.delete() + + +def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._pools = {} + pool._pool_locks = {} + pool._max_pool_size = 1 + pool._connection_timeout = 10 + pool._lock = clickzetta_module.threading.Lock() + pool._shutdown = False + + config = _config(clickzetta_module) + key = pool._get_config_key(config) + pool._pools[key] = [(MagicMock(), 1.0)] + pool._pool_locks[key] = clickzetta_module.threading.Lock() + pool._is_connection_valid = MagicMock(return_value=False) + + conn = MagicMock() + pool.return_connection(config, conn) + conn.close.assert_called_once() + + pool._pools["missing-lock-key"] = [(MagicMock(), 0.0)] + pool._cleanup_expired_connections() + pool.shutdown() + assert pool._shutdown is True + + +def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch): + pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool) + pool._shutdown = False + + def _cleanup_then_fail(): + pool._shutdown = True + raise RuntimeError("cleanup failed") + + pool._cleanup_expired_connections = MagicMock(side_effect=_cleanup_then_fail) + monkeypatch.setattr(clickzetta_module.time, "sleep", lambda _s: None) + + class _Thread: + def __init__(self, target, daemon): + self._target = target + self.daemon = daemon + + def start(self): + self._target() + + monkeypatch.setattr(clickzetta_module.threading, "Thread", _Thread) + pool._start_cleanup_thread() + pool._cleanup_expired_connections.assert_called_once() + + +def test_clickzetta_parse_metadata_and_write_worker_additional_branches(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + + parsed_non_dict = vector._parse_metadata("[1,2,3]", "row-1") + assert parsed_non_dict["doc_id"] == "row-1" + assert parsed_non_dict["document_id"] == "row-1" + + parsed_none = vector._parse_metadata(None, "row-2") + assert parsed_none["doc_id"] == "row-2" + assert parsed_none["document_id"] == "row-2" + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = None + clickzetta_module.ClickzettaVector._write_worker() + + class _BadQueue: + def get(self, timeout): + clickzetta_module.ClickzettaVector._shutdown = True + raise RuntimeError("queue failed") + + clickzetta_module.ClickzettaVector._shutdown = False + clickzetta_module.ClickzettaVector._write_queue = _BadQueue() + clickzetta_module.ClickzettaVector._write_worker() + + +def test_clickzetta_inverted_index_existing_and_insert_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._table_name = "table_1" + cursor = MagicMock() + cursor.fetchall.return_value = [("idx_table_1_text", "INVERTED", "page_content")] + cursor.execute.side_effect = [ + None, + RuntimeError("already has index with the same type cannot create inverted index"), + None, + ] + + vector._create_inverted_index(cursor) + + vector._safe_doc_id = MagicMock(side_effect=lambda value: str(value)) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor_obj = MagicMock() + cursor_obj.__enter__.return_value = cursor_obj + cursor_obj.__exit__.return_value = None + connection.cursor.return_value = cursor_obj + yield connection + + vector.get_connection_context = _ctx + vector._insert_batch( + [SimpleNamespace(page_content="content", metadata="not-a-dict")], + [[0.1, 0.2]], + ["doc-1"], + 0, + 1, + 1, + ) + + +def test_clickzetta_full_text_table_missing_and_non_dict_metadata(clickzetta_module): + vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector) + vector._config = _config(clickzetta_module) + vector._config.enable_inverted_index = True + vector._table_name = "table_1" + + vector._table_exists = MagicMock(return_value=False) + assert vector.search_by_full_text("query") == [] + + vector._table_exists = MagicMock(return_value=True) + + @contextmanager + def _ctx(): + connection = MagicMock() + cursor = MagicMock() + cursor.__enter__.return_value = cursor + cursor.__exit__.return_value = None + cursor.fetchall.return_value = [ + ("seg-1", "content-1", "[1,2,3]"), + ("seg-2", "content-2", None), + ] + connection.cursor.return_value = cursor + yield connection + + vector.get_connection_context = _ctx + docs = vector.search_by_full_text("query") + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[1].metadata["doc_id"] == "seg-2" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py new file mode 100644 index 0000000000..9fea187615 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py @@ -0,0 +1,364 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_couchbase_modules(): + couchbase = types.ModuleType("couchbase") + couchbase_auth = types.ModuleType("couchbase.auth") + couchbase_cluster = types.ModuleType("couchbase.cluster") + couchbase_management = types.ModuleType("couchbase.management") + couchbase_management_search = types.ModuleType("couchbase.management.search") + couchbase_options = types.ModuleType("couchbase.options") + couchbase_vector = types.ModuleType("couchbase.vector_search") + couchbase_search = types.ModuleType("couchbase.search") + + class PasswordAuthenticator: + def __init__(self, user, password): + self.user = user + self.password = password + + class ClusterOptions: + def __init__(self, auth): + self.auth = auth + + class SearchOptions: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorQuery: + def __init__(self, field, vector, top_k): + self.field = field + self.vector = vector + self.top_k = top_k + + class VectorSearch: + @staticmethod + def from_vector_query(vector_query): + return {"vector_query": vector_query} + + class QueryStringQuery: + def __init__(self, query): + self.query = query + + class SearchRequest: + @staticmethod + def create(payload): + return {"payload": payload} + + class SearchIndex: + def __init__(self, name, params, source_name): + self.name = name + self.params = params + self.source_name = source_name + + class _QueryResult: + def __init__(self, rows=None): + self._rows = rows or [] + + def execute(self): + return self + + def __iter__(self): + return iter(self._rows) + + class _SearchIter: + def __init__(self, rows=None): + self._rows = rows or [] + + def rows(self): + return self._rows + + class _Collection: + def __init__(self): + self.upsert = MagicMock(return_value=True) + + class _SearchIndexManager: + def __init__(self): + self.upsert_index = MagicMock() + + class _Scope: + def __init__(self): + self._collection = _Collection() + self._search_index_manager = _SearchIndexManager() + self.search = MagicMock(return_value=_SearchIter()) + + def collection(self, _name): + return self._collection + + def search_indexes(self): + return self._search_index_manager + + class _CollectionManager: + def __init__(self): + self.create_collection = MagicMock() + self.drop_collection = MagicMock() + self.get_all_scopes = MagicMock(return_value=[]) + + class _Bucket: + def __init__(self): + self._scope = _Scope() + self._collections = _CollectionManager() + + def scope(self, _scope_name): + return self._scope + + def collections(self): + return self._collections + + class Cluster: + def __init__(self, connection_string, options): + self.connection_string = connection_string + self.options = options + self._bucket = _Bucket() + self.wait_until_ready = MagicMock() + self.query = MagicMock(return_value=_QueryResult()) + + def bucket(self, _name): + return self._bucket + + couchbase_auth.PasswordAuthenticator = PasswordAuthenticator + couchbase_cluster.Cluster = Cluster + couchbase_management_search.SearchIndex = SearchIndex + couchbase_options.ClusterOptions = ClusterOptions + couchbase_options.SearchOptions = SearchOptions + couchbase_vector.VectorQuery = VectorQuery + couchbase_vector.VectorSearch = VectorSearch + couchbase_search.QueryStringQuery = QueryStringQuery + couchbase_search.SearchRequest = SearchRequest + + couchbase.search = couchbase_search + couchbase.management = couchbase_management + + return { + "couchbase": couchbase, + "couchbase.auth": couchbase_auth, + "couchbase.cluster": couchbase_cluster, + "couchbase.management": couchbase_management, + "couchbase.management.search": couchbase_management_search, + "couchbase.options": couchbase_options, + "couchbase.vector_search": couchbase_vector, + "couchbase.search": couchbase_search, + } + + +@pytest.fixture +def couchbase_module(monkeypatch): + for name, module in _build_fake_couchbase_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.couchbase.couchbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.CouchbaseConfig( + connection_string="couchbase://localhost", + user="user", + password="pass", + bucket_name="bucket", + scope_name="scope", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("connection_string", "", "CONNECTION_STRING is required"), + ("user", "", "COUCHBASE_USER is required"), + ("password", "", "COUCHBASE_PASSWORD is required"), + ("bucket_name", "", "COUCHBASE_PASSWORD is required"), + ("scope_name", "", "COUCHBASE_SCOPE_NAME is required"), + ], +) +def test_couchbase_config_validation(couchbase_module, field, value, message): + values = _config(couchbase_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + couchbase_module.CouchbaseConfig.model_validate(values) + + +def test_init_sets_cluster_handles(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + assert vector._bucket_name == "bucket" + assert vector._scope_name == "scope" + vector._cluster.wait_until_ready.assert_called_once() + + +def test_create_and_create_collection_branches(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector) + vector._collection_name = "collection_1" + vector._client_config = _config(couchbase_module) + vector._scope_name = "scope" + vector._bucket_name = "bucket" + vector._bucket = MagicMock() + vector._scope = MagicMock() + vector._collection_exists = MagicMock(return_value=False) + vector.add_texts = MagicMock() + + monkeypatch.setattr(couchbase_module.uuid, "uuid4", lambda: "a-b-c") + vector._create_collection = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(uuid="abc", vector_length=2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(couchbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(couchbase_module.redis_client, "set", MagicMock()) + + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(vector_length=2, uuid="uuid-1") + vector._bucket.collections().create_collection.assert_not_called() + + monkeypatch.setattr(couchbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._collection_exists = MagicMock(return_value=True) + vector._create_collection(vector_length=2, uuid="uuid-2") + vector._bucket.collections().create_collection.assert_not_called() + + vector._collection_exists = MagicMock(return_value=False) + vector._create_collection(vector_length=3, uuid="uuid-3") + + vector._bucket.collections().create_collection.assert_called_once_with("scope", "collection_1") + vector._scope.search_indexes().upsert_index.assert_called_once() + search_index = vector._scope.search_indexes().upsert_index.call_args.args[0] + assert search_index.name == "collection_1_search" + assert ( + search_index.params["mapping"]["types"]["scope.collection_1"]["properties"]["embedding"]["fields"][0]["dims"] + == 3 + ) + couchbase_module.redis_client.set.assert_called_once() + + +def test_collection_exists_get_type_and_add_texts(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="collection_1")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is True + + scope_obj = SimpleNamespace(name="scope", collections=[SimpleNamespace(name="other")]) + vector._bucket.collections().get_all_scopes.return_value = [scope_obj] + assert vector._collection_exists("collection_1") is False + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert vector._scope.collection("collection_1").upsert.call_count == 2 + assert vector.get_type() == couchbase_module.VectorType.COUCHBASE + + +def test_query_delete_helpers(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([{"count": 2}])) + assert vector.text_exists("id-1") is True + + vector._cluster.query.return_value = SimpleNamespace(execute=lambda: iter([])) + assert vector.text_exists("id-2") is False + + query_result = MagicMock() + query_result.execute.return_value = None + vector._cluster.query.return_value = query_result + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_document_id("id-1") + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._cluster.query.call_count >= 3 + + vector._cluster.query.side_effect = RuntimeError("delete failed") + vector.delete_by_ids(["id-3"]) + + +def test_search_methods_and_format_metadata(couchbase_module): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + + row_1 = SimpleNamespace(fields={"text": "doc-a", "metadata.document_id": "d-1"}, score=0.9) + row_2 = SimpleNamespace(fields={"text": "doc-b", "metadata.document_id": "d-2"}, score=0.3) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_1, row_2]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["document_id"] == "d-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector._scope.search.side_effect = RuntimeError("search error") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_vector([0.1], top_k=1) + + vector._scope.search.side_effect = None + row_3 = SimpleNamespace(fields={"text": "full-text", "metadata.doc_id": "x"}, score=0.7) + vector._scope.search.return_value = SimpleNamespace(rows=lambda: [row_3]) + docs = vector.search_by_full_text("hello", top_k=1) + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == "x" + + vector._scope.search.side_effect = RuntimeError("full text failed") + with pytest.raises(ValueError, match="Search failed"): + vector.search_by_full_text("hello", top_k=1) + + assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2} + + +def test_delete_collection_and_factory(couchbase_module, monkeypatch): + vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module)) + scopes = [ + SimpleNamespace(collections=[SimpleNamespace(name="other")]), + SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]), + ] + vector._bucket.collections().get_all_scopes.return_value = scopes + + vector.delete() + vector._bucket.collections().drop_collection.assert_called_once_with("_default", "collection_1") + + factory = couchbase_module.CouchbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(couchbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + couchbase_module, + "current_app", + SimpleNamespace( + config={ + "COUCHBASE_CONNECTION_STRING": "couchbase://localhost", + "COUCHBASE_USER": "user", + "COUCHBASE_PASSWORD": "pass", + "COUCHBASE_BUCKET_NAME": "bucket", + "COUCHBASE_SCOPE_NAME": "scope", + } + ), + ) + + with patch.object(couchbase_module, "CouchbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py new file mode 100644 index 0000000000..edd29a4649 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py @@ -0,0 +1,121 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0"}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_ja_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + + importlib.reload(base_module) + return importlib.reload(ja_module) + + +def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_not_called() + + +def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(elasticsearch_ja_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_ja_module.ElasticSearchJaVector.__new__(elasticsearch_ja_module.ElasticSearchJaVector) + vector._collection_name = "test" + vector._client = MagicMock() + + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2, 0.3]], [{}]) + + vector._client.indices.create.assert_called_once() + kwargs = vector._client.indices.create.call_args.kwargs + assert kwargs["index"] == "test" + assert kwargs["mappings"]["properties"][elasticsearch_ja_module.Field.VECTOR]["dims"] == 3 + elasticsearch_ja_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_ja_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + + vector._client.indices.create.assert_not_called() + elasticsearch_ja_module.redis_client.set.assert_called_once() + + +def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch): + factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_ja_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr( + elasticsearch_ja_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_HOST": "localhost", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + + with patch.object(elasticsearch_ja_module, "ElasticSearchJaVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py new file mode 100644 index 0000000000..9ecf0caa24 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py @@ -0,0 +1,405 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class ConnectionError(Exception): + pass + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.ping = MagicMock(return_value=True) + self.info = MagicMock(return_value={"version": {"number": "8.12.0-SNAPSHOT"}}) + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), + delete=MagicMock(), + exists=MagicMock(return_value=False), + create=MagicMock(), + ) + + elasticsearch.Elasticsearch = Elasticsearch + elasticsearch.ConnectionError = ConnectionError + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def elasticsearch_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + + return importlib.reload(module) + + +def _regular_config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "username": "elastic", + "password": "secret", + "verify_certs": False, + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +def _cloud_config(module, **overrides): + values = { + "use_cloud": True, + "cloud_url": "https://cloud.example:9243", + "api_key": "api-key", + "verify_certs": True, + "ca_certs": "/tmp/ca.pem", + "request_timeout": 10, + "retry_on_timeout": True, + "max_retries": 3, + } + values.update(overrides) + return module.ElasticSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("values", "message"), + [ + ({"use_cloud": True, "cloud_url": None, "api_key": "x"}, "cloud_url is required"), + ({"use_cloud": True, "cloud_url": "https://cloud", "api_key": None}, "api_key is required"), + ({"host": None, "port": 9200, "username": "u", "password": "p"}, "HOST is required"), + ({"host": "h", "port": None, "username": "u", "password": "p"}, "PORT is required"), + ({"host": "h", "port": 9200, "username": None, "password": "p"}, "USERNAME is required"), + ({"host": "h", "port": 9200, "username": "u", "password": None}, "PASSWORD is required"), + ], +) +def test_elasticsearch_config_validation(elasticsearch_module, values, message): + with pytest.raises(ValidationError, match=message): + elasticsearch_module.ElasticSearchConfig.model_validate(values) + + +def test_init_client_cloud_configuration(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + result = vector._init_client(_cloud_config(elasticsearch_module)) + + assert result is client + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://cloud.example:9243"] + assert kwargs["api_key"] == "api-key" + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + +def test_init_client_regular_https_and_http_fallback(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + client = MagicMock() + client.ping.return_value = True + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client( + _regular_config( + elasticsearch_module, + host="https://es.example", + port=9443, + verify_certs=True, + ca_certs="/tmp/ca.pem", + ) + ) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["https://es.example:9443"] + assert kwargs["verify_certs"] is True + assert kwargs["ca_certs"] == "/tmp/ca.pem" + + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client) as es_cls: + vector._init_client(_regular_config(elasticsearch_module, host="es.internal", port=9200)) + kwargs = es_cls.call_args.kwargs + assert kwargs["hosts"] == ["http://es.internal:9200"] + assert "verify_certs" not in kwargs + + +def test_init_client_connection_failures(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + + client = MagicMock() + client.ping.return_value = False + with patch.object(elasticsearch_module, "Elasticsearch", return_value=client): + with pytest.raises(ConnectionError, match="Failed to connect"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object( + elasticsearch_module, + "Elasticsearch", + side_effect=elasticsearch_module.ElasticsearchConnectionError("boom"), + ): + with pytest.raises(ConnectionError, match="Vector database connection error"): + vector._init_client(_regular_config(elasticsearch_module)) + + with patch.object(elasticsearch_module, "Elasticsearch", side_effect=RuntimeError("oops")): + with pytest.raises(ConnectionError, match="initialization failed"): + vector._init_client(_regular_config(elasticsearch_module)) + + +def test_init_get_version_and_check_version(elasticsearch_module): + with ( + patch.object(elasticsearch_module.ElasticSearchVector, "_init_client", return_value=MagicMock()) as init_client, + patch.object(elasticsearch_module.ElasticSearchVector, "_get_version", return_value="8.10.0") as get_version, + patch.object(elasticsearch_module.ElasticSearchVector, "_check_version") as check_version, + ): + vector = elasticsearch_module.ElasticSearchVector( + "collection_1", _regular_config(elasticsearch_module), attributes=["doc_id"] + ) + + init_client.assert_called_once() + get_version.assert_called_once() + check_version.assert_called_once() + assert vector._attributes == ["doc_id"] + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._client = MagicMock() + vector._client.info.return_value = {"version": {"number": "8.13.2-SNAPSHOT"}} + assert vector._get_version() == "8.13.2" + + vector._version = "7.17.0" + with pytest.raises(ValueError, match="greater than 8.0.0"): + vector._check_version() + + vector._version = "8.0.0" + vector._check_version() + + +def test_crud_methods_and_get_type(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(refresh=MagicMock(), delete=MagicMock()) + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection_1") + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._client.delete.call_count == 2 + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "d1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "d2") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1") + assert vector.get_type() == elasticsearch_module.VectorType.ELASTICSEARCH + + +def test_search_by_vector_and_full_text(elasticsearch_module): + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.8, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-a", + elasticsearch_module.Field.VECTOR: [0.1], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "doc-b", + elasticsearch_module.Field.VECTOR: [0.2], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + knn = vector._client.search.call_args.kwargs["knn"] + assert knn["k"] == 2 + assert knn["num_candidates"] == 3 + assert "filter" in knn + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + elasticsearch_module.Field.CONTENT_KEY: "text-hit", + elasticsearch_module.Field.VECTOR: [0.3], + elasticsearch_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3, document_ids_filter=["d-3"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + query = vector._client.search.call_args.kwargs["query"] + assert "bool" in query + + +def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(elasticsearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(elasticsearch_module.redis_client, "set", MagicMock()) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = elasticsearch_module.ElasticSearchVector.__new__(elasticsearch_module.ElasticSearchVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.indices = SimpleNamespace(exists=MagicMock(return_value=False), create=MagicMock()) + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(elasticsearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + mappings = vector._client.indices.create.call_args.kwargs["mappings"] + assert mappings["properties"][elasticsearch_module.Field.VECTOR]["dims"] == 2 + elasticsearch_module.redis_client.set.assert_called_once() + + vector._client.indices.create.reset_mock() + elasticsearch_module.redis_client.set.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + elasticsearch_module.redis_client.set.assert_called_once() + + +def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch): + factory = elasticsearch_module.ElasticSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(elasticsearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": False, + "ELASTICSEARCH_HOST": "es-host", + "ELASTICSEARCH_PORT": 9200, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + "ELASTICSEARCH_VERIFY_CERTS": False, + } + ), + ) + + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + assert result_1 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert vector_cls.call_args.kwargs["index_name"] == "EXISTING_COLLECTION" + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": "https://cloud.elastic", + "ELASTICSEARCH_API_KEY": "api-key", + "ELASTICSEARCH_VERIFY_CERTS": True, + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + assert result_2 == "vector" + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is True + assert cfg.cloud_url == "https://cloud.elastic" + assert dataset_without_index.index_struct is not None + + monkeypatch.setattr( + elasticsearch_module, + "current_app", + SimpleNamespace( + config={ + "ELASTICSEARCH_USE_CLOUD": True, + "ELASTICSEARCH_CLOUD_URL": None, + "ELASTICSEARCH_HOST": "fallback-host", + "ELASTICSEARCH_PORT": 9201, + "ELASTICSEARCH_USERNAME": "elastic", + "ELASTICSEARCH_PASSWORD": "secret", + } + ), + ) + with patch.object(elasticsearch_module, "ElasticSearchVector", return_value="vector") as vector_cls: + factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + cfg = vector_cls.call_args.kwargs["config"] + assert cfg.use_cloud is False + assert cfg.host == "fallback-host" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py new file mode 100644 index 0000000000..5d9e744ded --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py @@ -0,0 +1,371 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_hologres_modules(): + holo_module = types.ModuleType("holo_search_sdk") + holo_types_module = types.ModuleType("holo_search_sdk.types") + + holo_types_module.BaseQuantizationType = str + holo_types_module.DistanceType = str + holo_types_module.TokenizerType = str + + def _connect(**kwargs): + client = MagicMock() + client.kwargs = kwargs + client.connect = MagicMock() + client.check_table_exist = MagicMock(return_value=False) + client.open_table = MagicMock(return_value=MagicMock()) + client.execute = MagicMock(return_value=[]) + client.drop_table = MagicMock() + return client + + holo_module.connect = MagicMock(side_effect=_connect) + + return { + "holo_search_sdk": holo_module, + "holo_search_sdk.types": holo_types_module, + } + + +@pytest.fixture +def hologres_module(monkeypatch): + for name, module in _build_fake_hologres_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.hologres.hologres_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.HologresVectorConfig( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema_name="public", + tokenizer="jieba", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config HOLOGRES_HOST is required"), + ("database", "", "config HOLOGRES_DATABASE is required"), + ("access_key_id", "", "config HOLOGRES_ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config HOLOGRES_ACCESS_KEY_SECRET is required"), + ], +) +def test_hologres_config_validation(hologres_module, field, value, message): + values = _valid_config(hologres_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + hologres_module.HologresVectorConfig.model_validate(values) + + +def test_init_client_and_get_type(hologres_module): + vector = hologres_module.HologresVector("Collection_One", _valid_config(hologres_module)) + + hologres_module.holo.connect.assert_called_once_with( + host="localhost", + port=80, + database="dify", + access_key_id="ak", + access_key_secret="sk", + schema="public", + ) + vector._client.connect.assert_called_once() + assert vector.table_name == "embedding_collection_one" + assert vector.get_type() == hologres_module.VectorType.HOLOGRES + + +def test_create_delegates_collection_creation_and_upsert(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result is None + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_returns_empty_for_empty_documents(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + assert vector.add_texts([], []) == [] + vector._client.open_table.assert_not_called() + + +def test_add_texts_batches_and_serializes_metadata(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + table = vector._client.open_table.return_value + documents = [ + Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}", "document_id": f"document-{i}"}) + for i in range(100) + ] + documents.append(SimpleNamespace(page_content="doc-100", metadata=None)) + embeddings = [[float(i)] for i in range(len(documents))] + + ids = vector.add_texts(documents, embeddings) + + assert ids[:2] == ["id-0", "id-1"] + assert ids[-1] == "" + assert len(ids) == 101 + assert vector._client.open_table.call_count == 2 + assert table.upsert_multi.call_count == 2 + first_call = table.upsert_multi.call_args_list[0].kwargs + second_call = table.upsert_multi.call_args_list[1].kwargs + assert first_call["index_column"] == "id" + assert first_call["column_names"] == ["id", "text", "meta", "embedding"] + assert first_call["update_columns"] == ["text", "meta", "embedding"] + assert len(first_call["values"]) == 100 + assert json.loads(first_call["values"][0][2]) == {"doc_id": "id-0", "document_id": "document-0"} + assert second_call["values"][0][0] == "" + assert second_call["values"][0][2] == "{}" + + +def test_text_exists_handles_missing_and_present_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + vector._client.execute.return_value = [(1,)] + + assert vector.text_exists("seg-1") is False + assert vector.text_exists("seg-1") is True + vector._client.execute.assert_called_once() + + +def test_get_ids_by_metadata_field_returns_ids_or_none(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.execute.side_effect = [[("id-1",), ("id-2",)], []] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +def test_delete_by_ids_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + + vector.delete_by_ids([]) + vector._client.check_table_exist.assert_not_called() + + vector._client.check_table_exist.return_value = False + vector.delete_by_ids(["id-1"]) + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_ids(["id-1", "id-2"]) + vector._client.execute.assert_called_once() + + +def test_delete_by_metadata_field_branches(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_not_called() + + vector._client.check_table_exist.return_value = True + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.execute.assert_called_once() + + +def test_search_by_vector_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_vector([0.1, 0.2]) == [] + + +def test_search_by_vector_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + query = MagicMock() + table.search_vector.return_value = query + query.select.return_value = query + query.limit.return_value = query + query.where.return_value = query + query.fetchall.return_value = [ + (0.2, "seg-1", "doc-1", '{"doc_id":"seg-1","document_id":"doc-1"}'), + (0.9, "seg-2", "doc-2", {"doc_id": "seg-2", "document_id": "doc-2"}), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["doc-1"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.8) + table.search_vector.assert_called_once() + query.where.assert_called_once() + + +def test_search_by_full_text_returns_empty_when_table_missing(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = False + + assert vector.search_by_full_text("query") == [] + + +def test_search_by_full_text_applies_filter_and_processes_results(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.return_value = True + table = vector._client.open_table.return_value + search_query = MagicMock() + table.search_text.return_value = search_query + search_query.limit.return_value = search_query + search_query.where.return_value = search_query + search_query.fetchall.return_value = [ + ("seg-1", "doc-1", '{"doc_id":"seg-1"}', [0.1], 0.95), + ("seg-2", "doc-2", {"doc_id": "seg-2"}, [0.2], 0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["doc-1"]) + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "seg-1" + assert docs[0].metadata["score"] == pytest.approx(0.95) + assert docs[1].metadata["score"] == pytest.approx(0.7) + table.search_text.assert_called_once() + search_query.where.assert_called_once() + + +def test_delete_handles_existing_and_missing_tables(hologres_module): + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, True] + + vector.delete() + vector._client.drop_table.assert_not_called() + + vector.delete() + vector._client.drop_table.assert_called_once_with(vector.table_name) + + +def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._create_collection(3) + + vector._client.check_table_exist.assert_not_called() + hologres_module.redis_client.set.assert_not_called() + + +def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False, False, True] + table = vector._client.open_table.return_value + + vector._create_collection(3) + + vector._client.execute.assert_called_once() + table.set_vector_index.assert_called_once_with( + column="embedding", + distance_method="Cosine", + base_quantization_type="rabitq", + max_degree=64, + ef_construction=400, + use_reorder=True, + ) + table.create_text_index.assert_called_once_with( + index_name="ft_idx_collection_one", + column="text", + tokenizer="jieba", + ) + hologres_module.redis_client.set.assert_called_once() + + +def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = False + monkeypatch.setattr(hologres_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(hologres_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(hologres_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(hologres_module.time, "sleep", MagicMock()) + + vector = hologres_module.HologresVector("collection_one", _valid_config(hologres_module)) + vector._client.check_table_exist.side_effect = [False] + [False] * 15 + + with pytest.raises(RuntimeError, match="was not ready after 30s"): + vector._create_collection(3) + + hologres_module.redis_client.set.assert_not_called() + + +def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch): + factory = hologres_module.HologresVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "existing_collection"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(hologres_module.Dataset, "gen_collection_name_by_id", lambda _id: "generated_collection") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_HOST", "127.0.0.1") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_PORT", 80) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DATABASE", "dify") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_SCHEMA", "public") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_TOKENIZER", "jieba") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_DISTANCE_METHOD", "Cosine") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq") + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_MAX_DEGREE", 64) + monkeypatch.setattr(hologres_module.dify_config, "HOLOGRES_EF_CONSTRUCTION", 400) + + with patch.object(hologres_module, "HologresVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "generated_collection" + generated_config = vector_cls.call_args_list[1].kwargs["config"] + assert generated_config.host == "127.0.0.1" + assert generated_config.database == "dify" + assert generated_config.access_key_id == "ak" + assert json.loads(dataset_without_index.index_struct) == { + "type": hologres_module.VectorType.HOLOGRES, + "vector_store": {"class_prefix": "generated_collection"}, + } diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py new file mode 100644 index 0000000000..9d23dfcf63 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py @@ -0,0 +1,243 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_elasticsearch_modules(): + elasticsearch = types.ModuleType("elasticsearch") + + class Elasticsearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.index = MagicMock() + self.exists = MagicMock(return_value=False) + self.delete = MagicMock() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.indices = SimpleNamespace( + refresh=MagicMock(), delete=MagicMock(), exists=MagicMock(return_value=False), create=MagicMock() + ) + + elasticsearch.Elasticsearch = Elasticsearch + return {"elasticsearch": elasticsearch} + + +@pytest.fixture +def huawei_module(monkeypatch): + for name, module in _build_fake_elasticsearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.HuaweiCloudVectorConfig(hosts="http://localhost:9200", username="user", password="pass") + + +def test_create_ssl_context(huawei_module): + ctx = huawei_module.create_ssl_context() + assert ctx.check_hostname is False + assert ctx.verify_mode == huawei_module.ssl.CERT_NONE + + +def test_huawei_config_validation_and_params(huawei_module): + with pytest.raises(ValidationError, match="HOSTS is required"): + huawei_module.HuaweiCloudVectorConfig.model_validate({"hosts": ""}) + + config = _config(huawei_module) + params = config.to_elasticsearch_params() + assert params["hosts"] == ["http://localhost:9200"] + assert params["basic_auth"] == ("user", "pass") + + config = huawei_module.HuaweiCloudVectorConfig(hosts="host1,host2", username=None, password=None) + params = config.to_elasticsearch_params() + assert "basic_auth" not in params + + +def test_init_get_type_and_add_texts(huawei_module): + vector = huawei_module.HuaweiCloudVector("COLLECTION", _config(huawei_module)) + + assert vector._collection_name == "collection" + assert vector.get_type() == huawei_module.VectorType.HUAWEI_CLOUD + + vector._get_uuids = MagicMock(return_value=["id-1", "id-2"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.index.call_count == 2 + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_crud_methods(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + + vector._client.exists.return_value = True + assert vector.text_exists("id-1") is True + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once_with(index="collection", id="id-1") + + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}]}} + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.delete_by_ids.reset_mock() + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_metadata_field("doc_id", "x") + vector.delete_by_ids.assert_not_called() + + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection") + + +def test_search_by_vector_and_full_text(huawei_module): + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + }, + { + "_score": 0.1, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-b", + huawei_module.Field.VECTOR: [0.2], + huawei_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + }, + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + query_body = vector._client.search.call_args.kwargs["body"] + assert query_body["query"]["vector"][huawei_module.Field.VECTOR]["topk"] == 2 + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + huawei_module.Field.CONTENT_KEY: "text-hit", + huawei_module.Field.VECTOR: [0.3], + huawei_module.Field.METADATA_KEY: {"doc_id": "3"}, + } + } + ] + } + } + docs = vector.search_by_full_text("hello", top_k=3) + assert len(docs) == 1 + assert docs[0].page_content == "text-hit" + + +def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch): + class FakeDocument: + def __init__(self, page_content, vector, metadata): + self.page_content = page_content + self.vector = vector + self.metadata = None + + monkeypatch.setattr(huawei_module, "Document", FakeDocument) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + huawei_module.Field.CONTENT_KEY: "doc-a", + huawei_module.Field.VECTOR: [0.1], + huawei_module.Field.METADATA_KEY: {"doc_id": "1"}, + }, + } + ] + } + } + + docs = vector.search_by_vector([0.1, 0.2], top_k=1, score_threshold=0.5) + + assert docs == [] + + +def test_create_and_create_collection_paths(huawei_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(huawei_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(huawei_module.redis_client, "set", MagicMock()) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = huawei_module.HuaweiCloudVector("collection", _config(huawei_module)) + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(huawei_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], [{}]) + vector._client.indices.create.assert_called_once() + + kwargs = vector._client.indices.create.call_args.kwargs + mappings = kwargs["mappings"] + assert mappings["properties"][huawei_module.Field.VECTOR]["dimension"] == 2 + assert kwargs["settings"] == {"index.vector": True} + huawei_module.redis_client.set.assert_called_once() + + +def test_huawei_factory_branches(huawei_module, monkeypatch): + factory = huawei_module.HuaweiCloudVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(huawei_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_HOSTS", "http://huawei-es:9200") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_USER", "user") + monkeypatch.setattr(huawei_module.dify_config, "HUAWEI_CLOUD_PASSWORD", "pass") + + with patch.object(huawei_module, "HuaweiCloudVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["index_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["index_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py new file mode 100644 index 0000000000..63338ca809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py @@ -0,0 +1,412 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_iris_module(): + iris = types.ModuleType("iris") + + def connect(**_kwargs): + conn = MagicMock() + conn.cursor.return_value = MagicMock() + return conn + + iris.connect = MagicMock(side_effect=connect) + return iris + + +@pytest.fixture +def iris_module(monkeypatch): + monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) + + import core.rag.datasource.vdb.iris.iris_vector as module + + reloaded = importlib.reload(module) + reloaded._pool_instance = None + return reloaded + + +def _config(module, **overrides): + values = { + "IRIS_HOST": "localhost", + "IRIS_SUPER_SERVER_PORT": 1972, + "IRIS_USER": "user", + "IRIS_PASSWORD": "pass", + "IRIS_DATABASE": "db", + "IRIS_SCHEMA": "schema", + "IRIS_CONNECTION_URL": "url", + "IRIS_MIN_CONNECTION": 1, + "IRIS_MAX_CONNECTION": 2, + "IRIS_TEXT_INDEX": True, + "IRIS_TEXT_INDEX_LANGUAGE": "en", + } + values.update(overrides) + return module.IrisVectorConfig.model_validate(values) + + +def test_get_iris_pool_singleton(iris_module): + iris_module._pool_instance = None + cfg = _config(iris_module) + + with patch.object(iris_module, "IrisConnectionPool", return_value="pool") as pool_cls: + pool_1 = iris_module.get_iris_pool(cfg) + pool_2 = iris_module.get_iris_pool(cfg) + + assert pool_1 == "pool" + assert pool_2 == "pool" + pool_cls.assert_called_once_with(cfg) + + +@pytest.fixture +def pool_with_min_max(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + with patch.object(iris_module.IrisConnectionPool, "_create_connection", return_value=MagicMock()) as create_conn: + pool = iris_module.IrisConnectionPool(cfg) + yield pool, create_conn + + +def test_pool_initialization_respects_min_max(pool_with_min_max): + pool, create_conn = pool_with_min_max + assert len(pool._pool) == 2 + assert create_conn.call_count == 2 + + +@pytest.fixture +def pool_for_get_connection(iris_module): + cfg = _config(iris_module, IRIS_MIN_CONNECTION=2, IRIS_MAX_CONNECTION=3) + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_get_connection_returns_existing_and_increments(pool_for_get_connection): + pool = pool_for_get_connection + conn = MagicMock() + pool._pool = [conn] + pool._in_use = 0 + assert pool.get_connection() is conn + assert pool._in_use == 1 + + +def test_get_connection_creates_new_when_empty(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = 0 + pool._create_connection = MagicMock(return_value="new-conn") + assert pool.get_connection() == "new-conn" + + +def test_get_connection_raises_when_exhausted(pool_for_get_connection): + pool = pool_for_get_connection + pool._pool = [] + pool._in_use = pool._max_size + with pytest.raises(RuntimeError, match="exhausted"): + pool.get_connection() + + +@pytest.fixture +def pool_for_return_connection(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + return pool + + +def test_return_connection_adds_healthy(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool.return_connection(conn) + assert pool._pool[-1] is conn + assert pool._in_use == 0 + + +def test_return_connection_replaces_bad(pool_for_return_connection): + pool = pool_for_return_connection + pool._in_use = 1 + bad_conn = MagicMock() + bad_cursor = MagicMock() + bad_cursor.execute.side_effect = OSError("bad") + bad_conn.cursor.return_value = bad_cursor + replacement = MagicMock() + pool._create_connection = MagicMock(return_value=replacement) + pool.return_connection(bad_conn) + bad_conn.close.assert_called_once() + assert pool._pool[-1] is replacement + assert pool._in_use == 0 + + +def test_return_connection_ignores_none(pool_for_return_connection): + pool = pool_for_return_connection + before = len(pool._pool) + pool.return_connection(None) + assert len(pool._pool) == before + + +@pytest.fixture +def pool_for_schema_and_close(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + pool._pool = [conn] + return pool, conn, cursor + + +def test_ensure_schema_exists_cached_noop(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = {"cached_schema"} + pool.ensure_schema_exists("cached_schema") + cursor.execute.assert_not_called() + + +def test_ensure_schema_exists_creates_new(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (0,) + pool.ensure_schema_exists("new_schema") + assert "new_schema" in pool._schemas_initialized + assert any("CREATE SCHEMA" in call.args[0] for call in cursor.execute.call_args_list) + conn.commit.assert_called_once() + + +def test_ensure_schema_exists_existing_no_commit(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.fetchone.return_value = (1,) + pool.ensure_schema_exists("existing_schema") + conn.commit.assert_not_called() + + +def test_ensure_schema_exists_rollback_on_error(pool_for_schema_and_close): + pool, conn, cursor = pool_for_schema_and_close + pool._schemas_initialized = set() + cursor.execute.side_effect = RuntimeError("schema failure") + with pytest.raises(RuntimeError, match="schema failure"): + pool.ensure_schema_exists("broken_schema") + conn.rollback.assert_called() + + +def test_close_all_closes_and_resets(iris_module): + cfg = _config(iris_module) + with patch.object(iris_module.IrisConnectionPool, "_initialize_pool", return_value=None): + pool = iris_module.IrisConnectionPool(cfg) + conn = MagicMock() + conn_2 = MagicMock() + conn_2.close.side_effect = OSError("close fail") + pool._pool = [conn, conn_2] + pool._schemas_initialized = {"x"} + pool.close_all() + assert pool._pool == [] + assert pool._in_use == 0 + assert pool._schemas_initialized == set() + + +def test_iris_vector_init_get_cursor_and_create(iris_module): + pool = MagicMock() + pool.get_connection.return_value = MagicMock() + + with patch.object(iris_module, "get_iris_pool", return_value=pool): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + assert vector.table_name == "EMBEDDING_COLLECTION" + assert vector.schema == "schema" + assert vector.get_type() == iris_module.VectorType.IRIS + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + + with vector._get_cursor() as got_cursor: + assert got_cursor is cursor + conn.commit.assert_called_once() + vector.pool.return_connection.assert_called_with(conn) + + conn = MagicMock() + cursor = MagicMock() + conn.cursor.return_value = cursor + vector.pool.get_connection.return_value = conn + with pytest.raises(RuntimeError, match="boom"): + with vector._get_cursor(): + raise RuntimeError("boom") + conn.rollback.assert_called_once() + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["id-1"]) + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + assert vector.create(docs, [[0.1, 0.2]]) == ["id-1"] + vector._create_collection.assert_called_once_with(2) + + +def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(iris_module.uuid, "uuid4", lambda: "generated-id") + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "generated-id"] + assert cursor.execute.call_count == 2 + + cursor.fetchone.return_value = (1,) + assert vector.text_exists("id-1") is True + cursor.fetchone.return_value = None + assert vector.text_exists("id-2") is False + + vector._get_cursor = MagicMock(side_effect=RuntimeError("db down")) + assert vector.text_exists("id-3") is False + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + before = cursor.execute.call_count + vector.delete_by_ids(["id-1", "id-2"]) + assert cursor.execute.call_count == before + 1 + + vector.delete_by_metadata_field("document_id", "doc-1") + assert "meta LIKE" in cursor.execute.call_args.args[0] + + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.9), + ("id-2", "text-2", '{"document_id":"d-2"}', 0.2), + ("id-x",), + ] + docs = vector.search_by_vector([0.1, 0.2], top_k=3, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + +def test_iris_vector_full_text_search_paths(iris_module, monkeypatch): + cfg = _config(iris_module, IRIS_TEXT_INDEX=True) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", cfg) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + cursor.execute.side_effect = None + cursor.fetchall.return_value = [ + ("id-1", "text-1", '{"document_id":"d-1"}', 0.7), + ("id-2", "text-2", "{}", None), + ] + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.7) + assert docs[1].metadata["score"] == pytest.approx(0.0) + + cursor.reset_mock() + cursor.execute.side_effect = [RuntimeError("rank failed"), None] + cursor.fetchall.return_value = [("id-3", "text-3", "{}", 0.5)] + docs = vector.search_by_full_text("query", top_k=1) + assert len(docs) == 1 + assert cursor.execute.call_count == 2 + + cfg_like = _config(iris_module, IRIS_TEXT_INDEX=False) + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector_like = iris_module.IrisVector("collection", cfg_like) + vector_like._get_cursor = _cursor_ctx + + fake_libs = types.ModuleType("libs") + fake_helper = types.ModuleType("libs.helper") + fake_helper.escape_like_pattern = lambda value: value.replace("%", "\\%") + monkeypatch.setitem(sys.modules, "libs", fake_libs) + monkeypatch.setitem(sys.modules, "libs.helper", fake_helper) + + cursor.reset_mock() + cursor.execute.side_effect = None + cursor.fetchall.return_value = [] + assert vector_like.search_by_full_text("100%", top_k=1) == [] + + +def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch): + with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()): + vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True)) + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete() + assert "DROP TABLE" in cursor.execute.call_args.args[0] + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(iris_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(iris_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_called_once() + + cursor.reset_mock() + monkeypatch.setattr(iris_module.redis_client, "get", MagicMock(return_value=None)) + vector.pool.ensure_schema_exists = MagicMock() + vector._create_collection(3) + assert cursor.execute.call_count == 3 + iris_module.redis_client.set.assert_called_once() + + cursor.reset_mock() + vector.config.IRIS_TEXT_INDEX = False + vector._create_collection(3) + assert cursor.execute.call_count == 2 + + factory = iris_module.IrisVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(iris_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(iris_module.dify_config, "IRIS_HOST", "localhost") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SUPER_SERVER_PORT", 1972) + monkeypatch.setattr(iris_module.dify_config, "IRIS_USER", "user") + monkeypatch.setattr(iris_module.dify_config, "IRIS_PASSWORD", "pass") + monkeypatch.setattr(iris_module.dify_config, "IRIS_DATABASE", "db") + monkeypatch.setattr(iris_module.dify_config, "IRIS_SCHEMA", "schema") + monkeypatch.setattr(iris_module.dify_config, "IRIS_CONNECTION_URL", "url") + monkeypatch.setattr(iris_module.dify_config, "IRIS_MIN_CONNECTION", 1) + monkeypatch.setattr(iris_module.dify_config, "IRIS_MAX_CONNECTION", 2) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX", True) + monkeypatch.setattr(iris_module.dify_config, "IRIS_TEXT_INDEX_LANGUAGE", "en") + + with patch.object(iris_module, "IrisVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py new file mode 100644 index 0000000000..34357d5907 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py @@ -0,0 +1,394 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearch_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = SimpleNamespace( + refresh=MagicMock(), + exists=MagicMock(return_value=False), + delete=MagicMock(), + create=MagicMock(), + ) + self.bulk = MagicMock(return_value={"errors": False, "items": []}) + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.delete_by_query = MagicMock() + self.get = MagicMock(return_value={"_id": "id"}) + self.exists = MagicMock(return_value=True) + + opensearch_helpers.BulkIndexError = BulkIndexError + opensearch_helpers.bulk = MagicMock() + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.helpers = opensearch_helpers + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearch_helpers, + } + + +@pytest.fixture +def lindorm_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.lindorm.lindorm_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.LindormVectorStoreConfig( + hosts="http://localhost:9200", + username="user", + password="pass", + using_ugc=False, + request_timeout=3.0, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("hosts", None, "config URL is required"), + ("username", None, "config USERNAME is required"), + ("password", None, "config PASSWORD is required"), + ], +) +def test_lindorm_config_validation(lindorm_module, field, value, message): + values = _config(lindorm_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + lindorm_module.LindormVectorStoreConfig.model_validate(values) + + +def test_to_opensearch_params_and_init(lindorm_module): + cfg = _config(lindorm_module) + params = cfg.to_opensearch_params() + + assert params["hosts"] == "http://localhost:9200" + assert params["http_auth"] == ("user", "pass") + + vector = lindorm_module.LindormVectorStore("Collection", cfg, using_ugc=False) + assert vector._collection_name == "collection" + assert vector.get_type() == lindorm_module.VectorType.LINDORM + + with pytest.raises(ValueError, match="routing_value"): + lindorm_module.LindormVectorStore("c", cfg, using_ugc=True) + + vector_ugc = lindorm_module.LindormVectorStore("c", cfg, using_ugc=True, routing_value="ROUTE") + assert vector_ugc._routing == "route" + + +def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with([[0.1]], [{"doc_id": "id-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + monkeypatch.setattr(lindorm_module.time, "sleep", MagicMock()) + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + + vector.add_texts(docs, embeddings, batch_size=2, timeout=9) + + assert vector._client.bulk.call_count == 2 + actions = vector._client.bulk.call_args_list[0].args[0] + assert actions[0]["index"]["routing"] == "route" + assert actions[1][lindorm_module.ROUTING_FIELD] == "route" + vector.refresh() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + +def test_add_texts_error_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.bulk.return_value = {"errors": True, "items": [{"index": {"error": "boom"}}]} + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + vector._client.bulk.side_effect = RuntimeError("bulk failed") + with pytest.raises(Exception, match="RetryError"): + vector.add_texts(docs, [[0.1]], batch_size=1) + + +def test_metadata_lookup_and_delete_by_metadata(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + query = vector._client.search.call_args.kwargs["body"] + must_conditions = query["query"]["bool"]["must"] + assert any("routing_field.keyword" in cond.get("term", {}) for cond in must_conditions) + + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1", "id-2"]) + + vector._client.search.return_value = {"hits": {"hits": []}} + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-2") + vector.delete_by_ids.assert_not_called() + + +def test_delete_by_ids_paths(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + vector.delete_by_ids([]) + vector._client.indices.exists.assert_not_called() + + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["id-1"]) + + vector._client.indices.exists.return_value = True + vector._client.exists.side_effect = [True, False] + lindorm_module.helpers.bulk.reset_mock() + vector.delete_by_ids(["id-1", "id-2"]) + lindorm_module.helpers.bulk.assert_called_once() + actions = lindorm_module.helpers.bulk.call_args.args[1] + assert len(actions) == 1 + assert actions[0]["routing"] == "route" + + lindorm_module.helpers.bulk.reset_mock() + lindorm_module.helpers.bulk.side_effect = lindorm_module.BulkIndexError( + errors=[ + {"delete": {"status": 404, "_id": "id-404"}}, + {"delete": {"status": 500, "_id": "id-500"}}, + ] + ) + vector._client.exists.side_effect = [True] + vector.delete_by_ids(["id-1"]) + + +def test_delete_and_text_exists(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector.delete() + vector._client.delete_by_query.assert_called_once() + vector._client.indices.refresh.assert_called_once_with(index="collection") + + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + vector._client.indices.exists.return_value = True + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection", params={"timeout": 60}) + + vector._client.indices.delete.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete() + vector._client.indices.delete.assert_not_called() + + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("missing") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validation_and_success(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + + with pytest.raises(ValueError, match="should be a list"): + vector.search_by_vector("bad") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, "bad"]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_score": 0.9, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1", "document_id": "d-1"}, + }, + }, + { + "_score": 0.2, + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-b", + lindorm_module.Field.VECTOR: [0.2], + lindorm_module.Field.METADATA_KEY: {"doc_id": "2", "document_id": "d-2"}, + }, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + call_kwargs = vector._client.search.call_args.kwargs + query = call_kwargs["body"] + assert "ext" in query + assert query["query"]["knn"][lindorm_module.Field.VECTOR]["filter"]["bool"]["must"] + assert call_kwargs["params"]["routing"] == "route" + + vector._client.search.side_effect = RuntimeError("search failed") + with pytest.raises(RuntimeError, match="search failed"): + vector.search_by_vector([0.1]) + + +def test_search_by_full_text_success_and_error(lindorm_module): + vector = lindorm_module.LindormVectorStore( + "collection", _config(lindorm_module), using_ugc=True, routing_value="route" + ) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + lindorm_module.Field.CONTENT_KEY: "doc-a", + lindorm_module.Field.VECTOR: [0.1], + lindorm_module.Field.METADATA_KEY: {"doc_id": "1"}, + } + } + ] + } + } + + docs = vector.search_by_full_text("hello", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] + + vector._client.search.side_effect = RuntimeError("full text failed") + with pytest.raises(RuntimeError, match="full text failed"): + vector.search_by_full_text("hello") + + +def test_create_collection_paths(lindorm_module, monkeypatch): + vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False) + + with pytest.raises(ValueError, match="cannot be empty"): + vector.create_collection([]) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(lindorm_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(lindorm_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(lindorm_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]], index_params={"index_type": "ivf", "space_type": "cosine"}) + vector._client.indices.create.assert_called_once() + body = vector._client.indices.create.call_args.kwargs["body"] + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["name"] == "ivf" + assert body["mappings"]["properties"][lindorm_module.Field.VECTOR]["method"]["space_type"] == "cosine" + + vector._client.indices.create.reset_mock() + vector._client.indices.exists.return_value = True + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + +def test_lindorm_factory_branches(lindorm_module, monkeypatch): + factory = lindorm_module.LindormVectorStoreFactory() + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USERNAME", "user") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_PASSWORD", "pass") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_QUERY_TIMEOUT", 3.0) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_INDEX_TYPE", "hnsw") + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_DISTANCE_TYPE", "l2") + monkeypatch.setattr(lindorm_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + dataset = SimpleNamespace(id="dataset-1", index_struct=None, index_struct_dict={}) + embeddings = SimpleNamespace(embed_query=lambda _q: [0.1, 0.2, 0.3]) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", None) + with pytest.raises(ValueError, match="LINDORM_USING_UGC is not set"): + factory.init_vector(dataset, attributes=[], embeddings=embeddings) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + + dataset_existing_plain = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}, "using_ugc": False}, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + result = factory.init_vector(dataset_existing_plain, attributes=[], embeddings=embeddings) + assert result == "vector" + assert store_cls.call_args.args[0] == "existing" + + dataset_existing_ugc = SimpleNamespace( + id="dataset-1", + index_struct="{}", + index_struct_dict={ + "vector_store": {"class_prefix": "ROUTING"}, + "using_ugc": True, + "dimension": 1536, + "index_type": "hnsw", + "distance_type": "l2", + }, + ) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_existing_ugc, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_1536_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "ROUTING" + + dataset_new = SimpleNamespace(id="dataset-2", index_struct=None, index_struct_dict={}) + + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", True) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "ugc_index_3_hnsw_l2" + assert store_cls.call_args.kwargs["routing_value"] == "auto_collection" + assert dataset_new.index_struct is not None + + dataset_new_plain = SimpleNamespace(id="dataset-3", index_struct=None, index_struct_dict={}) + monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_USING_UGC", False) + with patch.object(lindorm_module, "LindormVectorStore", return_value="vector") as store_cls: + factory.init_vector(dataset_new_plain, attributes=[], embeddings=embeddings) + assert store_cls.call_args.args[0] == "auto_collection" + assert store_cls.call_args.kwargs["routing_value"] is None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py new file mode 100644 index 0000000000..55e7b9112e --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py @@ -0,0 +1,252 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_mo_vector_modules(): + mo_vector = types.ModuleType("mo_vector") + mo_vector.__path__ = [] + mo_vector_client = types.ModuleType("mo_vector.client") + + class MoVectorClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_full_text_index = MagicMock() + self.insert = MagicMock() + self.get = MagicMock(return_value=[]) + self.delete = MagicMock() + self.query_by_metadata = MagicMock(return_value=[]) + self.query = MagicMock(return_value=[]) + self.full_text_query = MagicMock(return_value=[]) + + mo_vector_client.MoVectorClient = MoVectorClient + mo_vector.client = mo_vector_client + return {"mo_vector": mo_vector, "mo_vector.client": mo_vector_client} + + +@pytest.fixture +def matrixone_module(monkeypatch): + for name, module in _build_fake_mo_vector_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.matrixone.matrixone_vector as module + + return importlib.reload(module) + + +def _valid_config(module): + return module.MatrixoneConfig( + host="localhost", + port=6001, + user="dump", + password="111", + database="dify", + metric="l2", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config host is required"), + ("port", 0, "config port is required"), + ("user", "", "config user is required"), + ("password", "", "config password is required"), + ("database", "", "config database is required"), + ], +) +def test_matrixone_config_validation(matrixone_module, field, value, message): + values = _valid_config(matrixone_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + matrixone_module.MatrixoneConfig.model_validate(values) + + +def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client.kwargs["table_name"] == "collection_1" + client.create_full_text_index.assert_called_once() + matrixone_module.redis_client.set.assert_called_once() + + +def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + vector = matrixone_module.MatrixoneVector("Collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + client.create_full_text_index.assert_not_called() + matrixone_module.redis_client.set.assert_not_called() + + +def test_ensure_client_initializes_client_for_decorated_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = None + fake_client = MagicMock() + fake_client.get.return_value = [{"id": "seg-1"}] + vector._get_client = MagicMock(return_value=fake_client) + + exists = vector.text_exists("seg-1") + + assert exists is True + vector._get_client.assert_called_once_with(None, False) + + +def test_search_by_full_text_parses_metadata_and_applies_threshold(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.full_text_query.return_value = [ + SimpleNamespace(document="doc-a", metadata='{"doc_id":"1"}', distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.7), + ] + + docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "doc-a" + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert vector.client.full_text_query.call_args.kwargs["filter"] == {"document_id": {"$in": ["doc-1"]}} + + +def test_get_type_and_create_delegate_to_add_texts(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + fake_client = MagicMock() + vector._get_client = MagicMock(return_value=fake_client) + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "matrixone" + assert result == ["seg-1"] + vector._get_client.assert_called_once_with(2, True) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(matrixone_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(matrixone_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(matrixone_module.redis_client, "set", MagicMock()) + + failing_client = MagicMock() + failing_client.create_full_text_index.side_effect = RuntimeError("boom") + monkeypatch.setattr(matrixone_module, "MoVectorClient", MagicMock(return_value=failing_client)) + + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + client = vector._get_client(dimension=3, create_table=True) + + assert client is failing_client + matrixone_module.redis_client.set.assert_not_called() + + +def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + # For current prod code, only docs with metadata get ids, so only two ids + assert ids == ["doc-a", "generated-uuid"] + vector.client.insert.assert_called_once() + insert_kwargs = vector.client.insert.call_args.kwargs + # All lists passed to insert should be the same length + texts = insert_kwargs["texts"] + embeddings = insert_kwargs["embeddings"] + metadatas = insert_kwargs["metadatas"] + ids_insert = insert_kwargs["ids"] + assert len(texts) == len(embeddings) == len(metadatas) == len(docs) + # ids may be shorter than docs for current prod code, but should match number of docs with metadata + assert ids_insert == ["doc-a", "generated-uuid"] + + +def test_delete_and_metadata_methods(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query_by_metadata.return_value = [SimpleNamespace(id="seg-1"), SimpleNamespace(id="seg-2")] + + vector.delete_by_ids([]) + vector.client.delete.assert_not_called() + + vector.delete_by_ids(["seg-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + vector.delete() + + assert ids == ["seg-1", "seg-2"] + assert vector.client.delete.call_count == 3 + + +def test_search_by_vector_builds_documents(matrixone_module): + vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module)) + vector.client = MagicMock() + vector.client.query.return_value = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}), + ] + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, document_ids_filter=["d-1"]) + + assert len(docs) == 2 + assert docs[0].page_content == "doc-a" + assert docs[1].metadata["doc_id"] == "2" + assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}} + + +def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch): + factory = matrixone_module.MatrixoneVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(matrixone_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_HOST", "127.0.0.1") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PORT", 6001) + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_USER", "dump") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_PASSWORD", "111") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_DATABASE", "dify") + monkeypatch.setattr(matrixone_module.dify_config, "MATRIXONE_METRIC", "l2") + + with patch.object(matrixone_module, "MatrixoneVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index fb2ddfe162..2ac2c40d38 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -1,18 +1,414 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + import pytest from pydantic import ValidationError -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig +from core.rag.models.document import Document -def test_default_value(): +def _build_fake_pymilvus_modules(): + pymilvus = types.ModuleType("pymilvus") + pymilvus.__path__ = [] + pymilvus_milvus_client = types.ModuleType("pymilvus.milvus_client") + pymilvus_orm = types.ModuleType("pymilvus.orm") + pymilvus_orm.__path__ = [] + pymilvus_orm_types = types.ModuleType("pymilvus.orm.types") + + class MilvusError(Exception): + pass + + class MilvusClient: + def __init__(self, **kwargs): + self.init_kwargs = kwargs + self.has_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock( + return_value={"fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}]} + ) + self.get_server_version = MagicMock(return_value="2.5.0") + self.insert = MagicMock(return_value=[1]) + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.drop_collection = MagicMock() + self.search = MagicMock(return_value=[[]]) + self.create_collection = MagicMock() + + class IndexParams: + def __init__(self): + self.indexes = [] + + def add_index(self, **kwargs): + self.indexes.append(kwargs) + + class DataType: + JSON = "JSON" + VARCHAR = "VARCHAR" + INT64 = "INT64" + SPARSE_FLOAT_VECTOR = "SPARSE_FLOAT_VECTOR" + FLOAT_VECTOR = "FLOAT_VECTOR" + + class FieldSchema: + def __init__(self, name, dtype, **kwargs): + self.name = name + self.dtype = dtype + self.kwargs = kwargs + + class CollectionSchema: + def __init__(self, fields): + self.fields = fields + self.functions = [] + + def add_function(self, func): + self.functions.append(func) + + class FunctionType: + BM25 = "BM25" + + class Function: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def infer_dtype_bydata(_value): + return DataType.FLOAT_VECTOR + + pymilvus.MilvusException = MilvusError + pymilvus.MilvusClient = MilvusClient + pymilvus.IndexParams = IndexParams + pymilvus.CollectionSchema = CollectionSchema + pymilvus.DataType = DataType + pymilvus.FieldSchema = FieldSchema + pymilvus.Function = Function + pymilvus.FunctionType = FunctionType + pymilvus_milvus_client.IndexParams = IndexParams + pymilvus_orm.types = pymilvus_orm_types + pymilvus_orm_types.infer_dtype_bydata = infer_dtype_bydata + + # Attach submodules for dotted imports + pymilvus.milvus_client = pymilvus_milvus_client + pymilvus.orm = pymilvus_orm + + return { + "pymilvus": pymilvus, + "pymilvus.milvus_client": pymilvus_milvus_client, + "pymilvus.orm": pymilvus_orm, + "pymilvus.orm.types": pymilvus_orm_types, + } + + +@pytest.fixture +def milvus_module(monkeypatch): + for name, module in _build_fake_pymilvus_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.milvus.milvus_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "uri": "http://localhost:19530", + "user": "root", + "password": "Milvus", + "database": "default", + "enable_hybrid_search": False, + "analyzer_params": None, + } + values.update(overrides) + return module.MilvusConfig.model_validate(values) + + +def test_config_validation_and_defaults(milvus_module): valid_config = {"uri": "http://localhost:19530", "user": "root", "password": "Milvus"} for key in valid_config: config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig.model_validate(config) + milvus_module.MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig.model_validate(valid_config) + config = milvus_module.MilvusConfig.model_validate(valid_config) assert config.database == "default" + + token_config = milvus_module.MilvusConfig.model_validate( + {"uri": "http://localhost:19530", "token": "token-value", "database": "db-1"} + ) + assert token_config.token == "token-value" + + +def test_config_to_milvus_params(milvus_module): + config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + + params = config.to_milvus_params() + + assert params["uri"] == "http://localhost:19530" + assert params["db_name"] == "default" + assert params["analyzer_params"] == '{"tokenizer":"standard"}' + + +def test_init_client_supports_token_and_user_password(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + token_client = vector._init_client( + milvus_module.MilvusConfig.model_validate({"uri": "http://localhost:19530", "token": "abc", "database": "db"}) + ) + assert token_client.init_kwargs == {"uri": "http://localhost:19530", "token": "abc", "db_name": "db"} + + user_client = vector._init_client(_config(milvus_module)) + assert user_client.init_kwargs["uri"] == "http://localhost:19530" + assert user_client.init_kwargs["user"] == "root" + assert user_client.init_kwargs["password"] == "Milvus" + + +def test_init_loads_fields_when_collection_exists(milvus_module): + client = milvus_module.MilvusClient(uri="http://localhost:19530") + client.has_collection.return_value = True + client.describe_collection.return_value = { + "fields": [{"name": "id"}, {"name": "content"}, {"name": "metadata"}, {"name": "sparse_vector"}] + } + + with patch.object(milvus_module.MilvusVector, "_init_client", return_value=client): + with patch.object(milvus_module.MilvusVector, "_check_hybrid_search_support", return_value=False): + vector = milvus_module.MilvusVector("collection_1", _config(milvus_module)) + + assert "id" not in vector._fields + assert "content" in vector._fields + + +def test_load_collection_fields_from_argument_and_remote(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + vector._collection_name = "collection_1" + vector._client.describe_collection.return_value = {"fields": [{"name": "id"}, {"name": "content"}]} + + vector._load_collection_fields(["id", "metadata"]) + assert vector._fields == ["metadata"] + + vector._load_collection_fields() + assert vector._fields == ["content"] + + +def test_check_hybrid_search_support_branches(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._client = MagicMock() + + vector._client_config = SimpleNamespace(enable_hybrid_search=False) + assert vector._check_hybrid_search_support() is False + + vector._client_config = SimpleNamespace(enable_hybrid_search=True) + vector._client.get_server_version.return_value = "Zilliz Cloud 2.4" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.5.1" + assert vector._check_hybrid_search_support() is True + + vector._client.get_server_version.return_value = "2.4.9" + assert vector._check_hybrid_search_support() is False + + vector._client.get_server_version.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_get_type_and_create_delegate(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [SimpleNamespace(page_content="hello", metadata=None)] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "milvus" + vector.create_collection.assert_called_once() + create_args = vector.create_collection.call_args.args + assert create_args[0] == [[0.1, 0.2]] + assert create_args[1] == [{}] + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_batches_and_raises_milvus_exception(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.insert.side_effect = [["id-1"], ["id-2"]] + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"d-{i}"}) for i in range(1001)] + embeddings = [[0.1, 0.2] for _ in range(1001)] + + ids = vector.add_texts(docs, embeddings) + assert ids == ["id-1", "id-2"] + assert vector._client.insert.call_count == 2 + + vector._client.insert.side_effect = milvus_module.MilvusException("insert failed") + with pytest.raises(milvus_module.MilvusException): + vector.add_texts([Document(page_content="x", metadata={})], [[0.1]]) + + +def test_get_ids_and_delete_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.query.return_value = [{"id": 1}, {"id": 2}] + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == [1, 2] + vector._client.query.return_value = [] + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector._client.has_collection.return_value = True + vector.get_ids_by_metadata_field = MagicMock(return_value=[101, 102]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[101, 102]) + + vector._client.delete.reset_mock() + vector._client.query.return_value = [{"id": 11}, {"id": 12}] + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._client.delete.assert_called_with(collection_name="collection_1", pks=[11, 12]) + + vector._client.has_collection.return_value = True + vector.delete() + vector._client.drop_collection.assert_called_once_with("collection_1", None) + + +def test_text_exists_and_field_exists(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._fields = ["content", "metadata"] + vector._client = MagicMock() + vector._client.has_collection.return_value = False + assert vector.text_exists("doc-1") is False + + vector._client.has_collection.return_value = True + vector._client.query.return_value = [{"id": 1}] + assert vector.text_exists("doc-1") is True + vector._client.query.return_value = [] + assert vector.text_exists("doc-1") is False + assert vector.field_exists("content") is True + assert vector.field_exists("unknown") is False + + +def test_process_search_results_and_search_methods(milvus_module): + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._fields = ["content", "metadata", "sparse_vector"] + + processed = vector._process_search_results( + [ + [ + {"entity": {"content": "doc-1", "metadata": {"doc_id": "1"}}, "distance": 0.9}, + {"entity": {"content": "doc-2", "metadata": {"doc_id": "2"}}, "distance": 0.2}, + ] + ], + [milvus_module.Field.CONTENT_KEY, milvus_module.Field.METADATA_KEY], + score_threshold=0.5, + ) + assert len(processed) == 1 + assert processed[0].metadata["score"] == 0.9 + + vector._client.search.return_value = [[{"entity": {"content": "doc"}, "distance": 0.8}]] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["a", "b"], score_threshold=0.1) + assert docs == ["doc"] + assert vector._client.search.call_args.kwargs["filter"] == 'metadata["document_id"] in ["a", "b"]' + + vector._hybrid_search_enabled = False + assert vector.search_by_full_text("query") == [] + + vector._hybrid_search_enabled = True + vector._fields = [] + assert vector.search_by_full_text("query") == [] + + vector._fields = [milvus_module.Field.SPARSE_VECTOR] + vector._process_search_results = MagicMock(return_value=["full-text-doc"]) + full_text_docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.2) + assert full_text_docs == ["full-text-doc"] + assert "document_id" in vector._client.search.call_args.kwargs["filter"] + + +def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client_config = _config(milvus_module) + vector._hybrid_search_enabled = False + vector._client = MagicMock() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.has_collection.return_value = True + vector.create_collection([[0.1, 0.2]], metadatas=[{"doc_id": "1"}], index_params={"index_type": "HNSW"}) + milvus_module.redis_client.set.assert_called() + + +def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(milvus_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(milvus_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(milvus_module.redis_client, "set", MagicMock()) + + vector = milvus_module.MilvusVector.__new__(milvus_module.MilvusVector) + vector._collection_name = "collection_1" + vector._consistency_level = "Session" + vector._client = MagicMock() + vector._client.has_collection.return_value = False + vector._load_collection_fields = MagicMock() + + vector._client_config = _config(milvus_module, analyzer_params='{"tokenizer":"standard"}') + vector._hybrid_search_enabled = True + vector.create_collection( + embeddings=[[0.1, 0.2]], + metadatas=[{"doc_id": "1"}], + index_params={"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8}}, + ) + + call_kwargs = vector._client.create_collection.call_args.kwargs + schema = call_kwargs["schema"] + index_params_obj = call_kwargs["index_params"] + field_names = [f.name for f in schema.fields] + + assert milvus_module.Field.SPARSE_VECTOR in field_names + assert len(schema.functions) == 1 + assert len(index_params_obj.indexes) == 2 + assert call_kwargs["consistency_level"] == "Session" + + +def test_factory_initializes_milvus_vector(milvus_module, monkeypatch): + factory = milvus_module.MilvusVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(milvus_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_URI", "http://localhost:19530") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_TOKEN", "") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_USER", "root") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_PASSWORD", "Milvus") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_DATABASE", "default") + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(milvus_module.dify_config, "MILVUS_ANALYZER_PARAMS", '{"tokenizer":"standard"}') + + with patch.object(milvus_module, "MilvusVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py new file mode 100644 index 0000000000..a75ba82238 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py @@ -0,0 +1,230 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_clickhouse_connect_module(): + clickhouse_connect = types.ModuleType("clickhouse_connect") + + class QueryResult: + def __init__(self, rows=None, named_rows=None): + self.row_count = len(rows or []) + self.result_rows = rows or [] + self._named_rows = named_rows or [] + + def named_results(self): + return self._named_rows + + class Client: + def __init__(self): + self.command = MagicMock() + self.query = MagicMock(return_value=QueryResult()) + + client = Client() + + def get_client(**_kwargs): + return client + + clickhouse_connect.get_client = get_client + clickhouse_connect.QueryResult = QueryResult + clickhouse_connect._fake_client = client + return clickhouse_connect + + +@pytest.fixture +def myscale_module(monkeypatch): + fake_module = _build_fake_clickhouse_connect_module() + monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) + + import core.rag.datasource.vdb.myscale.myscale_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="", + ) + + +def test_escape_str_replaces_backslash_and_quote(myscale_module): + escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special") + assert escaped == "text with special" + + +def test_search_raises_for_invalid_top_k(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=0) + + +def test_search_builds_where_clause_for_cosine_threshold(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = myscale_module.get_client().query.return_value.__class__( + named_rows=[{"text": "doc-1", "vector": [0.1, 0.2], "metadata": {"doc_id": "seg-1"}}] + ) + + docs = vector._search("distance(vector, [0.1, 0.2])", myscale_module.SortOrder.ASC, top_k=1, score_threshold=0.2) + + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "WHERE dist < 0.8" in sql + + +def test_delete_by_ids_short_circuits_on_empty_list(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete_by_ids([]) + vector._client.command.assert_not_called() + + +def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch): + factory = myscale_module.MyScaleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(myscale_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_HOST", "localhost") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PORT", 8123) + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_USER", "default") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_PASSWORD", "") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_DATABASE", "dify") + monkeypatch.setattr(myscale_module.dify_config, "MYSCALE_FTS_PARAMS", "") + + with patch.object(myscale_module, "MyScaleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +def test_init_and_get_type_set_expected_defaults(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + + assert vector.get_type() == "myscale" + assert vector._vec_order == myscale_module.SortOrder.ASC + vector._client.command.assert_called_with("SET allow_experimental_object_type=1") + + +def test_create_calls_create_collection_and_add_texts(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once() + + +def test_create_collection_builds_expected_sql(myscale_module): + config = myscale_module.MyScaleConfig( + host="localhost", + port=8123, + user="default", + password="", + database="dify", + fts_params="tokenizer=unicode", + ) + vector = myscale_module.MyScaleVector("collection_1", config) + vector._client.command.reset_mock() + + vector._create_collection(3) + + assert vector._client.command.call_count == 2 + sql = vector._client.command.call_args_list[1].args[0] + assert "CREATE TABLE IF NOT EXISTS dify.collection_1" in sql + assert "CONSTRAINT cons_vec_len CHECK length(vector) = 3" in sql + assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql + + +def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content=r"te'xt\1", metadata={"doc_id": "doc-a", "document_id": "d-1"}), + Document(page_content="text-2", metadata={"document_id": "d-2"}), + SimpleNamespace(page_content="text-3", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + sql = vector._client.command.call_args.args[0] + assert "INSERT INTO dify.collection_1" in sql + assert "te xt 1" in sql + + +def test_text_exists_and_metadata_operations(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace(row_count=1, result_rows=[("id-1",), ("id-2",)]) + + assert vector.text_exists("id-1") is True + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert vector._client.command.call_count >= 2 + + +def test_search_delegation_methods(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._search = MagicMock(return_value=["result"]) + + result_vector = vector.search_by_vector([0.1, 0.2], top_k=2) + result_text = vector.search_by_full_text("hello", top_k=2) + + assert result_vector == ["result"] + assert result_text == ["result"] + assert vector._search.call_count == 2 + + +def test_search_with_document_filter_and_exception(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.query.return_value = SimpleNamespace( + named_results=lambda: [{"text": "doc", "vector": [0.1], "metadata": {"doc_id": "1"}}] + ) + + docs = vector._search( + "distance(vector, [0.1])", + myscale_module.SortOrder.ASC, + top_k=2, + document_ids_filter=["doc-1", "doc-2"], + ) + assert len(docs) == 1 + sql = vector._client.query.call_args.args[0] + assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql + + vector._client.query.side_effect = RuntimeError("boom") + assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == [] + + +def test_delete_drops_table(myscale_module): + vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module)) + vector._client.command.reset_mock() + + vector.delete() + + vector._client.command.assert_called_once_with("DROP TABLE IF EXISTS dify.collection_1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py new file mode 100644 index 0000000000..27d8198ec0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py @@ -0,0 +1,553 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.exc import SQLAlchemyError + +from core.rag.models.document import Document + + +def _build_fake_pyobvector_module(): + pyobvector = types.ModuleType("pyobvector") + + class VECTOR: + def __init__(self, dim): + self.dim = dim + + def l2_distance(*_args, **_kwargs): + return "l2" + + def cosine_distance(*_args, **_kwargs): + return "cosine" + + def inner_product(*_args, **_kwargs): + return "inner_product" + + class ObVecClient: + def __init__(self, **_kwargs): + self.metadata_obj = SimpleNamespace(tables={}) + self.engine = MagicMock() + self.check_table_exists = MagicMock(return_value=False) + self.perform_raw_text_sql = MagicMock() + self.prepare_index_params = MagicMock() + self.create_table_with_index_params = MagicMock() + self.refresh_metadata = MagicMock() + self.insert = MagicMock() + self.refresh_index = MagicMock() + self.get = MagicMock() + self.delete = MagicMock() + self.set_ob_hnsw_ef_search = MagicMock() + self.ann_search = MagicMock(return_value=[]) + self.drop_table_if_exist = MagicMock() + + pyobvector.VECTOR = VECTOR + pyobvector.ObVecClient = ObVecClient + pyobvector.l2_distance = l2_distance + pyobvector.cosine_distance = cosine_distance + pyobvector.inner_product = inner_product + return pyobvector + + +@pytest.fixture +def oceanbase_module(monkeypatch): + monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) + + import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.OceanBaseVectorConfig( + host="127.0.0.1", + port=2881, + user="root", + password="secret", + database="test", + enable_hybrid_search=True, + batch_size=10, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OCEANBASE_VECTOR_HOST is required"), + ("port", 0, "config OCEANBASE_VECTOR_PORT is required"), + ("user", "", "config OCEANBASE_VECTOR_USER is required"), + ("database", "", "config OCEANBASE_VECTOR_DATABASE is required"), + ], +) +def test_oceanbase_config_validation(oceanbase_module, field, value, message): + values = _config(oceanbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oceanbase_module.OceanBaseVectorConfig.model_validate(values) + + +def test_init_rejects_invalid_collection_name(oceanbase_module): + with pytest.raises(ValueError, match="Invalid collection name"): + oceanbase_module.OceanBaseVector("invalid-name", _config(oceanbase_module)) + + +def test_distance_to_score_for_supported_metrics(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="l2") + assert vector._distance_to_score(3.0) == pytest.approx(0.25) + + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._distance_to_score(0.2) == pytest.approx(0.8) + + vector._config = SimpleNamespace(metric_type="inner_product") + assert vector._distance_to_score(-0.2) == pytest.approx(0.2) + + +def test_get_distance_func_raises_for_unknown_metric(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="manhattan") + + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._get_distance_func() + + +def test_process_search_results_handles_json_and_score_threshold(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + rows = [ + ("doc-1", '{"doc_id":"1"}', 0.9), + ("doc-2", "not-json", 0.8), + ("doc-3", {"doc_id": "3"}, 0.3), + ] + + docs = vector._process_search_results(rows, score_threshold=0.5, score_key="rank") + + assert len(docs) == 2 + assert docs[0].metadata["doc_id"] == "1" + assert docs[0].metadata["rank"] == 0.9 + assert docs[1].metadata["rank"] == 0.8 + + +def test_search_by_vector_validates_document_id_format(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + + with pytest.raises(ValueError, match="Invalid document ID format"): + vector.search_by_vector([0.1, 0.2], document_ids_filter=["bad id"]) + + +def test_search_by_full_text_returns_empty_when_disabled(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._hybrid_search_enabled = False + vector._collection_name = "collection_1" + + assert vector.search_by_full_text("query") == [] + + +def test_check_hybrid_search_support_uses_version_comment(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client = MagicMock() + cursor = MagicMock() + cursor.fetchone.return_value = ("OceanBase_CE 4.3.5.1 (rxxxxxxxxx) (Built Mar 18 2025)",) + vector._client.perform_raw_text_sql.return_value = cursor + + assert vector._check_hybrid_search_support() is True + + cursor.fetchone.return_value = ("OceanBase_CE 4.3.4.0 (rxxxxxxxxx) (Built Mar 18 2025)",) + assert vector._check_hybrid_search_support() is False + + +def test_init_get_type_and_field_loading(oceanbase_module): + config = _config(oceanbase_module) + config.enable_hybrid_search = False + + table = SimpleNamespace(columns=[SimpleNamespace(name="id"), SimpleNamespace(name="text")]) + fake_client = oceanbase_module.ObVecClient() + fake_client.check_table_exists.return_value = True + fake_client.metadata_obj.tables = {"collection_1": table} + + with patch.object(oceanbase_module, "ObVecClient", return_value=fake_client): + vector = oceanbase_module.OceanBaseVector("collection_1", config) + + assert vector.get_type() == "oceanbase" + assert vector.field_exists("text") is True + + +def test_load_collection_fields_handles_missing_table_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._fields = [] + vector._client = MagicMock() + vector._client.metadata_obj.tables = {} + + vector._load_collection_fields() + assert vector._fields == [] + + vector._client.metadata_obj.tables = {"collection_1": MagicMock(columns=MagicMock(side_effect=RuntimeError("x")))} + vector._load_collection_fields() + assert vector._fields == [] + + +def test_create_delegates_to_collection_and_insert(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector._vec_dim == 2 + vector._create_collection.assert_called_once() + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = False + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection() + vector._client.check_table_exists.assert_not_called() + + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.check_table_exists.return_value = True + vector._create_collection() + vector.delete.assert_not_called() + + +def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 3 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + None, + None, + ] + index_params = MagicMock() + vector._client.prepare_index_params.return_value = index_params + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._create_collection() + + vector.delete.assert_called_once() + vector._client.create_table_with_index_params.assert_called_once() + index_params.add_index.assert_called_once() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + oceanbase_module.redis_client.set.assert_called_once() + + +def test_create_collection_error_paths(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.return_value = [] + with pytest.raises(ValueError, match="ob_vector_memory_limit_percentage not found"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "0"]], + RuntimeError("no privilege"), + ] + with pytest.raises(Exception, match="Failed to set ob_vector_memory_limit_percentage"): + vector._create_collection() + + vector._client.perform_raw_text_sql.side_effect = [[[None, None, None, None, None, None, "30"]]] + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "not-valid") + with pytest.raises(ValueError, match="Invalid OceanBase full-text parser"): + vector._create_collection() + + +def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oceanbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oceanbase_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(oceanbase_module.redis_client, "set", MagicMock()) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_FULLTEXT_PARSER", "ik") + monkeypatch.setattr(oceanbase_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr(oceanbase_module, "VECTOR", lambda dim: SimpleNamespace(dim=dim)) + + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._vec_dim = 2 + vector._hybrid_search_enabled = True + vector._config = SimpleNamespace(metric_type="cosine", hnsw_m=16, hnsw_ef_construction=64) + vector._client = MagicMock() + vector._client.check_table_exists.return_value = False + vector._client.prepare_index_params.return_value = MagicMock() + vector.delete = MagicMock() + vector._load_collection_fields = MagicMock() + + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + RuntimeError("fulltext failed"), + ] + with pytest.raises(Exception, match="Failed to add fulltext index"): + vector._create_collection() + + vector._hybrid_search_enabled = False + vector._client.perform_raw_text_sql.side_effect = [ + [[None, None, None, None, None, None, "30"]], + SQLAlchemyError("metadata index failed"), + ] + vector._create_collection() + vector._client.refresh_metadata.assert_called_once_with(["collection_1"]) + + +def test_check_hybrid_search_support_false_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(enable_hybrid_search=False) + vector._client = MagicMock() + assert vector._check_hybrid_search_support() is False + + vector._config = SimpleNamespace(enable_hybrid_search=True) + vector._client.perform_raw_text_sql.side_effect = RuntimeError("boom") + assert vector._check_hybrid_search_support() is False + + +def test_add_texts_batches_refresh_and_exceptions(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._config = SimpleNamespace(batch_size=2, hnsw_refresh_threshold=2) + vector._client = MagicMock() + vector._get_uuids = MagicMock(return_value=["id-1", "id-2", "id-3"]) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + Document(page_content="c", metadata={"doc_id": "id-3"}), + ] + + vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert vector._client.insert.call_count == 2 + vector._client.refresh_index.assert_called_once() + + vector._client.insert.reset_mock() + vector._client.refresh_index.reset_mock() + vector._client.insert.side_effect = RuntimeError("insert failed") + with pytest.raises(Exception, match="Failed to insert batch"): + vector.add_texts([docs[0]], [[0.1]]) + + vector._client.insert.side_effect = None + vector._client.insert.return_value = None + vector._client.refresh_index.side_effect = SQLAlchemyError("refresh failed") + vector._config = SimpleNamespace(batch_size=10, hnsw_refresh_threshold=1) + vector._get_uuids.return_value = ["id-1"] + vector.add_texts([docs[0]], [[0.1]]) + + +def test_text_exists_and_delete_by_ids(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + vector._client.get.return_value = SimpleNamespace(rowcount=1) + assert vector.text_exists("id-1") is True + + vector._client.get.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to check text existence"): + vector.text_exists("id-1") + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + + vector._client.delete.side_effect = None + vector.delete_by_ids(["id-1"]) + vector._client.delete.assert_called_once() + + vector._client.delete.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete documents"): + vector.delete_by_ids(["id-1"]) + + +def test_get_ids_and_delete_by_metadata_field(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + execute_result = [("id-1",), ("id-2",)] + + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value = execute_result + vector._client.engine.connect.return_value = conn + + ids = vector.get_ids_by_metadata_field("document_id", "doc-1") + assert ids == ["id-1", "id-2"] + + with pytest.raises(Exception, match="Failed to query documents by metadata field"): + vector.get_ids_by_metadata_field("bad key!", "doc-1") + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector.delete_by_ids.reset_mock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_not_called() + + +def test_search_by_full_text_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hybrid_search_enabled = True + vector.field_exists = MagicMock(return_value=False) + + assert vector.search_by_full_text("query") == [] + + vector.field_exists.return_value = True + vector._client = MagicMock() + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = tx + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = [("text-1", '{"doc_id":"1"}', 0.9)] + vector._client.engine.connect.return_value = conn + + docs = vector.search_by_full_text("query", top_k=2, document_ids_filter=["d-1"], score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + with pytest.raises(Exception, match="Full-text search failed"): + vector.search_by_full_text("query", top_k=0) + + +def test_search_by_vector_paths(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._hnsw_ef_search = -1 + vector._config = SimpleNamespace(metric_type="cosine") + vector._client = MagicMock() + vector._client.ann_search.return_value = [("doc-1", '{"doc_id":"1"}', 0.2)] + vector._process_search_results = MagicMock(return_value=["doc"]) + + docs = vector.search_by_vector( + [0.1, 0.2], + ef_search=10, + top_k=3, + score_threshold=0.1, + document_ids_filter=["good_id"], + ) + assert docs == ["doc"] + vector._client.set_ob_hnsw_ef_search.assert_called_once_with(10) + + with pytest.raises(ValueError, match="Invalid score_threshold parameter"): + vector.search_by_vector([0.1], score_threshold="x") + + vector._client.ann_search.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Vector search failed"): + vector.search_by_vector([0.1], score_threshold=0.1) + + +def test_get_distance_func_and_distance_to_score_errors(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._config = SimpleNamespace(metric_type="cosine") + assert vector._get_distance_func() is oceanbase_module.cosine_distance + + vector._config = SimpleNamespace(metric_type="unknown") + with pytest.raises(ValueError, match="Unsupported metric_type"): + vector._distance_to_score(0.1) + + +def test_delete_success_and_exception(oceanbase_module): + vector = oceanbase_module.OceanBaseVector.__new__(oceanbase_module.OceanBaseVector) + vector._collection_name = "collection_1" + vector._client = MagicMock() + + vector.delete() + vector._client.drop_table_if_exist.assert_called_once_with("collection_1") + + vector._client.drop_table_if_exist.side_effect = RuntimeError("boom") + with pytest.raises(Exception, match="Failed to delete collection"): + vector.delete() + + +def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch): + factory = oceanbase_module.OceanBaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oceanbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_HOST", "127.0.0.1") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PORT", 2881) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_USER", "root") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_PASSWORD", "password") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_DATABASE", "test") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_ENABLE_HYBRID_SEARCH", True) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_BATCH_SIZE", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_METRIC_TYPE", "cosine") + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_M", 16) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_CONSTRUCTION", 64) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_EF_SEARCH", -1) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_POOL_SIZE", 5) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_VECTOR_MAX_OVERFLOW", 10) + monkeypatch.setattr(oceanbase_module.dify_config, "OCEANBASE_HNSW_REFRESH_THRESHOLD", 1000) + + with patch.object(oceanbase_module, "OceanBaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].args[0] == "existing_collection" + assert vector_cls.call_args_list[1].args[0] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py new file mode 100644 index 0000000000..6641dbe4a0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py @@ -0,0 +1,400 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def opengauss_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opengauss.opengauss as module + + return importlib.reload(module) + + +def _config(module, *, enable_pq=False): + return module.OpenGaussConfig( + host="localhost", + port=6600, + user="postgres", + password="password", + database="dify", + min_connection=1, + max_connection=5, + enable_pq=enable_pq, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENGAUSS_HOST is required"), + ("port", 0, "config OPENGAUSS_PORT is required"), + ("user", "", "config OPENGAUSS_USER is required"), + ("password", "", "config OPENGAUSS_PASSWORD is required"), + ("database", "", "config OPENGAUSS_DATABASE is required"), + ("min_connection", 0, "config OPENGAUSS_MIN_CONNECTION is required"), + ("max_connection", 0, "config OPENGAUSS_MAX_CONNECTION is required"), + ], +) +def test_opengauss_config_validation(opengauss_module, field, value, message): + values = _config(opengauss_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_module): + values = _config(opengauss_module).model_dump() + values["min_connection"] = 6 + values["max_connection"] = 5 + + with pytest.raises(ValidationError, match="OPENGAUSS_MIN_CONNECTION should less than OPENGAUSS_MAX_CONNECTION"): + opengauss_module.OpenGaussConfig.model_validate(values) + + +def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + + assert vector.table_name == "embedding_collection_1" + assert vector.get_type() == "opengauss" + assert vector.pool is pool + + +def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=True)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("enable_pq=on" in sql for sql in executed_sql) + assert any("SET hnsw_earlystop_threshold = 320" in sql for sql in executed_sql) + opengauss_module.redis_client.set.assert_called_once() + + +def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(3072) + + cursor.execute.assert_not_called() + opengauss_module.redis_client.set.assert_called_once() + + +def test_search_by_vector_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + +def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector.delete_by_ids([]) + + vector._get_cursor.assert_not_called() + + +def test_get_cursor_closes_commits_and_returns_connection(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + pool = MagicMock() + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + vector.pool = pool + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_calls_collection_insert_and_index(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + vector._create_index = MagicMock() + docs = [Document(page_content="text", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + vector._create_index.assert_called_once_with(2) + + +def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + vector._get_cursor = MagicMock() + + vector._create_index(1536) + + vector._get_cursor.assert_not_called() + opengauss_module.redis_client.set.assert_not_called() + + +def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module, enable_pq=False)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector._create_index(1536) + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql) + + +def test_add_texts_uses_execute_values(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + opengauss_module.psycopg2.extras.execute_values.reset_mock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + docs = [ + Document(page_content="text-1", metadata={"doc_id": "seg-1", "document_id": "d-1"}), + SimpleNamespace(page_content="text-2", metadata=None), + ] + monkeypatch.setattr(opengauss_module.uuid, "uuid4", lambda: "generated-uuid") + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["seg-1"] + opengauss_module.psycopg2.extras.execute_values.assert_called_once() + + +def test_text_exists_and_get_by_ids(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("seg-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("seg-1") is True + docs = vector.get_by_ids(["seg-1", "seg-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + +def test_delete_and_metadata_field_queries(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + vector.delete_by_ids(["seg-1", "seg-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in query for query in sql) + assert any("meta->>%s = %s" in query for query in sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in query for query in sql) + + +def test_search_by_vector_and_full_text(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.6), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.8)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_search_by_full_text_validates_top_k(opengauss_module): + vector = opengauss_module.OpenGauss.__new__(opengauss_module.OpenGauss) + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("query", top_k=0) + + +def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opengauss_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opengauss_module.redis_client, "set", MagicMock()) + + vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module)) + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(1536) + cursor.execute.assert_not_called() + + monkeypatch.setattr(opengauss_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(1536) + cursor.execute.assert_called_once() + opengauss_module.redis_client.set.assert_called_once() + + +def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch): + factory = opengauss_module.OpenGaussFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opengauss_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_HOST", "localhost") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PORT", 6600) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_USER", "postgres") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_PASSWORD", "password") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_DATABASE", "dify") + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MIN_CONNECTION", 1) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_MAX_CONNECTION", 5) + monkeypatch.setattr(opengauss_module.dify_config, "OPENGAUSS_ENABLE_PQ", False) + + with patch.object(opengauss_module, "OpenGauss", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py new file mode 100644 index 0000000000..1030158dd1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py @@ -0,0 +1,360 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_opensearch_modules(): + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "basic", + "user": "admin", + "password": "secret", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config OPENSEARCH_HOST is required"), + ("port", 0, "config OPENSEARCH_PORT is required"), + ], +) +def test_config_validation_required_fields(opensearch_module, field, value, message): + values = _config(opensearch_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_config_validation_for_aws_auth_and_https_fields(opensearch_module): + values = { + "host": "localhost", + "port": 9200, + "secure": True, + "verify_certs": True, + "auth_method": "aws_managed_iam", + "user": "admin", + "password": "secret", + } + with pytest.raises(ValidationError, match="OPENSEARCH_AWS_REGION"): + opensearch_module.OpenSearchConfig.model_validate(values) + + values = _config(opensearch_module).model_dump() + values["OPENSEARCH_SECURE"] = False + values["OPENSEARCH_VERIFY_CERTS"] = True + with pytest.raises(ValidationError, match="verify_certs=True requires secure"): + opensearch_module.OpenSearchConfig.model_validate(values) + + +def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-east-1", + aws_service="es", + ) + auth = config.create_aws_managed_iam_auth() + + assert auth.credentials == "creds" + assert auth.region == "us-east-1" + assert auth.service == "es" + + +def test_to_opensearch_params_supports_basic_and_aws(opensearch_module): + basic_params = _config(opensearch_module).to_opensearch_params() + assert basic_params["http_auth"] == ("admin", "secret") + + aws_config = _config( + opensearch_module, + auth_method="aws_managed_iam", + aws_region="us-west-2", + aws_service="es", + ) + with patch.object(opensearch_module.OpenSearchConfig, "create_aws_managed_iam_auth", return_value="iam-auth"): + aws_params = aws_config.to_opensearch_params() + + assert aws_params["http_auth"] == "iam-auth" + + +def test_init_and_create_delegate_calls(opensearch_module): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == "opensearch" + vector.create_collection.assert_called_once_with([[0.1, 0.2]], [{"doc_id": "seg-1"}]) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch): + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es")) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + monkeypatch.setattr(opensearch_module, "uuid4", lambda: SimpleNamespace(hex="generated-id")) + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.1], [0.2]]) + actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert len(actions) == 2 + assert all("_id" in action for action in actions) + + vector._client_config.aws_service = "aoss" + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts(docs, [[0.3], [0.4]]) + aoss_actions = opensearch_module.helpers.bulk.call_args.kwargs["actions"] + assert all("_id" not in action for action in aoss_actions) + + +def test_metadata_lookup_and_delete_by_metadata_field(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = {"hits": {"hits": [{"_id": "id-1"}, {"_id": "id-2"}]}} + + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + vector._client.search.return_value = {"hits": {"hits": []}} + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-1"]) + + +def test_delete_by_ids_branches_and_bulk_error_handling(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + opensearch_module.helpers.bulk.reset_mock() + vector._client.indices.exists.return_value = False + vector.delete_by_ids(["doc-1"]) + opensearch_module.helpers.bulk.assert_not_called() + + vector._client.indices.exists.return_value = True + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["es-1"], None]) + vector.delete_by_ids(["doc-1", "doc-2"]) + opensearch_module.helpers.bulk.assert_called_once() + + opensearch_module.helpers.bulk.reset_mock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["es-404"]) + opensearch_module.helpers.bulk.side_effect = opensearch_module.BulkIndexError( + [{"delete": {"status": 404, "_id": "es-404"}}] + ) + vector.delete_by_ids(["doc-404"]) + assert opensearch_module.helpers.bulk.call_count == 1 + + opensearch_module.helpers.bulk.side_effect = None + + +def test_delete_and_text_exists(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector.delete() + vector._client.indices.delete.assert_called_once_with(index="collection_1", ignore_unavailable=True) + + vector._client.get.return_value = {"_id": "id-1"} + assert vector.text_exists("id-1") is True + vector._client.get.side_effect = RuntimeError("not found") + assert vector.text_exists("id-1") is False + + +def test_search_by_vector_validates_and_builds_documents(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + + with pytest.raises(ValueError, match="query_vector should be a list"): + vector.search_by_vector("not-a-list") + + with pytest.raises(ValueError, match="should be floats"): + vector.search_by_vector([0.1, 1]) + + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-1", + opensearch_module.Field.METADATA_KEY: None, + }, + "_score": 0.9, + }, + { + "_source": { + opensearch_module.Field.CONTENT_KEY: "doc-2", + opensearch_module.Field.METADATA_KEY: {"doc_id": "2"}, + }, + "_score": 0.1, + }, + ] + } + } + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].page_content == "doc-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + vector.search_by_vector([0.1, 0.2], top_k=3, document_ids_filter=["doc-a", "doc-b"]) + query = vector._client.search.call_args.kwargs["body"] + assert "script_score" in query["query"] + + +def test_search_by_vector_reraises_client_error(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + vector.search_by_vector([0.1, 0.2]) + + +def test_search_by_full_text_and_filters(opensearch_module): + vector = opensearch_module.OpenSearchVector("collection_1", _config(opensearch_module)) + vector._client.search.return_value = { + "hits": { + "hits": [ + { + "_source": { + opensearch_module.Field.METADATA_KEY: {"doc_id": "1"}, + opensearch_module.Field.VECTOR: [0.1], + opensearch_module.Field.CONTENT_KEY: "matched text", + } + }, + ] + } + } + + docs = vector.search_by_full_text("hello", document_ids_filter=["d-1"]) + + assert len(docs) == 1 + assert docs[0].page_content == "matched text" + query = vector._client.search.call_args.kwargs["body"] + assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}] + + +def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(opensearch_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(opensearch_module.redis_client, "set", MagicMock()) + + vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module)) + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=1)) + vector._client.indices.create.reset_mock() + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_not_called() + + monkeypatch.setattr(opensearch_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.indices.exists.return_value = False + vector.create_collection([[0.1, 0.2]]) + vector._client.indices.create.assert_called_once() + index_body = vector._client.indices.create.call_args.kwargs["body"] + assert index_body["mappings"]["properties"]["vector"]["dimension"] == 2 + opensearch_module.redis_client.set.assert_called() + + +def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch): + factory = opensearch_module.OpenSearchVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(opensearch_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_HOST", "localhost") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PORT", 9200) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_SECURE", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_VERIFY_CERTS", True) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AUTH_METHOD", "basic") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_USER", "admin") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_PASSWORD", "secret") + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_REGION", None) + monkeypatch.setattr(opensearch_module.dify_config, "OPENSEARCH_AWS_SERVICE", None) + + with patch.object(opensearch_module, "OpenSearchVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py new file mode 100644 index 0000000000..817a7d342b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py @@ -0,0 +1,375 @@ +import array +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_oracle_modules(): + jieba = types.ModuleType("jieba") + jieba_posseg = types.ModuleType("jieba.posseg") + jieba_posseg.cut = MagicMock(return_value=[]) + jieba.posseg = jieba_posseg + + oracledb = types.ModuleType("oracledb") + oracledb_connection = types.ModuleType("oracledb.connection") + + class Connection: + pass + + oracledb_connection.Connection = Connection + oracledb.defaults = SimpleNamespace(fetch_lobs=True) + oracledb.DB_TYPE_VECTOR = object() + oracledb.create_pool = MagicMock(return_value=MagicMock(release=MagicMock())) + oracledb.connect = MagicMock() + + return { + "jieba": jieba, + "jieba.posseg": jieba_posseg, + "oracledb": oracledb, + "oracledb.connection": oracledb_connection, + } + + +def _connection_with_cursor(cursor): + cursor_ctx = MagicMock() + cursor_ctx.__enter__.return_value = cursor + cursor_ctx.__exit__.return_value = None + + connection = MagicMock() + connection.__enter__.return_value = connection + connection.__exit__.return_value = None + connection.cursor.return_value = cursor_ctx + return connection + + +@pytest.fixture +def oracle_module(monkeypatch): + for name, module in _build_fake_oracle_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.oracle.oraclevector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "user": "system", + "password": "oracle", + "dsn": "oracle:1521/freepdb1", + "is_autonomous": False, + } + values.update(overrides) + return module.OracleVectorConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("user", "", "config ORACLE_USER is required"), + ("password", "", "config ORACLE_PASSWORD is required"), + ("dsn", "", "config ORACLE_DSN is required"), + ], +) +def test_oracle_config_validation_required_fields(oracle_module, field, value, message): + values = _config(oracle_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + oracle_module.OracleVectorConfig.model_validate(values) + + +def test_oracle_config_validation_autonomous_requirements(oracle_module): + with pytest.raises(ValidationError, match="config_dir is required"): + oracle_module.OracleVectorConfig.model_validate( + {"user": "u", "password": "p", "dsn": "d", "is_autonomous": True} + ) + + +def test_init_and_get_type(oracle_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool)) + vector = oracle_module.OracleVector("collection_1", _config(oracle_module)) + + assert vector.get_type() == "oracle" + assert vector.table_name == "embedding_collection_1" + assert vector.pool is pool + + +def test_numpy_converters_and_type_handlers(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + + in_float64 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float64)) + in_float32 = vector.numpy_converter_in(numpy.array([0.1], dtype=numpy.float32)) + in_int8 = vector.numpy_converter_in(numpy.array([1], dtype=numpy.int8)) + assert in_float64.typecode == "d" + assert in_float32.typecode == "f" + assert in_int8.typecode == "b" + + cursor = MagicMock() + vector.input_type_handler(cursor, numpy.array([0.1], dtype=numpy.float32), 2) + cursor.var.assert_called_with( + oracle_module.oracledb.DB_TYPE_VECTOR, + arraysize=2, + inconverter=vector.numpy_converter_in, + ) + + metadata = SimpleNamespace(type_code=oracle_module.oracledb.DB_TYPE_VECTOR) + cursor.arraysize = 3 + vector.output_type_handler(cursor, metadata) + cursor.var.assert_called_with( + metadata.type_code, + arraysize=3, + outconverter=vector.numpy_converter_out, + ) + + out_int8 = vector.numpy_converter_out(array.array("b", [1])) + assert out_int8.dtype == numpy.int8 + out_float32 = vector.numpy_converter_out(array.array("f", [1.0])) + assert out_float32.dtype == numpy.float32 + out_float64 = vector.numpy_converter_out(array.array("d", [1.0])) + assert out_float64.dtype == numpy.float64 + + +def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch): + connect = MagicMock(return_value="connection") + monkeypatch.setattr(oracle_module.oracledb, "connect", connect) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.config = _config(oracle_module) + assert vector._get_connection() == "connection" + connect.assert_called_with(user="system", password="oracle", dsn="oracle:1521/freepdb1") + + vector.config = _config( + oracle_module, + is_autonomous=True, + config_dir="/wallet", + wallet_location="/wallet", + wallet_password="pw", + ) + vector._get_connection() + assert connect.call_args.kwargs["config_dir"] == "/wallet" + assert connect.call_args.kwargs["wallet_location"] == "/wallet" + + +def test_create_delegates_collection_and_insert(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["seg-1"]) + docs = [Document(page_content="doc", metadata={"doc_id": "seg-1"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["seg-1"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.execute.side_effect = [None, RuntimeError("insert failed")] + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + monkeypatch.setattr(oracle_module.uuid, "uuid4", lambda: "generated-uuid") + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + assert cursor.execute.call_count == 2 + assert connection.commit.call_count >= 1 + connection.close.assert_called() + + +def test_text_exists_and_get_by_ids(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.pool = MagicMock() + + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + vector.pool.release.assert_called_once() + assert vector.get_by_ids([]) == [] + + +def test_delete_methods(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + vector.delete_by_ids([]) + vector._get_connection.assert_not_called() + + vector.delete_by_ids(["id-1", "id-2"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN" in sql for sql in executed_sql) + assert any("JSON_VALUE(meta" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_with_threshold_and_filter(oracle_module): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector.input_type_handler = MagicMock() + vector.output_type_handler = MagicMock() + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "doc-1", 0.1), ({"doc_id": "2"}, "doc-2", 0.8)]) + connection = _connection_with_cursor(cursor) + vector._get_connection = MagicMock(return_value=connection) + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=0, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "fetch first 4 rows only" in sql + assert "JSON_VALUE(meta, '$.document_id') IN (:2, :3)" in sql + + +def _fake_nltk_module(*, missing_data=False): + nltk = types.ModuleType("nltk") + nltk_corpus = types.ModuleType("nltk.corpus") + + class _Data: + @staticmethod + def find(_path): + if missing_data: + raise LookupError("missing") + return True + + nltk.data = _Data() + nltk.word_tokenize = lambda text: text.split() + nltk_corpus.stopwords = SimpleNamespace(words=lambda _lang: ["and", "the"]) + return nltk, nltk_corpus + + +def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", [0.1, 0.2])]) + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.pseg, "cut", MagicMock(return_value=[("张", "nr"), ("三", "nr"), ("。", "x")])) + zh_docs = vector.search_by_full_text("张三", top_k=2) + assert len(zh_docs) == 1 + zh_params = cursor.execute.call_args.args[1] + assert zh_params["kk"] == "张三" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=False) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", [0.3, 0.4])]) + en_docs = vector.search_by_full_text("alice and bob", top_k=-1, document_ids_filter=["d-1"]) + assert len(en_docs) == 1 + en_sql = cursor.execute.call_args.args[0] + en_params = cursor.execute.call_args.args[1] + assert "fetch first 5 rows only" in en_sql + assert "doc_id_0" in en_params + + +def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch): + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector.table_name = "embedding_collection_1" + vector._get_connection = MagicMock() + + empty_result = vector.search_by_full_text("") + assert empty_result[0].page_content == "" + + nltk, nltk_corpus = _fake_nltk_module(missing_data=True) + monkeypatch.setitem(sys.modules, "nltk", nltk) + monkeypatch.setitem(sys.modules, "nltk.corpus", nltk_corpus) + with pytest.raises(LookupError, match="required NLTK data package"): + vector.search_by_full_text("english query") + + +def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(oracle_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(oracle_module.redis_client, "set", MagicMock()) + + vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + vector._get_connection = MagicMock(return_value=_connection_with_cursor(cursor)) + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(2) + cursor.execute.assert_not_called() + + monkeypatch.setattr(oracle_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(2) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert any("CREATE INDEX IF NOT EXISTS idx_docs_embedding_collection_1" in sql for sql in executed_sql) + oracle_module.redis_client.set.assert_called_once() + + +def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch): + factory = oracle_module.OracleVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(oracle_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_USER", "system") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_PASSWORD", "oracle") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_DSN", "oracle:1521/freepdb1") + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_CONFIG_DIR", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_LOCATION", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_WALLET_PASSWORD", None) + monkeypatch.setattr(oracle_module.dify_config, "ORACLE_IS_AUTONOMOUS", False) + + with patch.object(oracle_module, "OracleVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py new file mode 100644 index 0000000000..1aec81b8ac --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py @@ -0,0 +1,317 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_pgvecto_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSessionContext: + def __init__(self, calls, execute_results=None): + self.calls = calls + self.execute_results = execute_results or [] + self.execute = MagicMock(side_effect=self._execute_side_effect) + self.commit = MagicMock() + + def _execute_side_effect(self, *args, **kwargs): + self.calls.append((args, kwargs)) + if self.execute_results: + return self.execute_results.pop(0) + return MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +def _session_factory(calls, execute_results=None): + def _session(_client): + return _FakeSessionContext(calls=calls, execute_results=execute_results) + + return _session + + +@pytest.fixture +def pgvecto_module(monkeypatch): + for name, module in _build_fake_pgvecto_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module + import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + + return importlib.reload(module), importlib.reload(collection_module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "postgres", + } + values.update(overrides) + return module.PgvectoRSConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config PGVECTO_RS_HOST is required"), + ("port", 0, "config PGVECTO_RS_PORT is required"), + ("user", "", "config PGVECTO_RS_USER is required"), + ("password", "", "config PGVECTO_RS_PASSWORD is required"), + ("database", "", "config PGVECTO_RS_DATABASE is required"), + ], +) +def test_pgvecto_config_validation(pgvecto_module, field, value, message): + module, _ = pgvecto_module + values = _config(module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + module.PgvectoRSConfig.model_validate(values) + + +def test_collection_base_has_expected_annotations(pgvecto_module): + _, collection_module = pgvecto_module + annotations = collection_module.CollectionORM.__annotations__ + assert {"id", "text", "meta", "vector"} <= set(annotations) + + +def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "1"})] + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == module.VectorType.PGVECTO_RS + module.create_engine.assert_called_once_with("postgresql+psycopg2://postgres:secret@localhost:5432/postgres") + assert any("CREATE EXTENSION IF NOT EXISTS vectors" in str(args[0]) for args, _ in session_calls) + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + session_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(module.redis_client, "set", MagicMock()) + + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection(3) + assert not any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + + monkeypatch.setattr(module.redis_client, "get", MagicMock(return_value=None)) + vector.create_collection(3) + assert any("CREATE TABLE IF NOT EXISTS collection_1" in str(args[0]) for args, _ in session_calls) + assert any("CREATE INDEX IF NOT EXISTS collection_1_embedding_index" in str(args[0]) for args, _ in session_calls) + module.redis_client.set.assert_called() + + +def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + runtime_calls = [] + execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] + + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + + class _InsertBuilder: + def __init__(self, table): + self.table = table + + def values(self, **kwargs): + return ("insert", kwargs) + + monkeypatch.setattr(module, "insert", lambda table: _InsertBuilder(table)) + monkeypatch.setattr(module, "uuid4", MagicMock(side_effect=["uuid-1", "uuid-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "1"}), + Document(page_content="b", metadata={"doc_id": "2"}), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["uuid-1", "uuid-2"] + assert any(call[0][0][0] == "insert" for call in runtime_calls if call[0]) + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] + + monkeypatch.setattr( + module, + "Session", + _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [])]), + ) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("row-id-1",)]), + MagicMock(), + ], + ), + ) + vector.delete_by_ids(["doc-1"]) + assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) + assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) + + runtime_calls.clear() + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + vector.delete() + assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) + + +def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + init_calls = [] + monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) + monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + vector = module.PGVectoRS("collection_1", _config(module), dim=3) + + runtime_calls = [] + monkeypatch.setattr( + module, + "Session", + _session_factory( + runtime_calls, + execute_results=[ + SimpleNamespace(fetchall=lambda: [("id-1",)]), + SimpleNamespace(fetchall=lambda: []), + ], + ), + ) + assert vector.text_exists("doc-1") is True + assert vector.text_exists("doc-1") is False + + class _DistanceExpr: + def label(self, _name): + return self + + class _VectorColumn: + def op(self, _operator, return_type=None): + def _call(_query_vector): + return _DistanceExpr() + + return _call + + class _MetaFilter: + def in_(self, values): + return ("in", values) + + class _MetaColumn: + def __getitem__(self, _item): + return _MetaFilter() + + class _Stmt: + def __init__(self): + self.where_called = False + + def limit(self, _value): + return self + + def order_by(self, _value): + return self + + def where(self, _value): + self.where_called = True + return self + + stmt = _Stmt() + monkeypatch.setattr(module, "select", lambda *_args: stmt) + + vector._table = SimpleNamespace(vector=_VectorColumn(), meta=_MetaColumn()) + rows = [ + (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), + (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), + ] + monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + assert stmt.where_called is True + assert vector.search_by_full_text("hello") == [] + + +def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch): + module, _ = pgvecto_module + factory = module.PGVectoRSFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_HOST", "localhost") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PORT", 5432) + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_USER", "postgres") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_PASSWORD", "secret") + monkeypatch.setattr(module.dify_config, "PGVECTO_RS_DATABASE", "postgres") + + embeddings = MagicMock() + embeddings.embed_query.return_value = [0.1, 0.2, 0.3] + + with patch.object(module, "PGVectoRS", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=embeddings) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=embeddings) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py index 4998a9858f..7505262eb7 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py @@ -1,16 +1,19 @@ -import unittest +from contextlib import contextmanager +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module from core.rag.datasource.vdb.pgvector.pgvector import ( PGVector, PGVectorConfig, ) +from core.rag.models.document import Document -class TestPGVector(unittest.TestCase): - def setUp(self): +class TestPGVector: + def setup_method(self, method): self.config = PGVectorConfig( host="localhost", port=5432, @@ -323,5 +326,172 @@ def test_config_validation_parametrized(invalid_config_override): PGVectorConfig(**config) -if __name__ == "__main__": - unittest.main() +def test_create_delegates_collection_creation_and_insert(): + vector = PGVector.__new__(PGVector) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock(return_value=["doc-a"]) + docs = [Document(page_content="hello", metadata={"doc_id": "doc-a"})] + + result = vector.create(docs, [[0.1, 0.2]]) + + assert result == ["doc-a"] + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(pgvector_module.uuid, "uuid4", lambda: "generated-uuid") + execute_values = MagicMock() + monkeypatch.setattr(pgvector_module.psycopg2.extras, "execute_values", execute_values) + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + + assert ids == ["doc-a", "generated-uuid"] + execute_values.assert_called_once() + + +def test_text_get_and_delete_methods(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + vector.delete_by_ids([]) + cursor.execute.assert_not_called() + + class _UndefinedTableError(Exception): + pass + + monkeypatch.setattr(pgvector_module.psycopg2.errors, "UndefinedTable", _UndefinedTableError) + cursor.execute.side_effect = _UndefinedTableError("missing") + vector.delete_by_ids(["doc-1"]) + + cursor.execute.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError, match="boom"): + vector.delete_by_ids(["doc-1"]) + + +def test_search_by_vector_supports_filter_and_threshold(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.1), ({"doc_id": "2"}, "text-2", 0.8)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + sql = cursor.execute.call_args.args[0] + assert "meta->>'document_id' in ('d-1')" in sql + + +def test_search_by_full_text_branches_for_bigm_and_standard(): + vector = PGVector.__new__(PGVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1", 0.7)]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + vector.pg_bigm = False + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.7) + standard_sql = cursor.execute.call_args.args[0] + assert "to_tsvector(text) @@ plainto_tsquery(%s)" in standard_sql + + cursor.execute.reset_mock() + cursor.__iter__.return_value = iter([({"doc_id": "2"}, "text-2", 0.6)]) + vector.pg_bigm = True + vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-2"]) + assert "SET pg_bigm.similarity_limit TO 0.000001" in cursor.execute.call_args_list[0].args[0] + assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0] + + +def test_pgvector_factory_initializes_expected_collection_name(monkeypatch): + factory = pgvector_module.PGVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(pgvector_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_HOST", "localhost") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PORT", 5432) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_USER", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PASSWORD", "secret") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_DATABASE", "postgres") + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MIN_CONNECTION", 1) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_MAX_CONNECTION", 5) + monkeypatch.setattr(pgvector_module.dify_config, "PGVECTOR_PG_BIGM", False) + + with patch.object(pgvector_module, "PGVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py new file mode 100644 index 0000000000..bd8df520ba --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py @@ -0,0 +1,269 @@ +import importlib +import sys +import types +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_psycopg2_modules(): + psycopg2 = types.ModuleType("psycopg2") + psycopg2.__path__ = [] + psycopg2_extras = types.ModuleType("psycopg2.extras") + psycopg2_pool = types.ModuleType("psycopg2.pool") + + class SimpleConnectionPool: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.getconn = MagicMock() + self.putconn = MagicMock() + + psycopg2_pool.SimpleConnectionPool = SimpleConnectionPool + psycopg2_extras.execute_values = MagicMock() + psycopg2.pool = psycopg2_pool + psycopg2.extras = psycopg2_extras + + return { + "psycopg2": psycopg2, + "psycopg2.pool": psycopg2_pool, + "psycopg2.extras": psycopg2_extras, + } + + +@pytest.fixture +def vastbase_module(monkeypatch): + for name, module in _build_fake_psycopg2_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VastbaseVectorConfig( + host="localhost", + port=5432, + user="dify", + password="secret", + database="dify", + min_connection=1, + max_connection=5, + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config VASTBASE_HOST is required"), + ("port", 0, "config VASTBASE_PORT is required"), + ("user", "", "config VASTBASE_USER is required"), + ("password", "", "config VASTBASE_PASSWORD is required"), + ("database", "", "config VASTBASE_DATABASE is required"), + ("min_connection", 0, "config VASTBASE_MIN_CONNECTION is required"), + ("max_connection", 0, "config VASTBASE_MAX_CONNECTION is required"), + ], +) +def test_vastbase_config_validation(vastbase_module, field, value, message): + values = _config(vastbase_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + vastbase_module.VastbaseVectorConfig.model_validate(values) + + +def test_vastbase_config_rejects_invalid_connection_window(vastbase_module): + with pytest.raises(ValidationError, match="VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION"): + vastbase_module.VastbaseVectorConfig.model_validate( + { + "host": "localhost", + "port": 5432, + "user": "dify", + "password": "secret", + "database": "dify", + "min_connection": 6, + "max_connection": 5, + } + ) + + +def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch): + pool = MagicMock() + monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool)) + + conn = MagicMock() + cur = MagicMock() + pool.getconn.return_value = conn + conn.cursor.return_value = cur + + vector = vastbase_module.VastbaseVector("collection_1", _config(vastbase_module)) + assert vector.get_type() == "vastbase" + assert vector.table_name == "embedding_collection_1" + + with vector._get_cursor() as got_cur: + assert got_cur is cur + + cur.close.assert_called_once() + conn.commit.assert_called_once() + pool.putconn.assert_called_once_with(conn) + + +def test_create_and_add_texts(vastbase_module, monkeypatch): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + vector._create_collection = MagicMock() + + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + monkeypatch.setattr(vastbase_module.uuid, "uuid4", lambda: "generated-uuid") + + docs = [ + Document(page_content="a", metadata={"doc_id": "doc-a"}), + Document(page_content="b", metadata={"document_id": "doc-b"}), + SimpleNamespace(page_content="c", metadata=None), + ] + + ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]]) + assert ids == ["doc-a", "generated-uuid"] + vastbase_module.psycopg2.extras.execute_values.assert_called_once() + + vector.add_texts = MagicMock(return_value=["doc-a"]) + result = vector.create(docs, [[0.1], [0.2], [0.3]]) + vector._create_collection.assert_called_once_with(1) + assert result == ["doc-a"] + + +def test_text_get_delete_and_metadata_methods(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.fetchone.return_value = ("id-1",) + cursor.__iter__.return_value = iter([({"doc_id": "1"}, "text-1"), ({"doc_id": "2"}, "text-2")]) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + assert vector.text_exists("id-1") is True + docs = vector.get_by_ids(["id-1", "id-2"]) + assert len(docs) == 2 + assert docs[0].page_content == "text-1" + + vector.delete_by_ids([]) + vector.delete_by_ids(["id-1"]) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete() + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("DELETE FROM embedding_collection_1 WHERE id IN %s" in sql for sql in executed_sql) + assert any("meta->>%s = %s" in sql for sql in executed_sql) + assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql) + + +def test_search_by_vector_and_full_text(vastbase_module): + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + cursor.__iter__.return_value = iter( + [ + ({"doc_id": "1"}, "text-1", 0.1), + ({"doc_id": "2"}, "text-2", 0.8), + ] + ) + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_vector([0.1, 0.2], top_k=0) + + docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + with pytest.raises(ValueError, match="top_k must be a positive integer"): + vector.search_by_full_text("hello", top_k=0) + + cursor.__iter__.return_value = iter([({"doc_id": "3"}, "full-text", 0.7)]) + full_docs = vector.search_by_full_text("hello world", top_k=2) + assert len(full_docs) == 1 + assert full_docs[0].page_content == "full-text" + + +def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vastbase_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vastbase_module.redis_client, "set", MagicMock()) + + vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector) + vector._collection_name = "collection_1" + vector.table_name = "embedding_collection_1" + cursor = MagicMock() + + @contextmanager + def _cursor_ctx(): + yield cursor + + vector._get_cursor = _cursor_ctx + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + cursor.execute.assert_not_called() + + monkeypatch.setattr(vastbase_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(17000) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("CREATE TABLE IF NOT EXISTS embedding_collection_1" in sql for sql in executed_sql) + assert all("embedding_cosine_v1_idx" not in sql for sql in executed_sql) + + cursor.execute.reset_mock() + vector._create_collection(3) + executed_sql = [call.args[0] for call in cursor.execute.call_args_list] + assert any("embedding_cosine_v1_idx" in sql for sql in executed_sql) + vastbase_module.redis_client.set.assert_called() + + +def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch): + factory = vastbase_module.VastbaseVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vastbase_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_HOST", "localhost") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PORT", 5432) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_USER", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_PASSWORD", "secret") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_DATABASE", "dify") + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MIN_CONNECTION", 1) + monkeypatch.setattr(vastbase_module.dify_config, "VASTBASE_MAX_CONNECTION", 5) + + with patch.object(vastbase_module, "VastbaseVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py new file mode 100644 index 0000000000..0408506563 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py @@ -0,0 +1,328 @@ +import importlib +import os +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_qdrant_modules(): + qdrant_client = types.ModuleType("qdrant_client") + qdrant_http = types.ModuleType("qdrant_client.http") + qdrant_http_models = types.ModuleType("qdrant_client.http.models") + qdrant_http_exceptions = types.ModuleType("qdrant_client.http.exceptions") + qdrant_local_pkg = types.ModuleType("qdrant_client.local") + qdrant_local_mod = types.ModuleType("qdrant_client.local.qdrant_local") + + class UnexpectedResponseError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + class FilterSelector: + def __init__(self, filter): + self.filter = filter + + class HnswConfigDiff: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class TextIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class PointStruct: + def __init__(self, **kwargs): + self.id = kwargs["id"] + self.vector = kwargs["vector"] + self.payload = kwargs["payload"] + + class Filter: + def __init__(self, must=None): + self.must = must or [] + + class FieldCondition: + def __init__(self, key, match): + self.key = key + self.match = match + + class MatchValue: + def __init__(self, value): + self.value = value + + class MatchAny: + def __init__(self, any): + self.any = any + + class MatchText: + def __init__(self, text): + self.text = text + + class _Distance(UserDict): + def __getitem__(self, key): + return key + + class QdrantClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.get_collections = MagicMock(return_value=SimpleNamespace(collections=[])) + self.create_collection = MagicMock() + self.create_payload_index = MagicMock() + self.upsert = MagicMock() + self.delete = MagicMock() + self.delete_collection = MagicMock() + self.retrieve = MagicMock(return_value=[]) + self.search = MagicMock(return_value=[]) + self.scroll = MagicMock(return_value=([], None)) + + class QdrantLocal(QdrantClient): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._load = MagicMock() + + qdrant_client.QdrantClient = QdrantClient + qdrant_http_models.FilterSelector = FilterSelector + qdrant_http_models.HnswConfigDiff = HnswConfigDiff + qdrant_http_models.PayloadSchemaType = SimpleNamespace(KEYWORD="KEYWORD") + qdrant_http_models.TextIndexParams = TextIndexParams + qdrant_http_models.TextIndexType = SimpleNamespace(TEXT="TEXT") + qdrant_http_models.TokenizerType = SimpleNamespace(MULTILINGUAL="MULTILINGUAL") + qdrant_http_models.VectorParams = VectorParams + qdrant_http_models.Distance = _Distance() + qdrant_http_models.PointStruct = PointStruct + qdrant_http_models.Filter = Filter + qdrant_http_models.FieldCondition = FieldCondition + qdrant_http_models.MatchValue = MatchValue + qdrant_http_models.MatchAny = MatchAny + qdrant_http_models.MatchText = MatchText + qdrant_http_exceptions.UnexpectedResponse = UnexpectedResponseError + + qdrant_http.models = qdrant_http_models + qdrant_local_mod.QdrantLocal = QdrantLocal + qdrant_local_pkg.qdrant_local = qdrant_local_mod + + return { + "qdrant_client": qdrant_client, + "qdrant_client.http": qdrant_http, + "qdrant_client.http.models": qdrant_http_models, + "qdrant_client.http.exceptions": qdrant_http_exceptions, + "qdrant_client.local": qdrant_local_pkg, + "qdrant_client.local.qdrant_local": qdrant_local_mod, + } + + +@pytest.fixture +def qdrant_module(monkeypatch): + for name, module in _build_fake_qdrant_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.qdrant.qdrant_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "endpoint": "http://localhost:6333", + "api_key": "api-key", + "timeout": 20, + "root_path": "/tmp", + "grpc_port": 6334, + "prefer_grpc": False, + "replication_factor": 1, + "write_consistency_factor": 1, + } + values.update(overrides) + return module.QdrantConfig.model_validate(values) + + +def test_qdrant_config_to_params(qdrant_module): + url_params = _config(qdrant_module).to_qdrant_params().model_dump() + assert url_params["url"] == "http://localhost:6333" + assert url_params["verify"] is False + + path_config = _config(qdrant_module, endpoint="path:storage") + assert path_config.to_qdrant_params().path == os.path.join("/tmp", "storage") + + with pytest.raises(ValueError, match="Root path is not set"): + _config(qdrant_module, endpoint="path:storage", root_path=None).to_qdrant_params() + + +def test_init_and_basic_behaviour(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.get_type() == qdrant_module.VectorType.QDRANT + assert vector.to_index_struct()["vector_store"]["class_prefix"] == "collection_1" + + docs = [Document(page_content="a", metadata={"doc_id": "a"})] + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + vector.create(docs, [[0.1]]) + vector.create_collection.assert_called_once_with("collection_1", 1) + vector.add_texts.assert_called_once() + + +def test_create_collection_and_add_texts(qdrant_module, monkeypatch): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(qdrant_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(qdrant_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=1)) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(qdrant_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.get_collections.return_value = SimpleNamespace(collections=[]) + vector.create_collection("collection_1", 3) + vector._client.create_collection.assert_called_once() + assert vector._client.create_payload_index.call_count == 4 + qdrant_module.redis_client.set.assert_called_once() + + # add_texts and generated batches + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + assert ids == ["id-1", "id-2"] + assert vector._client.upsert.call_count == 1 + + payloads = qdrant_module.QdrantVector._build_payloads( + ["a"], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + assert payloads[0]["group_id"] == "g1" + with pytest.raises(ValueError, match="At least one of the texts is None"): + qdrant_module.QdrantVector._build_payloads( + [None], [{"doc_id": "id-1"}], "content", "metadata", "g1", "group_id" + ) + + +def test_delete_and_exists_paths(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + unexpected = sys.modules["qdrant_client.http.exceptions"].UnexpectedResponse + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_metadata_field("document_id", "doc-1") + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete() + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete() + vector._client.delete.side_effect = None + + vector._client.delete.side_effect = unexpected(404) + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = unexpected(500) + with pytest.raises(unexpected): + vector.delete_by_ids(["doc-1"]) + vector._client.delete.side_effect = None + + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="other")]) + assert vector.text_exists("id-1") is False + vector._client.get_collections.return_value = SimpleNamespace(collections=[SimpleNamespace(name="collection_1")]) + vector._client.retrieve.return_value = [{"id": "id-1"}] + assert vector.text_exists("id-1") is True + + +def test_search_and_helper_methods(qdrant_module): + vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module)) + assert vector.search_by_vector([0.1], score_threshold=1.0) == [] + + vector._client.search.return_value = [ + SimpleNamespace(payload=None, score=0.9, vector=[0.1]), + SimpleNamespace(payload={"metadata": {"doc_id": "1"}, "page_content": "doc-a"}, score=0.8, vector=[0.1]), + ] + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + # full text search: keyword split, dedup and top_k limit + scroll_results = [ + ( + [ + SimpleNamespace(id="p1", payload={"page_content": "doc-1", "metadata": {"doc_id": "1"}}, vector=[0.1]), + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ( + [ + SimpleNamespace(id="p2", payload={"page_content": "doc-2", "metadata": {"doc_id": "2"}}, vector=[0.2]), + ], + None, + ), + ] + vector._client.scroll.side_effect = scroll_results + docs = vector.search_by_full_text("hello world", top_k=2, document_ids_filter=["d-1"]) + assert len(docs) == 2 + assert vector.search_by_full_text(" ", top_k=2) == [] + + local_client = qdrant_module.QdrantLocal() + vector._client = local_client + vector._reload_if_needed() + local_client._load.assert_called_once() + + doc = vector._document_from_scored_point( + SimpleNamespace(payload={"page_content": "doc", "metadata": {"doc_id": "1"}}, vector=[0.1]), + "page_content", + "metadata", + ) + assert doc.page_content == "doc" + + +def test_qdrant_factory_paths(qdrant_module, monkeypatch): + factory = qdrant_module.QdrantVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + collection_binding_id=None, + index_struct_dict=None, + index_struct=None, + ) + monkeypatch.setattr(qdrant_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(qdrant_module, "current_app", SimpleNamespace(config=SimpleNamespace(root_path="/root"))) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_URL", "http://localhost:6333") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_API_KEY", "api-key") + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_CLIENT_TIMEOUT", 20) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_PORT", 6334) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_GRPC_ENABLED", False) + monkeypatch.setattr(qdrant_module.dify_config, "QDRANT_REPLICATION_FACTOR", 1) + + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + result = factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert result == "vector" + assert vector_cls.call_args.kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset.index_struct is not None + + # collection binding lookup path + dataset.collection_binding_id = "binding-1" + dataset.index_struct_dict = {"vector_store": {"class_prefix": "existing"}} + monkeypatch.setattr(qdrant_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + qdrant_module.db.session.scalars = MagicMock( + return_value=SimpleNamespace(one_or_none=lambda: SimpleNamespace(collection_name="BOUND_COLLECTION")) + ) + with patch.object(qdrant_module, "QdrantVector", return_value="vector") as vector_cls: + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) + assert vector_cls.call_args.kwargs["collection_name"] == "BOUND_COLLECTION" + + qdrant_module.db.session.scalars = MagicMock(return_value=SimpleNamespace(one_or_none=lambda: None)) + with pytest.raises(ValueError, match="Dataset Collection Bindings does not exist"): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py new file mode 100644 index 0000000000..ca8cd5e514 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -0,0 +1,303 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from sqlalchemy.types import UserDefinedType + +from core.rag.models.document import Document + + +def _build_fake_relyt_modules(): + pgvecto_rs = types.ModuleType("pgvecto_rs") + pgvecto_rs_sqlalchemy = types.ModuleType("pgvecto_rs.sqlalchemy") + + class VECTOR(UserDefinedType): + def __init__(self, dim): + self.dim = dim + + pgvecto_rs_sqlalchemy.VECTOR = VECTOR + return { + "pgvecto_rs": pgvecto_rs, + "pgvecto_rs.sqlalchemy": pgvecto_rs_sqlalchemy, + } + + +class _FakeSession: + def __init__(self, execute_result=None): + self.execute_result = execute_result or MagicMock(fetchall=lambda: []) + self.execute = MagicMock(return_value=self.execute_result) + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + +@pytest.fixture +def relyt_module(monkeypatch): + for name, module in _build_fake_relyt_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.relyt.relyt_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "secret", + "database": "relyt", + } + values.update(overrides) + return module.RelytConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config RELYT_HOST is required"), + ("port", 0, "config RELYT_PORT is required"), + ("user", "", "config RELYT_USER is required"), + ("password", "", "config RELYT_PASSWORD is required"), + ("database", "", "config RELYT_DATABASE is required"), + ], +) +def test_relyt_config_validation(relyt_module, field, value, message): + values = _config(relyt_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + relyt_module.RelytConfig.model_validate(values) + + +def test_init_get_type_and_create_delegate(relyt_module, monkeypatch): + engine = MagicMock() + monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine)) + vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1") + vector.create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "seg-1"})] + + vector.create(docs, [[0.1, 0.2]]) + + assert vector.get_type() == relyt_module.VectorType.RELYT + assert vector._url == "postgresql+psycopg2://postgres:secret@localhost:5432/relyt" + assert vector.embedding_dimension == 2 + vector.create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + +def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(relyt_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(relyt_module.redis_client, "set", MagicMock()) + + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + session.execute.assert_not_called() + + monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.create_collection(3) + executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] + assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) + assert any("CREATE TABLE IF NOT EXISTS" in sql for sql in executed_sql) + assert any("CREATE INDEX" in sql for sql in executed_sql) + relyt_module.redis_client.set.assert_called_once() + + +def test_add_texts_and_metadata_queries(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector._group_id = "group-1" + vector.client = MagicMock() + + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + + monkeypatch.setattr(relyt_module.uuid, "uuid1", MagicMock(side_effect=["id-1", "id-2"])) + docs = [ + Document(page_content="a", metadata={"doc_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "d-2"}), + ] + ids = vector.add_texts(docs, [[0.1], [0.2]]) + + assert ids == ["id-1", "id-2"] + assert conn.execute.call_count >= 1 + first_insert_values = conn.execute.call_args.args[0].compile().params + assert "group_id" in str(first_insert_values) + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-a",), ("id-b",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a", "id-b"] + + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") is None + + +# 1. delete_by_uuids: success and connect error +def test_delete_by_uuids_success_and_connect_error(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + with pytest.raises(ValueError, match="No ids provided"): + vector.delete_by_uuids(None) + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + begin_ctx = MagicMock() + begin_ctx.__enter__.return_value = None + begin_ctx.__exit__.return_value = None + conn.begin.return_value = begin_ctx + vector.client.connect.return_value = conn + assert vector.delete_by_uuids(["id-1"]) is True + vector.client.connect.side_effect = RuntimeError("boom") + assert vector.delete_by_uuids(["id-1"]) is False + + +# 2. delete_by_metadata_field calls delete_by_uuids +def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_uuids.assert_called_once_with(["id-1"]) + + +# 3. delete_by_ids translates to uuids +def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("uuid-1",), ("uuid-2",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete_by_uuids = MagicMock(return_value=True) + vector.delete_by_ids(["doc-1", "doc-2"]) + vector.delete_by_uuids.assert_called_once_with(["uuid-1", "uuid-2"]) + + +# 4. text_exists True +def test_text_exists_true(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [("id-1",)])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is True + + +# 5. text_exists False +def test_text_exists_false(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession(execute_result=MagicMock(fetchall=lambda: [])) + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + assert vector.text_exists("doc-1") is False + + +# 6. similarity_search_with_score_by_vector returns Documents and scores +def test_similarity_search_with_score_by_vector(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + result_rows = [ + SimpleNamespace(document="doc-a", metadata={"doc_id": "1"}, distance=0.1), + SimpleNamespace(document="doc-b", metadata={"doc_id": "2"}, distance=0.8), + ] + conn = MagicMock() + conn.__enter__.return_value = conn + conn.__exit__.return_value = None + conn.execute.return_value.fetchall.return_value = result_rows + vector.client.connect.return_value = conn + similarities = vector.similarity_search_with_score_by_vector([0.1, 0.2], k=2, filter={"document_id": ["d-1"]}) + assert len(similarities) == 2 + assert similarities[0][0].page_content == "doc-a" + + +# 7. search_by_vector filters by score and ids +def test_search_by_vector_filters_by_score_and_ids(relyt_module): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + vector.similarity_search_with_score_by_vector = MagicMock( + return_value=[ + (Document(page_content="a", metadata={"doc_id": "1"}), 0.1), + (Document(page_content="b", metadata={}), 0.9), + ] + ) + docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) + assert len(docs) == 1 + assert vector.search_by_full_text("query") == [] + + +# 8. delete commits session +def test_delete_commits_session(relyt_module, monkeypatch): + vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) + vector._collection_name = "collection_1" + vector.client = MagicMock() + vector.embedding_dimension = 3 + session = _FakeSession() + monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + vector.delete() + session.commit.assert_called_once() + + +def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): + factory = relyt_module.RelytVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(relyt_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_HOST", "localhost") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PORT", 5432) + monkeypatch.setattr(relyt_module.dify_config, "RELYT_USER", "postgres") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_PASSWORD", "secret") + monkeypatch.setattr(relyt_module.dify_config, "RELYT_DATABASE", "relyt") + + with patch.object(relyt_module, "RelytVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py new file mode 100644 index 0000000000..e3b6676d9b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py @@ -0,0 +1,316 @@ +import importlib +import json +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_tablestore_module(): + tablestore = types.ModuleType("tablestore") + + class _BatchGetRowRequest: + def __init__(self): + self.items = [] + + def add(self, item): + self.items.append(item) + + class _TableInBatchGetRowItem: + def __init__(self, table_name, rows_to_get, columns_to_get, _unused, _ver): + self.table_name = table_name + self.rows_to_get = rows_to_get + self.columns_to_get = columns_to_get + + class _Row: + def __init__(self, primary_key, attribute_columns=None): + self.primary_key = primary_key + self.attribute_columns = attribute_columns or [] + + class _Client: + def __init__(self, *_args): + self.list_table = MagicMock(return_value=[]) + self.create_table = MagicMock() + self.list_search_index = MagicMock(return_value=[]) + self.create_search_index = MagicMock() + self.delete_search_index = MagicMock() + self.delete_table = MagicMock() + self.put_row = MagicMock() + self.delete_row = MagicMock() + self.get_row = MagicMock(return_value=(None, None, None)) + self.batch_get_row = MagicMock() + self.search = MagicMock() + + tablestore.OTSClient = _Client + tablestore.BatchGetRowRequest = _BatchGetRowRequest + tablestore.TableInBatchGetRowItem = _TableInBatchGetRowItem + tablestore.Row = _Row + tablestore.TableMeta = lambda name, schema: ("table_meta", name, schema) + tablestore.TableOptions = lambda: ("table_options",) + tablestore.CapacityUnit = lambda read, write: ("capacity", read, write) + tablestore.ReservedThroughput = lambda cap: ("reserved", cap) + tablestore.FieldSchema = lambda *args, **kwargs: ("field", args, kwargs) + tablestore.VectorOptions = lambda **kwargs: ("vector_options", kwargs) + tablestore.SearchIndexMeta = lambda field_schemas: ("search_index_meta", field_schemas) + tablestore.SearchQuery = lambda query, **kwargs: SimpleNamespace(query=query, **kwargs) + tablestore.TermQuery = lambda key, value: ("term_query", key, value) + tablestore.ColumnsToGet = lambda **kwargs: ("columns_to_get", kwargs) + tablestore.KnnVectorQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.TermsQuery = lambda key, values: ("terms_query", key, values) + tablestore.Sort = lambda **kwargs: ("sort", kwargs) + tablestore.ScoreSort = lambda **kwargs: ("score_sort", kwargs) + tablestore.BoolQuery = lambda **kwargs: SimpleNamespace(**kwargs) + tablestore.MatchQuery = lambda **kwargs: ("match_query", kwargs) + + tablestore.FieldType = SimpleNamespace(TEXT="TEXT", VECTOR="VECTOR", KEYWORD="KEYWORD") + tablestore.AnalyzerType = SimpleNamespace(MAXWORD="MAXWORD") + tablestore.VectorDataType = SimpleNamespace(VD_FLOAT_32="VD_FLOAT_32") + tablestore.VectorMetricType = SimpleNamespace(VM_COSINE="VM_COSINE") + tablestore.ColumnReturnType = SimpleNamespace(SPECIFIED="SPECIFIED", ALL_FROM_INDEX="ALL_FROM_INDEX") + tablestore.SortOrder = SimpleNamespace(DESC="DESC") + return tablestore + + +@pytest.fixture +def tablestore_module(monkeypatch): + fake_module = _build_fake_tablestore_module() + monkeypatch.setitem(sys.modules, "tablestore", fake_module) + + import core.rag.datasource.vdb.tablestore.tablestore_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "access_key_id": "ak", + "access_key_secret": "sk", + "instance_name": "instance", + "endpoint": "endpoint", + "normalize_full_text_bm25_score": False, + } + values.update(overrides) + return module.TableStoreConfig.model_validate(values) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("access_key_id", "", "config ACCESS_KEY_ID is required"), + ("access_key_secret", "", "config ACCESS_KEY_SECRET is required"), + ("instance_name", "", "config INSTANCE_NAME is required"), + ("endpoint", "", "config ENDPOINT is required"), + ], +) +def test_tablestore_config_validation(tablestore_module, field, value, message): + values = _config(tablestore_module).model_dump() + values[field] = value + with pytest.raises(ValidationError, match=message): + tablestore_module.TableStoreConfig.model_validate(values) + + +def test_init_and_basic_delegation(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + assert vector.get_type() == tablestore_module.VectorType.TABLESTORE + assert vector._table_name == "collection_1" + assert vector._index_name == "collection_1_idx" + + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + docs = [Document(page_content="hello", metadata={"doc_id": "d-1"})] + vector.create(docs, [[0.1, 0.2]]) + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(documents=docs, embeddings=[[0.1, 0.2]]) + + vector.create_collection([[0.1, 0.2]]) + assert vector._create_collection.call_count == 2 + + +def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + # get_by_ids + ok_item = SimpleNamespace( + is_ok=True, + row=SimpleNamespace( + attribute_columns=[("metadata", json.dumps({"doc_id": "1"}), None), ("page_content", "text-1", None)] + ), + ) + fail_item = SimpleNamespace(is_ok=False, row=None) + batch_resp = SimpleNamespace(get_result_by_table=lambda _table: [ok_item, fail_item]) + vector._tablestore_client.batch_get_row.return_value = batch_resp + docs = vector.get_by_ids(["id-1"]) + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + + # text_exists + vector._tablestore_client.get_row.return_value = (None, object(), None) + assert vector.text_exists("id-1") is True + vector._tablestore_client.get_row.return_value = (None, None, None) + assert vector.text_exists("id-1") is False + + # delete wrappers + vector._delete_row = MagicMock() + vector.delete_by_ids([]) + vector._delete_row.assert_not_called() + vector.delete_by_ids(["id-1", "id-2"]) + assert vector._delete_row.call_count == 2 + + vector._search_by_metadata = MagicMock(return_value=["id-a"]) + assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-a"] + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("document_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-a"]) + + vector._search_by_vector = MagicMock(return_value=["vec-doc"]) + vector._search_by_full_text = MagicMock(return_value=["fts-doc"]) + assert vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) == ["vec-doc"] + assert vector.search_by_full_text("query", top_k=2, score_threshold=0.3, document_ids_filter=["d-1"]) == ["fts-doc"] + + vector._delete_table_if_exist = MagicMock() + vector.delete() + vector._delete_table_if_exist.assert_called_once() + + +def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tablestore_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tablestore_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_table_if_not_exist = MagicMock() + vector._create_search_index_if_not_exist = MagicMock() + vector._create_collection(3) + vector._create_table_if_not_exist.assert_not_called() + + monkeypatch.setattr(tablestore_module.redis_client, "get", MagicMock(return_value=None)) + vector._create_collection(3) + vector._create_table_if_not_exist.assert_called_once() + vector._create_search_index_if_not_exist.assert_called_once_with(3) + tablestore_module.redis_client.set.assert_called_once() + + vector = tablestore_module.TableStoreVector("collection_2", _config(tablestore_module)) + vector._tablestore_client.list_table.return_value = ["collection_2"] + assert vector._create_table_if_not_exist() is None + vector._tablestore_client.list_table.return_value = [] + vector._create_table_if_not_exist() + vector._tablestore_client.create_table.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "collection_2_idx")] + assert vector._create_search_index_if_not_exist(3) is None + vector._tablestore_client.list_search_index.return_value = [] + vector._create_search_index_if_not_exist(3) + vector._tablestore_client.create_search_index.assert_called_once() + + vector._tablestore_client.list_search_index.return_value = [("collection_2", "idx_a"), ("collection_2", "idx_b")] + vector._delete_table_if_exist() + assert vector._tablestore_client.delete_search_index.call_count == 2 + vector._tablestore_client.delete_table.assert_called_once_with("collection_2") + + vector._delete_search_index() + vector._tablestore_client.delete_search_index.assert_called_with("collection_2", "collection_2_idx") + + +def test_write_row_and_search_helpers(tablestore_module): + vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module)) + + vector._write_row( + "id-1", + { + "page_content": "hello", + "vector": [0.1, 0.2], + "metadata": {"doc_id": "d-1", "document_id": "doc-1"}, + }, + ) + put_row_call = vector._tablestore_client.put_row.call_args + assert put_row_call.args[0] == "collection_1" + attrs = put_row_call.args[1].attribute_columns + assert any(item[0] == "metadata_tags" for item in attrs) + + vector._delete_row("id-1") + vector._tablestore_client.delete_row.assert_called_once() + + # metadata search pagination + first_page = SimpleNamespace(rows=[[(("id", "row-1"),)]], next_token=b"next") + second_page = SimpleNamespace(rows=[[(("id", "row-2"),)]], next_token=b"") + vector._tablestore_client.search.side_effect = [first_page, second_page] + ids = vector._search_by_metadata("document_id", "doc-1") + assert ids == ["row-1", "row-2"] + vector._tablestore_client.search.side_effect = None + + # vector search + hit1 = SimpleNamespace( + score=0.9, + row=( + None, + [("page_content", "doc-a"), ("metadata", json.dumps({"doc_id": "1"})), ("vector", json.dumps([0.1]))], + ), + ) + hit2 = SimpleNamespace( + score=0.2, + row=( + None, + [("page_content", "doc-b"), ("metadata", json.dumps({"doc_id": "2"})), ("vector", json.dumps([0.2]))], + ), + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit1, hit2]) + docs = vector._search_by_vector([0.1], document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.5) + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.9) + + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(0) == pytest.approx(0.0) + assert tablestore_module.TableStoreVector._normalize_score_exp_decay(100) <= 1.0 + + # full text search with and without normalized score filter + vector._normalize_full_text_bm25_score = True + hit3 = SimpleNamespace( + score=10.0, row=(None, [("page_content", "doc-c"), ("metadata", json.dumps({"doc_id": "3"}))]) + ) + hit4 = SimpleNamespace( + score=0.1, row=(None, [("page_content", "doc-d"), ("metadata", json.dumps({"doc_id": "4"}))]) + ) + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3, hit4]) + docs = vector._search_by_full_text("query", document_ids_filter=["document_id=doc-1"], top_k=2, score_threshold=0.2) + assert len(docs) == 1 + assert "score" in docs[0].metadata + + vector._normalize_full_text_bm25_score = False + vector._tablestore_client.search.return_value = SimpleNamespace(search_hits=[hit3]) + docs = vector._search_by_full_text("query", document_ids_filter=None, top_k=2, score_threshold=0.0) + assert len(docs) == 1 + assert "score" not in docs[0].metadata + + +def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch): + factory = tablestore_module.TableStoreVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tablestore_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ENDPOINT", "endpoint") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_INSTANCE_NAME", "instance") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_ID", "ak") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_ACCESS_KEY_SECRET", "sk") + monkeypatch.setattr(tablestore_module.dify_config, "TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE", True) + + with patch.object(tablestore_module, "TableStoreVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "EXISTING_COLLECTION" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "AUTO_COLLECTION" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py new file mode 100644 index 0000000000..d8f35a6019 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py @@ -0,0 +1,309 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_tencent_modules(): + tcvdb_text = types.ModuleType("tcvdb_text") + tcvdb_text_encoder = types.ModuleType("tcvdb_text.encoder") + tcvectordb = types.ModuleType("tcvectordb") + tcvectordb_model = types.ModuleType("tcvectordb.model") + tcvectordb_document = types.ModuleType("tcvectordb.model.document") + tcvectordb_index = types.ModuleType("tcvectordb.model.index") + tcvectordb_enum = types.ModuleType("tcvectordb.model.enum") + + class _BM25Encoder: + def encode_texts(self, text): + return {"encoded_text": text} + + def encode_queries(self, query): + return {"encoded_query": query} + + @classmethod + def default(cls, _lang): + return cls() + + class VectorDBError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + class RPCVectorDBClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_database_if_not_exists = MagicMock() + self.exists_collection = MagicMock(return_value=False) + self.describe_collection = MagicMock(return_value=SimpleNamespace(indexes=[])) + self.create_collection = MagicMock() + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.search = MagicMock(return_value=[]) + self.hybrid_search = MagicMock(return_value=[]) + self.drop_collection = MagicMock() + + class _Document: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class _HNSWSearchParams: + def __init__(self, ef): + self.ef = ef + + class _AnnSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _KeywordSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _WeightedRerank: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Filter: + @staticmethod + def in_(field, values): + return ("in", field, values) + + def __init__(self, condition): + self.condition = condition + + _Filter.In = staticmethod(_Filter.in_) + + class _HNSWParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _FilterIndex: + def __init__(self, *args): + self.args = args + + class _VectorIndex: + def __init__(self, *args): + self.args = args + + class _SparseIndex: + def __init__(self, **kwargs): + self.kwargs = kwargs + + tcvectordb_enum.IndexType = SimpleNamespace( + __members__={"HNSW": "HNSW", "PRIMARY_KEY": "PRIMARY_KEY", "FILTER": "FILTER", "SPARSE_INVERTED": "SPARSE"}, + PRIMARY_KEY="PRIMARY_KEY", + FILTER="FILTER", + SPARSE_INVERTED="SPARSE", + ) + tcvectordb_enum.MetricType = SimpleNamespace(__members__={"IP": "IP"}, IP="IP") + tcvectordb_enum.FieldType = SimpleNamespace(String="String", Json="Json", SparseVector="SparseVector") + + tcvectordb_document.Document = _Document + tcvectordb_document.HNSWSearchParams = _HNSWSearchParams + tcvectordb_document.AnnSearch = _AnnSearch + tcvectordb_document.Filter = _Filter + tcvectordb_document.KeywordSearch = _KeywordSearch + tcvectordb_document.WeightedRerank = _WeightedRerank + + tcvectordb_index.HNSWParams = _HNSWParams + tcvectordb_index.FilterIndex = _FilterIndex + tcvectordb_index.VectorIndex = _VectorIndex + tcvectordb_index.SparseIndex = _SparseIndex + + tcvdb_text_encoder.BM25Encoder = _BM25Encoder + + tcvectordb_model.document = tcvectordb_document + tcvectordb_model.enum = tcvectordb_enum + tcvectordb_model.index = tcvectordb_index + + tcvectordb.RPCVectorDBClient = RPCVectorDBClient + tcvectordb.VectorDBException = VectorDBError + + return { + "tcvdb_text": tcvdb_text, + "tcvdb_text.encoder": tcvdb_text_encoder, + "tcvectordb": tcvectordb, + "tcvectordb.model": tcvectordb_model, + "tcvectordb.model.document": tcvectordb_document, + "tcvectordb.model.index": tcvectordb_index, + "tcvectordb.model.enum": tcvectordb_enum, + } + + +@pytest.fixture +def tencent_module(monkeypatch): + for name, module in _build_fake_tencent_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.tencent.tencent_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "url": "http://vdb.local", + "api_key": "api-key", + "timeout": 30, + "username": "user", + "database": "db", + "index_type": "HNSW", + "metric_type": "IP", + "shard": 1, + "replicas": 2, + "max_upsert_batch_size": 2, + "enable_hybrid_search": False, + } + values.update(overrides) + return module.TencentConfig.model_validate(values) + + +def test_config_and_init_paths(tencent_module): + config = _config(tencent_module) + assert config.to_tencent_params()["url"] == "http://vdb.local" + + vector = tencent_module.TencentVector("collection_1", config) + assert vector.get_type() == tencent_module.VectorType.TENCENT + assert vector._client.kwargs["key"] == "api-key" + + vector._client.exists_collection.return_value = True + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=768), SimpleNamespace(name="sparse_vector", dimension=0)] + ) + vector._client_config.enable_hybrid_search = True + vector._load_collection() + assert vector._enable_hybrid_search is True + assert vector._dimension == 768 + + vector._client.describe_collection.return_value = SimpleNamespace( + indexes=[SimpleNamespace(name="vector", dimension=512)] + ) + vector._load_collection() + assert vector._enable_hybrid_search is False + + +def test_create_collection_branches(tencent_module, monkeypatch): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module)) + + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tencent_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tencent_module.redis_client, "set", MagicMock()) + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + monkeypatch.setattr(tencent_module.redis_client, "get", MagicMock(return_value=None)) + vector._client.exists_collection.return_value = True + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + + vector._client.exists_collection.return_value = False + vector._client_config.index_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported index_type"): + vector._create_collection(3) + + vector._client_config.index_type = "HNSW" + vector._client_config.metric_type = "UNKNOWN" + with pytest.raises(ValueError, match="unsupported metric_type"): + vector._create_collection(3) + + vector._client_config.metric_type = "IP" + vector._client.create_collection.side_effect = [ + tencent_module.VectorDBException("fieldType:json unsupported"), + None, + ] + vector._enable_hybrid_search = True + vector._create_collection(3) + assert vector._client.create_collection.call_count == 2 + tencent_module.redis_client.set.assert_called_once() + vector._client.create_collection.side_effect = None + + +def test_create_add_delete_and_search_behaviour(tencent_module): + vector = tencent_module.TencentVector("collection_1", _config(tencent_module, enable_hybrid_search=True)) + vector._create_collection = MagicMock() + docs = [ + Document(page_content="text-a", metadata={"doc_id": "a", "document_id": "doc-a"}), + Document(page_content="text-b", metadata={"doc_id": "b", "document_id": "doc-b"}), + Document(page_content="text-c", metadata={"doc_id": "c", "document_id": "doc-c"}), + ] + embeddings = [[0.1], [0.2], [0.3]] + vector.create(docs, embeddings) + vector._create_collection.assert_called_once_with(1) + + vector._client.upsert.reset_mock() + vector.add_texts(docs, embeddings) + assert vector._client.upsert.call_count == 2 + first_docs = vector._client.upsert.call_args_list[0].kwargs["documents"] + assert "sparse_vector" in first_docs[0].__dict__ + + vector._client.query.return_value = [{"id": "a"}] + assert vector.text_exists("a") is True + vector._client.query.return_value = [] + assert vector.text_exists("a") is False + + vector.delete_by_ids([]) + vector._client.delete.assert_not_called() + vector.delete_by_ids(["a", "b", "c"]) + assert vector._client.delete.call_count == 2 + vector.delete_by_metadata_field("document_id", "doc-a") + assert vector._client.delete.call_count >= 3 + + vector._client.search.return_value = [[{"metadata": {"doc_id": "1"}, "text": "vec-doc", "score": 0.9}]] + vec_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(vec_docs) == 1 + assert vec_docs[0].metadata["score"] == pytest.approx(0.9) + + vector._enable_hybrid_search = False + assert vector.search_by_full_text("query") == [] + vector._enable_hybrid_search = True + vector._client.hybrid_search.return_value = [[{"metadata": {"doc_id": "2"}, "text": "fts-doc", "score": 0.8}]] + fts_docs = vector.search_by_full_text("query", top_k=2, score_threshold=0.5, document_ids_filter=["doc-a"]) + assert len(fts_docs) == 1 + + # _get_search_res handles old string metadata format + compat_docs = vector._get_search_res([[{"metadata": '{"doc_id": "3"}', "text": "compat", "score": 0.2}]], 0.5) + assert len(compat_docs) == 1 + assert compat_docs[0].metadata["score"] == pytest.approx(0.8) + + vector._has_collection = MagicMock(return_value=True) + vector.delete() + vector._client.drop_collection.assert_called_once() + + +def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch): + factory = tencent_module.TencentVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tencent_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_URL", "http://vdb.local") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_API_KEY", "api-key") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_TIMEOUT", 30) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_USERNAME", "user") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_DATABASE", "db") + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_SHARD", 1) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_REPLICAS", 2) + monkeypatch.setattr(tencent_module.dify_config, "TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH", True) + + with patch.object(tencent_module, "TencentVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py new file mode 100644 index 0000000000..369cda39bf --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_base.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace + +import pytest + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.models.document import Document + + +class _DummyVector(BaseVector): + def __init__(self, collection_name: str, existing_ids: set[str] | None = None): + super().__init__(collection_name) + self._existing_ids = existing_ids or set() + + def get_type(self) -> str: + return "dummy" + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + return None + + def text_exists(self, id: str) -> bool: + return id in self._existing_ids + + def delete_by_ids(self, ids: list[str]): + return None + + def delete_by_metadata_field(self, key: str, value: str): + return None + + def search_by_vector(self, query_vector: list[float], **kwargs): + return [] + + def search_by_full_text(self, query: str, **kwargs): + return [] + + def delete(self): + return None + + +@pytest.mark.parametrize( + ("base_method", "args"), + [ + (BaseVector.get_type, ()), + (BaseVector.create, ([], [])), + (BaseVector.add_texts, ([], [])), + (BaseVector.text_exists, ("doc-1",)), + (BaseVector.delete_by_ids, ([],)), + (BaseVector.get_ids_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.delete_by_metadata_field, ("doc_id", "doc-1")), + (BaseVector.search_by_vector, ([0.1],)), + (BaseVector.search_by_full_text, ("query",)), + (BaseVector.delete, ()), + ], +) +def test_base_vector_default_methods_raise_not_implemented(base_method, args): + vector = _DummyVector("collection_1") + + with pytest.raises(NotImplementedError): + base_method(vector, *args) + + +def test_filter_duplicate_texts_removes_existing_docs(): + vector = _DummyVector("collection_1", existing_ids={"dup"}) + docs = [ + SimpleNamespace(page_content="keep-no-meta", metadata=None), + Document(page_content="keep-no-doc-id", metadata={"document_id": "d1"}), + Document(page_content="remove-dup", metadata={"doc_id": "dup"}), + Document(page_content="keep-unique", metadata={"doc_id": "unique"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + + assert [d.page_content for d in filtered] == ["keep-no-meta", "keep-no-doc-id", "keep-unique"] + + +def test_get_uuids_and_collection_name_property(): + vector = _DummyVector("collection_1") + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + SimpleNamespace(page_content="b", metadata=None), + Document(page_content="c", metadata={"document_id": "d-1"}), + Document(page_content="d", metadata={"doc_id": "id-2"}), + ] + + assert vector._get_uuids(docs) == ["id-1", "id-2"] + assert vector.collection_name == "collection_1" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py new file mode 100644 index 0000000000..54ad6d330b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -0,0 +1,436 @@ +import base64 +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str): + fake_module = types.ModuleType(module_path) + fake_cls = type(class_name, (), {}) + setattr(fake_module, class_name, fake_cls) + monkeypatch.setitem(sys.modules, module_path, fake_module) + return fake_cls + + +@pytest.fixture +def vector_factory_module(): + import importlib + + import core.rag.datasource.vdb.vector_factory as module + + return importlib.reload(module) + + +def test_gen_index_struct_dict(vector_factory_module): + result = vector_factory_module.AbstractVectorFactory.gen_index_struct_dict( + vector_factory_module.VectorType.WEAVIATE, + "collection_1", + ) + + assert result == { + "type": vector_factory_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "collection_1"}, + } + + +@pytest.mark.parametrize( + ("vector_type", "module_path", "class_name"), + [ + ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ( + "ALIBABACLOUD_MYSQL", + "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "AlibabaCloudMySQLVectorFactory", + ), + ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ( + "ELASTICSEARCH", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "ElasticSearchVectorFactory", + ), + ( + "ELASTICSEARCH_JA", + "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "ElasticSearchJaVectorFactory", + ), + ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ( + "OPENSEARCH", + "core.rag.datasource.vdb.opensearch.opensearch_vector", + "OpenSearchVectorFactory", + ), + ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ( + "TIDB_ON_QDRANT", + "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "TidbOnQdrantVectorFactory", + ), + ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ( + "HUAWEI_CLOUD", + "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "HuaweiCloudVectorFactory", + ), + ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ], +) +def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): + expected_cls = _register_fake_factory_module(monkeypatch, module_path, class_name) + + result_cls = vector_factory_module.Vector.get_vector_factory(getattr(vector_factory_module.VectorType, vector_type)) + + assert result_cls is expected_cls + + +def test_get_vector_factory_unsupported(vector_factory_module): + with pytest.raises(ValueError, match="not supported"): + vector_factory_module.Vector.get_vector_factory("unknown") + + +def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): + dataset = SimpleNamespace(id="dataset-1") + + with ( + patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), + patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), + ): + default_vector = vector_factory_module.Vector(dataset) + custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) + + assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + assert custom_vector._attributes == ["doc_id"] + assert default_vector._embeddings == "embeddings" + assert default_vector._vector_processor == "processor" + + +def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): + calls = {"vector_type": None, "init_args": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + calls["init_args"] = (dataset, attributes, embeddings) + return "vector-processor" + + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + index_struct_dict={"type": vector_factory_module.VectorType.UPSTASH}, tenant_id="tenant-1" + ) + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.UPSTASH + assert calls["init_args"] == (vector._dataset, ["doc_id"], "embeddings") + + +def test_init_vector_uses_whitelist_override(vector_factory_module, monkeypatch): + class _Expr: + def __eq__(self, _other): + return "expr" + + calls = {"vector_type": None} + + class _Factory: + def init_vector(self, dataset, attributes, embeddings): + return "vector-processor" + + monkeypatch.setattr(vector_factory_module, "Whitelist", SimpleNamespace(tenant_id=_Expr(), category=_Expr())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace(session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(one_or_none=lambda: object()))), + ) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", vector_factory_module.VectorType.CHROMA) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", True) + monkeypatch.setattr( + vector_factory_module.Vector, + "get_vector_factory", + staticmethod(lambda vector_type: calls.update(vector_type=vector_type) or _Factory), + ) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = ["doc_id"] + vector._embeddings = "embeddings" + + result = vector._init_vector() + + assert result == "vector-processor" + assert calls["vector_type"] == vector_factory_module.VectorType.TIDB_ON_QDRANT + + +def test_init_vector_raises_when_vector_store_missing(vector_factory_module, monkeypatch): + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE", None) + monkeypatch.setattr(vector_factory_module.dify_config, "VECTOR_STORE_WHITELIST_ENABLE", False) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace(index_struct_dict=None, tenant_id="tenant-1") + vector._attributes = [] + vector._embeddings = "embeddings" + + with pytest.raises(ValueError, match="Vector store must be specified"): + vector._init_vector() + + +def test_create_batches_texts_and_skips_empty_input(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + docs = [Document(page_content=f"doc-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(1001)] + vector._embeddings.embed_documents.side_effect = [ + [[0.1] for _ in range(1000)], + [[0.2]], + ] + + vector.create(texts=docs, trace_id="trace-1") + + assert vector._embeddings.embed_documents.call_count == 2 + assert vector._vector_processor.create.call_count == 2 + assert vector._vector_processor.create.call_args_list[0].kwargs["trace_id"] == "trace-1" + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create(texts=None) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch): + class _Field: + def in_(self, value): + return value + + def __eq__(self, value): + return value + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_multimodal_documents.return_value = [[0.1, 0.2]] + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr(vector_factory_module, "select", lambda _model: SimpleNamespace(where=lambda *_args: "stmt")) + monkeypatch.setattr( + vector_factory_module, + "db", + SimpleNamespace( + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="f-1", key="k-1")]) + ) + ), + ) + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"abc")) + + docs = [ + Document(page_content="file-1", metadata={"doc_id": "f-1", "doc_type": "image"}), + Document(page_content="file-2", metadata={"doc_id": "f-2", "doc_type": "image"}), + ] + + vector.create_multimodal(file_documents=docs, request_id="r-1") + + file_base64 = base64.b64encode(b"abc").decode() + vector._embeddings.embed_multimodal_documents.assert_called_once_with( + [{"content": file_base64, "content_type": "image", "file_id": "f-1"}] + ) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], + embeddings=[[0.1, 0.2]], + request_id="r-1", + ) + + vector._embeddings.embed_multimodal_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create_multimodal(file_documents=None) + vector._embeddings.embed_multimodal_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_add_texts_with_optional_duplicate_check(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + vector._filter_duplicate_texts = MagicMock() + + docs = [ + Document(page_content="a", metadata={"doc_id": "id-1"}), + Document(page_content="b", metadata={"doc_id": "id-2"}), + ] + vector._filter_duplicate_texts.return_value = [docs[0]] + vector._embeddings.embed_documents.return_value = [[0.1]] + + vector.add_texts(docs, duplicate_check=True, flag=True) + + vector._filter_duplicate_texts.assert_called_once_with(docs) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0]], embeddings=[[0.1]], duplicate_check=True, flag=True + ) + + vector._filter_duplicate_texts.reset_mock() + vector._vector_processor.create.reset_mock() + vector._embeddings.embed_documents.return_value = [[0.2], [0.3]] + + vector.add_texts(docs, duplicate_check=False) + + vector._filter_duplicate_texts.assert_not_called() + vector._vector_processor.create.assert_called_once() + + +def test_vector_delegation_methods(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_query.return_value = [0.1, 0.2] + vector._vector_processor = MagicMock() + vector._vector_processor.text_exists.return_value = True + vector._vector_processor.search_by_vector.return_value = ["vector-doc"] + vector._vector_processor.search_by_full_text.return_value = ["text-doc"] + + assert vector.text_exists("doc-1") is True + vector.delete_by_ids(["doc-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + assert vector.search_by_vector("hello", top_k=3) == ["vector-doc"] + assert vector.search_by_full_text("hello", top_k=3) == ["text-doc"] + + vector._vector_processor.delete_by_ids.assert_called_once_with(["doc-1"]) + vector._vector_processor.delete_by_metadata_field.assert_called_once_with("doc_id", "doc-1") + + +def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): + class _Field: + def __eq__(self, value): + return value + + upload_query = MagicMock() + upload_query.where.return_value = upload_query + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._vector_processor = MagicMock() + + monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) + monkeypatch.setattr( + vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query)) + ) + + upload_query.first.return_value = None + assert vector.search_by_file("file-1") == [] + + upload_query.first.return_value = SimpleNamespace(key="blob-key") + monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) + vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] + vector._vector_processor.search_by_vector.return_value = ["hit"] + + result = vector.search_by_file("file-2", top_k=2) + + assert result == ["hit"] + payload = vector._embeddings.embed_multimodal_query.call_args.args[0] + assert payload["content_type"] == vector_factory_module.DocType.IMAGE + assert payload["file_id"] == "file-2" + + +def test_delete_clears_redis_cache_when_collection_exists(vector_factory_module, monkeypatch): + delete_mock = MagicMock() + redis_delete = MagicMock() + monkeypatch.setattr(vector_factory_module.redis_client, "delete", redis_delete) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="collection_1") + + vector.delete() + + delete_mock.assert_called_once() + redis_delete.assert_called_once_with("vector_indexing_collection_1") + + vector._vector_processor = SimpleNamespace(delete=delete_mock, collection_name="") + redis_delete.reset_mock() + vector.delete() + redis_delete.assert_not_called() + + +def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatch): + model_manager = MagicMock() + model_manager.get_model_instance.return_value = "model-instance" + + for_tenant_mock = MagicMock(return_value=model_manager) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) + + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + result = vector._get_embeddings() + + assert result == "cached-embedding" + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") + model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=vector_factory_module.ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + +def test_filter_duplicate_texts_and_getattr(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector.text_exists = MagicMock(side_effect=lambda doc_id: doc_id == "dup") + + docs = [ + SimpleNamespace(page_content="no-meta", metadata=None), + Document(page_content="empty-doc-id", metadata={"doc_id": ""}), + Document(page_content="duplicate", metadata={"doc_id": "dup"}), + Document(page_content="unique", metadata={"doc_id": "ok"}), + ] + + filtered = vector._filter_duplicate_texts(docs) + assert [doc.page_content for doc in filtered] == ["no-meta", "empty-doc-id", "unique"] + + class _Processor: + def ping(self): + return "pong" + + vector._vector_processor = _Processor() + assert vector.ping() == "pong" + + with pytest.raises(AttributeError): + _ = vector.unknown_method + + vector._vector_processor = None + with pytest.raises(AttributeError, match="vector_processor"): + _ = vector.another_missing diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py new file mode 100644 index 0000000000..951a920f3b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -0,0 +1,443 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +@pytest.fixture +def tidb_module(): + import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + + return importlib.reload(module) + + +def _config(tidb_module): + return tidb_module.TiDBVectorConfig( + host="localhost", + port=4000, + user="root", + password="secret", + database="dify", + program_name="dify-app", + ) + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("host", "", "config TIDB_VECTOR_HOST is required"), + ("port", 0, "config TIDB_VECTOR_PORT is required"), + ("user", "", "config TIDB_VECTOR_USER is required"), + ("database", "", "config TIDB_VECTOR_DATABASE is required"), + ("program_name", "", "config APPLICATION_NAME is required"), + ], +) +def test_tidb_config_validation(tidb_module, field, value, message): + values = _config(tidb_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + tidb_module.TiDBVectorConfig.model_validate(values) + + +def test_init_get_type_and_distance_func(tidb_module, monkeypatch): + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine")) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2") + + assert vector.get_type() == tidb_module.VectorType.TIDB_VECTOR + assert vector._url.startswith("mysql+pymysql://root:secret@localhost:4000/dify") + assert vector._dimension == 1536 + assert vector._get_distance_func() == "VEC_L2_DISTANCE" + + vector._distance_func = "cosine" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + vector._distance_func = "other" + assert vector._get_distance_func() == "VEC_COSINE_DISTANCE" + + +def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch): + fake_tidb_vector = types.ModuleType("tidb_vector") + fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy") + + class _VectorType: + def __init__(self, dim): + self.dim = dim + + fake_tidb_sqlalchemy.VectorType = _VectorType + + monkeypatch.setitem(sys.modules, "tidb_vector", fake_tidb_vector) + monkeypatch.setitem(sys.modules, "tidb_vector.sqlalchemy", fake_tidb_sqlalchemy) + monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value=MagicMock())) + monkeypatch.setattr(tidb_module, "Column", lambda *args, **kwargs: SimpleNamespace(args=args, kwargs=kwargs)) + monkeypatch.setattr( + tidb_module, + "Table", + lambda name, _metadata, *columns, **_kwargs: SimpleNamespace(name=name, columns=columns), + ) + + vector = tidb_module.TiDBVector("collection_1", _config(tidb_module)) + table = vector._table(3) + + assert table.name == "collection_1" + column_names = [column.args[0] for column in table.columns] + assert tidb_module.Field.PRIMARY_KEY in column_names + assert tidb_module.Field.VECTOR in column_names + assert tidb_module.Field.TEXT_KEY in column_names + + +def test_create_calls_collection_and_add_texts(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="a", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + assert vector._dimension == 2 + + +def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=1)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + + tidb_module.Session = MagicMock() + + vector._create_collection(3) + + tidb_module.Session.assert_not_called() + tidb_module.redis_client.set.assert_not_called() + + +def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(tidb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(tidb_module.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(tidb_module.redis_client, "set", MagicMock()) + + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "l2" + + vector._create_collection(3) + + session.begin.assert_called_once() + sql = str(session.execute.call_args.args[0]) + assert "VECTOR(3)" in sql + assert "VEC_L2_DISTANCE" in sql + session.commit.assert_called_once() + tidb_module.redis_client.set.assert_called_once() + + +def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch): + class _InsertStmt: + def __init__(self, table): + self.table = table + + def values(self, rows): + return {"table": self.table, "rows": rows} + + monkeypatch.setattr(tidb_module, "insert", lambda table: _InsertStmt(table)) + + conn = MagicMock() + transaction = MagicMock() + transaction.__enter__.return_value = None + transaction.__exit__.return_value = None + conn.begin.return_value = transaction + + connection_ctx = MagicMock() + connection_ctx.__enter__.return_value = conn + connection_ctx.__exit__.return_value = None + + engine = MagicMock() + engine.connect.return_value = connection_ctx + + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._engine = engine + vector._table = MagicMock(return_value="table") + + docs = [Document(page_content=f"text-{i}", metadata={"doc_id": f"id-{i}"}) for i in range(501)] + embeddings = [[float(i)] for i in range(501)] + + ids = vector.add_texts(docs, embeddings) + + assert ids[0] == "id-0" + assert len(ids) == 501 + assert conn.execute.call_count == 2 + + +@pytest.fixture +def tidb_vector_with_session(tidb_module, monkeypatch): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + session = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + return vector, session, tidb_module + + +# 1. search_by_full_text returns empty +def test_search_by_full_text_returns_empty(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + assert vector.search_by_full_text("query") == [] + + +# 2. text_exists returns True when ids found +def test_text_exists_returns_true_when_ids_found(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + + +# 3. text_exists returns False when no ids +def test_text_exists_returns_false_when_no_ids(tidb_vector_with_session): + vector, _, _ = tidb_vector_with_session + vector.get_ids_by_metadata_field = MagicMock(return_value=None) + assert vector.text_exists("doc-1") is False + + +# 4. delete_by_ids delegates to _delete_by_ids when ids found +def test_delete_by_ids_delegates_to_internal_delete(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [("id-a",), ("id-b",)] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-a", "doc-b"]) + vector._delete_by_ids.assert_called_once_with(["id-a", "id-b"]) + + +# 5. delete_by_ids skips when no ids found +def test_delete_by_ids_skips_when_no_ids_found(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + session.execute.return_value.fetchall.return_value = [] + vector._delete_by_ids = MagicMock() + # Use real get_ids_by_metadata_field + vector.get_ids_by_metadata_field = tidb_module.TiDBVector.get_ids_by_metadata_field.__get__( + vector, tidb_module.TiDBVector + ) + vector.delete_by_ids(["doc-c"]) + vector._delete_by_ids.assert_not_called() + + +# 6. get_ids_by_metadata_field returns ids and returns None +def test_get_ids_by_metadata_field_returns_ids_and_returns_none(tidb_vector_with_session): + vector, session, tidb_module = tidb_vector_with_session + # Returns ids + session.execute.return_value.fetchall.return_value = [("id-1",)] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") == ["id-1"] + # Returns None + session.execute.return_value.fetchall.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "doc-1") is None + + +# 1. _delete_by_ids raises on None +def test__delete_by_ids_raises_on_none(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + with pytest.raises(ValueError, match="No ids provided"): + vector._delete_by_ids(None) + + +# 2. _delete_by_ids returns True and calls execute +def test__delete_by_ids_returns_true_and_calls_execute(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-1"]) is True + conn.execute.assert_called_once() + + +# 3. _delete_by_ids returns False on RuntimeError +def test__delete_by_ids_returns_false_on_runtime_error(tidb_module): + class _IDColumn: + def in_(self, ids): + return ids + + class _Delete: + def where(self, condition): + return condition + + table = SimpleNamespace(c=SimpleNamespace(id=_IDColumn()), delete=lambda: _Delete()) + conn = MagicMock() + tx = MagicMock() + tx.__enter__.return_value = None + tx.__exit__.return_value = None + conn.begin.return_value = tx + conn_ctx = MagicMock() + conn_ctx.__enter__.return_value = conn + conn_ctx.__exit__.return_value = None + conn.execute.side_effect = RuntimeError("delete failed") + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._dimension = 2 + vector._engine = SimpleNamespace(connect=MagicMock(return_value=conn_ctx)) + vector._table = MagicMock(return_value=table) + assert vector._delete_by_ids(["id-2"]) is False + + +# 4. delete_by_metadata_field calls _delete_by_ids when ids found +def test_delete_by_metadata_field_calls__delete_by_ids_when_ids_found(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-3"]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-3") + vector._delete_by_ids.assert_called_once_with(["id-3"]) + + +# 5. delete_by_metadata_field does nothing when no ids +def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module): + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector.get_ids_by_metadata_field = MagicMock(return_value=[]) + vector._delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-4") + vector._delete_by_ids.assert_not_called() + + +# Test search_by_vector filters and scores +def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = [ + ('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2), + ('{"doc_id":"id-2","document_id":"d-2"}', "text-2", 0.4), + ] + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector._distance_func = "cosine" + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=2, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + assert len(docs) == 2 + assert docs[0].metadata["score"] == pytest.approx(0.8) + assert docs[1].metadata["score"] == pytest.approx(0.6) + sql = str(session.execute.call_args.args[0]) + params = session.execute.call_args.kwargs["params"] + assert "meta->>'$.document_id' in ('d-1', 'd-2')" in sql + assert params["distance"] == pytest.approx(0.5) + assert params["top_k"] == 2 + session.commit.assert_not_called() + + +# Test delete drops table +def test_delete_drops_table(tidb_module, monkeypatch): + session = MagicMock() + session.execute.return_value = None + session.commit = MagicMock() + + class _SessionCtx: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) + vector._collection_name = "collection_1" + vector._engine = MagicMock() + vector.delete() + drop_sql = str(session.execute.call_args.args[0]) + assert "DROP TABLE IF EXISTS collection_1" in drop_sql + session.commit.assert_called_once() + + +def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): + factory = tidb_module.TiDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(tidb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_HOST", "localhost") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PORT", 4000) + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_USER", "root") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_PASSWORD", "secret") + monkeypatch.setattr(tidb_module.dify_config, "TIDB_VECTOR_DATABASE", "dify") + monkeypatch.setattr(tidb_module.dify_config, "APPLICATION_NAME", "dify-app") + + with patch.object(tidb_module, "TiDBVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py new file mode 100644 index 0000000000..ac8a63a44b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py @@ -0,0 +1,186 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.rag.models.document import Document + + +def _build_fake_upstash_module(): + upstash_module = types.ModuleType("upstash_vector") + + class Vector: + def __init__(self, id, vector, metadata, data): + self.id = id + self.vector = vector + self.metadata = metadata + self.data = data + + class Index: + def __init__(self, url, token): + self.url = url + self.token = token + self.info = MagicMock(return_value=SimpleNamespace(dimension=8)) + self.upsert = MagicMock() + self.query = MagicMock(return_value=[]) + self.delete = MagicMock() + self.reset = MagicMock() + + upstash_module.Vector = Vector + upstash_module.Index = Index + return upstash_module + + +@pytest.fixture +def upstash_module(monkeypatch): + # Remove patched modules if present + for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + if modname in sys.modules: + monkeypatch.delitem(sys.modules, modname, raising=False) + monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) + module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + return module + + +def _config(module): + return module.UpstashVectorConfig(url="https://upstash.example", token="token-123") + + +@pytest.mark.parametrize( + ("field", "value", "message"), + [ + ("url", "", "Upstash URL is required"), + ("token", "", "Upstash Token is required"), + ], +) +def test_upstash_config_validation(upstash_module, field, value, message): + values = _config(upstash_module).model_dump() + values[field] = value + + with pytest.raises(ValidationError, match=message): + upstash_module.UpstashVectorConfig.model_validate(values) + + +def test_init_get_type_and_dimension(upstash_module, monkeypatch): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + assert vector.get_type() == upstash_module.VectorType.UPSTASH + assert vector._table_name == "collection_1" + assert vector._get_index_dimension() == 8 + + vector.index.info.return_value = SimpleNamespace(dimension=None) + assert vector._get_index_dimension() == 1536 + + vector.index.info.return_value = None + assert vector._get_index_dimension() == 1536 + + monkeypatch.setattr(upstash_module, "uuid4", lambda: "generated-uuid") + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.add_texts(docs, [[0.1, 0.2]]) + + vector.index.upsert.assert_called_once() + upsert_vectors = vector.index.upsert.call_args.kwargs["vectors"] + assert upsert_vectors[0].id == "generated-uuid" + + +def test_create_text_exists_and_delete_by_ids(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1]]) + vector.add_texts.assert_called_once_with(docs, [[0.1]]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-1"]) + assert vector.text_exists("doc-1") is True + vector.get_ids_by_metadata_field.return_value = [] + assert vector.text_exists("doc-1") is False + + vector.get_ids_by_metadata_field = MagicMock(side_effect=[["item-1"], [], ["item-2"]]) + vector._delete_by_ids = MagicMock() + vector.delete_by_ids(["doc-1", "doc-2", "doc-3"]) + vector._delete_by_ids.assert_called_once_with(ids=["item-1", "item-2"]) + + +def test_delete_helpers_and_search(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + + vector._delete_by_ids([]) + vector.index.delete.assert_not_called() + vector._delete_by_ids(["a", "b"]) + vector.index.delete.assert_called_once_with(ids=["a", "b"]) + + vector.index.query.return_value = [SimpleNamespace(id="x-1"), SimpleNamespace(id="x-2")] + ids = vector.get_ids_by_metadata_field("doc_id", "doc-1") + assert ids == ["x-1", "x-2"] + query_kwargs = vector.index.query.call_args.kwargs + assert query_kwargs["top_k"] == 1000 + assert query_kwargs["filter"] == "doc_id = 'doc-1'" + + vector._delete_by_ids = MagicMock() + vector.get_ids_by_metadata_field = MagicMock(return_value=["x-1"]) + vector.delete_by_metadata_field("doc_id", "doc-1") + vector._delete_by_ids.assert_called_once_with(["x-1"]) + + vector._delete_by_ids.reset_mock() + vector.get_ids_by_metadata_field.return_value = [] + vector.delete_by_metadata_field("doc_id", "doc-2") + vector._delete_by_ids.assert_not_called() + + +def test_search_by_vector_filter_threshold_and_delete(upstash_module): + vector = upstash_module.UpstashVector("collection_1", _config(upstash_module)) + vector.index.query.return_value = [ + SimpleNamespace(metadata={"document_id": "d-1"}, data="text-1", score=0.9), + SimpleNamespace(metadata={"document_id": "d-2"}, data="text-2", score=0.3), + SimpleNamespace(metadata=None, data="text-3", score=0.99), + SimpleNamespace(metadata={"document_id": "d-4"}, data=None, score=0.99), + ] + + docs = vector.search_by_vector( + [0.1, 0.2], + top_k=3, + score_threshold=0.5, + document_ids_filter=["d-1", "d-2"], + ) + + assert len(docs) == 1 + assert docs[0].page_content == "text-1" + assert docs[0].metadata["score"] == pytest.approx(0.9) + + search_kwargs = vector.index.query.call_args.kwargs + assert search_kwargs["top_k"] == 3 + assert search_kwargs["filter"] == "document_id in ('d-1', 'd-2')" + + assert vector.search_by_full_text("query") == [] + + vector.delete() + vector.index.reset.assert_called_once() + + +def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch): + factory = upstash_module.UpstashVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(upstash_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_URL", "https://upstash.example") + monkeypatch.setattr(upstash_module.dify_config, "UPSTASH_VECTOR_TOKEN", "token-123") + + with patch.object(upstash_module, "UpstashVector", return_value="vector") as vector_cls: + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py new file mode 100644 index 0000000000..9da92af2d0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py @@ -0,0 +1,310 @@ +import importlib +import json +import sys +import types +from collections import UserDict +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.models.document import Document + + +def _build_fake_vikingdb_modules(): + volcengine = types.ModuleType("volcengine") + volcengine.__path__ = [] + viking_db = types.ModuleType("volcengine.viking_db") + + class Data(UserDict): + def __init__(self, payload): + super().__init__(payload) + self.fields = payload + + class DistanceType: + L2 = "L2" + + class IndexType: + HNSW = "HNSW" + + class QuantType: + Float = "Float" + + class FieldType: + String = "string" + Text = "text" + Vector = "vector" + + class Field: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class VectorIndexParams: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class _Collection: + def __init__(self): + self.upsert_data = MagicMock() + self.fetch_data = MagicMock(return_value=None) + self.delete_data = MagicMock() + + class _Index: + def __init__(self): + self.search = MagicMock(return_value=[]) + self.search_by_vector = MagicMock(return_value=[]) + + class VikingDBService: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_collection = MagicMock() + self.create_index = MagicMock() + self.drop_index = MagicMock() + self.drop_collection = MagicMock() + self._collection = _Collection() + self._index = _Index() + self.get_collection = MagicMock(return_value=self._collection) + self.get_index = MagicMock(return_value=self._index) + + viking_db.Data = Data + viking_db.DistanceType = DistanceType + viking_db.Field = Field + viking_db.FieldType = FieldType + viking_db.IndexType = IndexType + viking_db.QuantType = QuantType + viking_db.VectorIndexParams = VectorIndexParams + viking_db.VikingDBService = VikingDBService + + return {"volcengine": volcengine, "volcengine.viking_db": viking_db} + + +@pytest.fixture +def vikingdb_module(monkeypatch): + for name, module in _build_fake_vikingdb_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + + return importlib.reload(module) + + +def _config(module): + return module.VikingDBConfig( + access_key="ak", + secret_key="sk", + host="host", + region="region", + scheme="https", + connection_timeout=10, + socket_timeout=20, + ) + + +def test_init_get_type_and_has_checks(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + assert vector.get_type() == vikingdb_module.VectorType.VIKINGDB + assert vector._index_name == "collection_1_idx" + + assert vector._has_collection() is True + assert vector._has_index() is True + + vector._client.get_collection.side_effect = RuntimeError("missing") + assert vector._has_collection() is False + vector._client.get_collection.side_effect = None + + vector._client.get_index.side_effect = RuntimeError("missing") + assert vector._has_index() is False + + +def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch): + lock = MagicMock() + lock.__enter__.return_value = None + lock.__exit__.return_value = None + monkeypatch.setattr(vikingdb_module.redis_client, "lock", MagicMock(return_value=lock)) + monkeypatch.setattr(vikingdb_module.redis_client, "set", MagicMock()) + + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=1)) + vector._create_collection(3) + vector._client.create_collection.assert_not_called() + vector._client.create_index.assert_not_called() + + monkeypatch.setattr(vikingdb_module.redis_client, "get", MagicMock(return_value=None)) + vector._has_collection = MagicMock(return_value=False) + vector._has_index = MagicMock(return_value=False) + vector._create_collection(4) + + vector._client.create_collection.assert_called_once() + vector._client.create_index.assert_called_once() + vikingdb_module.redis_client.set.assert_called_once() + + +def test_create_and_add_texts(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._create_collection = MagicMock() + vector.add_texts = MagicMock() + + docs = [Document(page_content="hello", metadata={"doc_id": "id-1"})] + vector.create(docs, [[0.1, 0.2]]) + + vector._create_collection.assert_called_once_with(2) + vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]]) + + vector = vikingdb_module.VikingDBVector("collection_2", "group-2", _config(vikingdb_module)) + docs = [ + Document(page_content="a", metadata={"doc_id": "id-a", "document_id": "d-1"}), + Document(page_content="b", metadata={"doc_id": "id-b", "document_id": "d-2"}), + ] + vector.add_texts(docs, [[0.1], [0.2]]) + + vector._client.get_collection.assert_called() + upsert_docs = vector._client.get_collection.return_value.upsert_data.call_args.args[0] + assert upsert_docs[0][vikingdb_module.vdb_Field.PRIMARY_KEY] == "id-a" + assert upsert_docs[0][vikingdb_module.vdb_Field.GROUP_KEY] == "group-2" + + +def test_text_exists_and_delete_operations(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace(fields={"message": "ok"}) + assert vector.text_exists("id-1") is True + + vector._client.get_collection.return_value.fetch_data.return_value = SimpleNamespace( + fields={"message": "data does not exist"} + ) + assert vector.text_exists("id-1") is False + + vector._client.get_collection.return_value.fetch_data.return_value = None + assert vector.text_exists("id-1") is False + + vector.delete_by_ids(["id-1"]) + vector._client.get_collection.return_value.delete_data.assert_called_once_with(["id-1"]) + + vector.get_ids_by_metadata_field = MagicMock(return_value=["id-2"]) + vector.delete_by_ids = MagicMock() + vector.delete_by_metadata_field("doc_id", "doc-1") + vector.delete_by_ids.assert_called_once_with(["id-2"]) + + +def test_get_ids_and_search_helpers(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + + vector._client.get_index.return_value.search.return_value = [] + assert vector.get_ids_by_metadata_field("doc_id", "x") == [] + + vector._client.get_index.return_value.search.return_value = [ + SimpleNamespace(id="a", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "x"})}), + SimpleNamespace(id="b", fields={vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"doc_id": "y"})}), + SimpleNamespace(id="c", fields={}), + ] + assert vector.get_ids_by_metadata_field("doc_id", "x") == ["a"] + + empty_docs = vector._get_search_res([], score_threshold=0.1) + assert empty_docs == [] + + results = [ + SimpleNamespace( + id="a", + score=0.3, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-a", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-1"}), + }, + ), + SimpleNamespace( + id="b", + score=0.9, + fields={ + vikingdb_module.vdb_Field.CONTENT_KEY: "doc-b", + vikingdb_module.vdb_Field.METADATA_KEY: json.dumps({"document_id": "d-2"}), + }, + ), + ] + + docs = vector._get_search_res(results, score_threshold=0.2) + assert [doc.page_content for doc in docs] == ["doc-b", "doc-a"] + + vector._client.get_index.return_value.search_by_vector.return_value = results + filtered_docs = vector.search_by_vector([0.1], top_k=2, score_threshold=0.2, document_ids_filter=["d-2"]) + assert len(filtered_docs) == 1 + assert filtered_docs[0].page_content == "doc-b" + assert vector.search_by_full_text("query") == [] + + +def test_delete_drops_index_and_collection_when_present(vikingdb_module): + vector = vikingdb_module.VikingDBVector("collection_1", "group-1", _config(vikingdb_module)) + vector._has_index = MagicMock(return_value=True) + vector._has_collection = MagicMock(return_value=True) + + vector.delete() + + vector._client.drop_index.assert_called_once_with("collection_1", "collection_1_idx") + vector._client.drop_collection.assert_called_once_with("collection_1") + + vector._client.drop_index.reset_mock() + vector._client.drop_collection.reset_mock() + vector._has_index.return_value = False + vector._has_collection.return_value = False + vector.delete() + + vector._client.drop_index.assert_not_called() + vector._client.drop_collection.assert_not_called() + + +def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch): + factory = vikingdb_module.VikingDBVectorFactory() + dataset_with_index = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "EXISTING_COLLECTION"}}, + index_struct=None, + ) + dataset_without_index = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + + monkeypatch.setattr(vikingdb_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION") + + with patch.object(vikingdb_module, "VikingDBVector", return_value="vector") as vector_cls: + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_CONNECTION_TIMEOUT", 10) + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SOCKET_TIMEOUT", 20) + + result_1 = factory.init_vector(dataset_with_index, attributes=[], embeddings=MagicMock()) + result_2 = factory.init_vector(dataset_without_index, attributes=[], embeddings=MagicMock()) + + assert result_1 == "vector" + assert result_2 == "vector" + assert vector_cls.call_args_list[0].kwargs["collection_name"] == "existing_collection" + assert vector_cls.call_args_list[1].kwargs["collection_name"] == "auto_collection" + assert dataset_without_index.index_struct is not None + + +@pytest.mark.parametrize( + ("field", "message"), + [ + ("VIKINGDB_ACCESS_KEY", "VIKINGDB_ACCESS_KEY should not be None"), + ("VIKINGDB_SECRET_KEY", "VIKINGDB_SECRET_KEY should not be None"), + ("VIKINGDB_HOST", "VIKINGDB_HOST should not be None"), + ("VIKINGDB_REGION", "VIKINGDB_REGION should not be None"), + ("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"), + ], +) +def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message): + factory = vikingdb_module.VikingDBVectorFactory() + dataset = SimpleNamespace( + id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None + ) + + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_ACCESS_KEY", "ak") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SECRET_KEY", "sk") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_HOST", "host") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_REGION", "region") + monkeypatch.setattr(vikingdb_module.dify_config, "VIKINGDB_SCHEME", "https") + monkeypatch.setattr(vikingdb_module.dify_config, field, None) + + with pytest.raises(ValueError, match=message): + factory.init_vector(dataset, attributes=[], embeddings=MagicMock()) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py index 3bd656ba84..69d1833001 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py @@ -7,10 +7,14 @@ Focuses on verifying that doc_type is properly handled in: - Full-text search result metadata (search_by_full_text) """ +import datetime +import json import unittest from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -32,6 +36,10 @@ class TestWeaviateVector(unittest.TestCase): def tearDown(self): weaviate_vector_module._weaviate_client = None + def test_config_requires_endpoint(self): + with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): + WeaviateConfig(endpoint="") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" @@ -46,6 +54,85 @@ class TestWeaviateVector(unittest.TestCase): ) return wv, mock_client + def test_shutdown_client_logs_debug_when_close_fails(self): + mock_client = MagicMock() + mock_client.close.side_effect = RuntimeError("close failed") + weaviate_vector_module._weaviate_client = mock_client + + with patch.object(weaviate_vector_module.logger, "debug") as mock_debug: + weaviate_vector_module._shutdown_weaviate_client() + + assert weaviate_vector_module._weaviate_client is None + mock_client.close.assert_called_once() + mock_debug.assert_called_once() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.return_value = True + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): + cached_client = MagicMock() + cached_client.is_ready.side_effect = [False, True] + weaviate_vector_module._weaviate_client = cached_client + + wv = WeaviateVector.__new__(WeaviateVector) + + client = wv._init_client(self.config) + + assert client is cached_client + mock_connect.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + config = WeaviateConfig( + endpoint="https://weaviate.example.com", + grpc_endpoint="grpc.example.com:6000", + api_key="test-key", + batch_size=50, + ) + + client = wv._init_client(config) + + assert client is mock_client + assert mock_connect.call_args.kwargs == { + "http_host": "weaviate.example.com", + "http_port": 443, + "http_secure": True, + "grpc_host": "grpc.example.com", + "grpc_port": 6000, + "grpc_secure": False, + "auth_credentials": "auth-token", + "skip_init_checks": True, + } + mock_api_key.assert_called_once_with("test-key") + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + def test_init_client_raises_when_database_not_ready(self, mock_connect): + mock_client = MagicMock() + mock_client.is_ready.return_value = False + mock_connect.return_value = mock_client + + wv = WeaviateVector.__new__(WeaviateVector) + + with pytest.raises(ConnectionError, match="Vector database is not ready"): + wv._init_client(self.config) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" @@ -62,6 +149,40 @@ class TestWeaviateVector(unittest.TestCase): assert wv._collection_name == self.collection_name assert "doc_type" in wv._attributes + def test_get_type_and_to_index_struct(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + + assert wv.get_type() == weaviate_vector_module.VectorType.WEAVIATE + assert wv.to_index_struct() == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": self.collection_name}, + } + + def test_get_collection_name_uses_existing_class_prefix_and_appends_suffix(self): + dataset = SimpleNamespace(index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection"}}, id="ds-1") + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv.get_collection_name(dataset) == "ExistingCollection_Node" + + def test_get_collection_name_generates_name_from_dataset_id(self): + dataset = SimpleNamespace(index_struct_dict=None, id="ds-2") + wv = WeaviateVector.__new__(WeaviateVector) + + with patch.object(weaviate_vector_module.Dataset, "gen_collection_name_by_id", return_value="Generated_Node"): + assert wv.get_collection_name(dataset) == "Generated_Node" + + def test_create_calls_collection_setup_then_add_texts(self): + doc = Document(page_content="hello", metadata={}) + wv = WeaviateVector.__new__(WeaviateVector) + wv._create_collection = MagicMock() + wv.add_texts = MagicMock() + + wv.create([doc], [[0.1, 0.2]]) + + wv._create_collection.assert_called_once() + wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") @@ -111,6 +232,44 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = 1 + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._ensure_properties = MagicMock() + + wv._create_collection() + + wv._client.collections.exists.assert_not_called() + wv._ensure_properties.assert_not_called() + mock_redis.set.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + def test_create_collection_logs_and_reraises_errors(self, mock_redis): + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock(return_value=False) + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = RuntimeError("create failed") + + with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: + with pytest.raises(RuntimeError, match="create failed"): + wv._create_collection() + + mock_exception.assert_called_once() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" @@ -146,6 +305,29 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [SimpleNamespace(name="text")] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + add_calls = mock_col.config.add_property.call_args_list + added_names = [call.args[0].name for call in add_calls] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" @@ -179,6 +361,30 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [] + mock_col.config.get.return_value = mock_cfg + mock_col.config.add_property.side_effect = RuntimeError("cannot add") + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: + wv._ensure_properties() + + assert mock_warning.call_count == 4 + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -226,6 +432,58 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = { + "text": "fallback distance result", + "document_id": "doc-1", + "doc_id": "segment-1", + } + mock_obj.metadata = None + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.near_vector.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector( + query_vector=[0.2] * 3, + document_ids_filter=["doc-1"], + top_k=2, + score_threshold=-1, + ) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.0 + assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_vector(query_vector=[0.1] * 3) == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" @@ -268,6 +526,49 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_obj = MagicMock() + mock_obj.properties = {"text": "bm25 result", "doc_id": "segment-1"} + mock_obj.vector = [0.3, 0.4] + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.bm25.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="bm25", document_ids_filter=["doc-1"]) + + assert len(docs) == 1 + assert docs[0].vector == [0.3, 0.4] + assert mock_col.query.bm25.call_args.kwargs["filters"] is not None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv.search_by_full_text(query="missing") == [] + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" @@ -310,6 +611,135 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + mock_batch = MagicMock() + mock_batch.__enter__ = MagicMock(return_value=mock_batch) + mock_batch.__exit__ = MagicMock(return_value=False) + mock_col.batch.dynamic.return_value = mock_batch + + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + doc = Document(page_content="text", metadata={"created_at": created_at}) + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + with ( + patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), + patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + ): + ids = wv.add_texts(documents=[doc], embeddings=[[]]) + + assert ids == ["fallback-uuid"] + call_kwargs = mock_batch.add_object.call_args + assert call_kwargs.kwargs["uuid"] == "fallback-uuid" + assert call_kwargs.kwargs["vector"] is None + assert call_kwargs.kwargs["properties"]["created_at"] == created_at.isoformat() + + def test_is_uuid_handles_invalid_values(self): + wv = WeaviateVector.__new__(WeaviateVector) + + assert wv._is_uuid("123e4567-e89b-12d3-a456-426614174000") is True + assert wv._is_uuid("not-a-uuid") is False + + def test_delete_by_metadata_field_returns_when_collection_is_missing(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = False + + wv.delete_by_metadata_field("doc_id", "segment-1") + + wv._client.collections.use.assert_not_called() + + def test_delete_by_metadata_field_deletes_matching_objects(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + + wv.delete_by_metadata_field("doc_id", "segment-1") + + mock_col.data.delete_many.assert_called_once() + + def test_delete_removes_collection_when_present(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + + wv.delete() + + wv._client.collections.delete.assert_called_once_with(self.collection_name) + + def test_text_exists_handles_missing_and_present_documents(self): + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.query.fetch_objects.return_value = SimpleNamespace(objects=[SimpleNamespace()]) + + assert wv.text_exists("segment-1") is False + assert wv.text_exists("segment-1") is True + + def test_delete_by_ids_handles_missing_collections_and_404s(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.side_effect = [False, True] + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = [FakeUnexpectedStatusCodeError(404), None] + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + wv.delete_by_ids(["ignored"]) + wv.delete_by_ids(["missing-id", "ok-id"]) + + assert mock_col.data.delete_by_id.call_count == 2 + + def test_delete_by_ids_reraises_non_404_errors(self): + class FakeUnexpectedStatusCodeError(Exception): + def __init__(self, status_code): + super().__init__(f"status={status_code}") + self.status_code = status_code + + wv = WeaviateVector.__new__(WeaviateVector) + wv._collection_name = self.collection_name + wv._client = MagicMock() + wv._client.collections.exists.return_value = True + mock_col = MagicMock() + wv._client.collections.use.return_value = mock_col + mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) + + with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): + with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): + wv.delete_by_ids(["bad-id"]) + + def test_json_serializable_converts_datetime(self): + wv = WeaviateVector.__new__(WeaviateVector) + created_at = datetime.datetime(2024, 1, 2, 3, 4, 5, tzinfo=datetime.UTC) + + assert wv._json_serializable(created_at) == created_at.isoformat() + assert wv._json_serializable("plain") == "plain" + class TestVectorDefaultAttributes(unittest.TestCase): """Tests for Vector class default attributes list.""" @@ -331,5 +761,65 @@ class TestVectorDefaultAttributes(unittest.TestCase): assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" +class TestWeaviateVectorFactory(unittest.TestCase): + def test_init_vector_uses_existing_dataset_index_struct(self): + dataset = SimpleNamespace( + id="dataset-1", + index_struct_dict={"vector_store": {"class_prefix": "ExistingCollection_Node"}}, + index_struct=None, + ) + attributes = ["doc_id"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + config = mock_vector.call_args.kwargs["config"] + assert mock_vector.call_args.kwargs["collection_name"] == "ExistingCollection_Node" + assert mock_vector.call_args.kwargs["attributes"] == attributes + assert config.endpoint == "http://localhost:8080" + assert config.grpc_endpoint == "localhost:50051" + assert config.api_key == "api-key" + assert config.batch_size == 88 + assert dataset.index_struct is None + + def test_init_vector_generates_collection_and_updates_index_struct(self): + dataset = SimpleNamespace(id="dataset-2", index_struct_dict=None, index_struct=None) + attributes = ["doc_id", "doc_type"] + + with ( + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_ENDPOINT", "http://localhost:8080"), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", ""), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", None), + patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 100), + patch.object( + weaviate_vector_module.Dataset, + "gen_collection_name_by_id", + return_value="GeneratedCollection_Node", + ), + patch( + "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" + ) as mock_vector, + ): + factory = weaviate_vector_module.WeaviateVectorFactory() + result = factory.init_vector(dataset, attributes, MagicMock()) + + assert result == "vector" + assert mock_vector.call_args.kwargs["collection_name"] == "GeneratedCollection_Node" + assert json.loads(dataset.index_struct) == { + "type": weaviate_vector_module.VectorType.WEAVIATE, + "vector_store": {"class_prefix": "GeneratedCollection_Node"}, + } + + if __name__ == "__main__": unittest.main() diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 13285cdad0..3ba0628fe2 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -163,7 +163,7 @@ class TestDatasetDocumentStoreAddDocuments: with ( patch("core.rag.docstore.dataset_docstore.db") as mock_db, - patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + patch("core.rag.docstore.dataset_docstore.ModelManager.for_tenant") as mock_manager_class, ): mock_session = MagicMock() mock_db.session = mock_session diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index a0db25174d..bfa78fe565 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -15,8 +15,8 @@ import pytest from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding @@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -64,7 +65,7 @@ class TestCacheEmbeddingMultimodalDocuments: def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): """Test embedding a single multimodal document when cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "file123", "content": "test content"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: @@ -316,13 +317,14 @@ class TestCacheEmbeddingMultimodalQuery: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance def test_embed_multimodal_query_cache_miss(self, mock_model_instance): """Test embedding multimodal query when Redis cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) document = {"file_id": "file123"} vector = np.random.randn(1536) @@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -532,24 +535,13 @@ class TestCacheEmbeddingQueryErrors: class TestCacheEmbeddingInitialization: """Test suite for CacheEmbedding initialization.""" - def test_initialization_with_user(self): - """Test CacheEmbedding initialization with user parameter.""" - model_instance = Mock() - model_instance.model = "test-model" - model_instance.provider = "test-provider" - - cache_embedding = CacheEmbedding(model_instance, user="test-user") - - assert cache_embedding._model_instance == model_instance - assert cache_embedding._user == "test-user" - - def test_initialization_without_user(self): - """Test CacheEmbedding initialization without user parameter.""" + def test_initialization_sets_model_instance(self): + """Test CacheEmbedding initialization stores the provided model instance.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance) assert cache_embedding._model_instance == model_instance - assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 6e71f0c61f..392f0b458b 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,9 +53,9 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError, @@ -134,7 +134,7 @@ class TestCacheEmbeddingDocuments: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Python is a programming language"] # Mock database query to return no cached embedding (cache miss) @@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments: # Verify model was invoked with correct parameters mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=texts, - user="test-user", input_type=EmbeddingInputType.DOCUMENT, ) @@ -612,7 +611,7 @@ class TestCacheEmbeddingQuery: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is Python?" # Create embedding result @@ -651,7 +650,6 @@ class TestCacheEmbeddingQuery: # Verify model was invoked with QUERY input type mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user="test-user", input_type=EmbeddingInputType.QUERY, ) @@ -1568,25 +1566,16 @@ class TestEmbeddingEdgeCases: norm = np.linalg.norm(emb) assert abs(norm - 1.0) < 0.01 - def test_embed_query_with_user_context(self, mock_model_instance): - """Test query embedding with user context parameter. + def test_embed_query_uses_bound_model_instance(self, mock_model_instance): + """Test query embedding using the provided model instance. Verifies: - - User parameter is passed correctly to model - - User context is used for tracking/logging - - Embedding generation works with user context - - Context: - -------- - The user parameter is important for: - 1. Usage tracking per user - 2. Rate limiting per user - 3. Audit logging - 4. Personalization (in some models) + - Embedding generation works with the injected model instance + - Query input type is preserved + - No extra binding step is required at call time """ # Arrange - user_id = "user-12345" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is machine learning?" # Create embedding @@ -1620,24 +1609,20 @@ class TestEmbeddingEdgeCases: assert isinstance(result, list) assert len(result) == 1536 - # Verify user parameter was passed to model mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user=user_id, input_type=EmbeddingInputType.QUERY, ) - def test_embed_documents_with_user_context(self, mock_model_instance): - """Test document embedding with user context parameter. + def test_embed_documents_uses_bound_model_instance(self, mock_model_instance): + """Test document embedding using the provided model instance. Verifies: - - User parameter is passed correctly for document embeddings - - Batch processing maintains user context - - User tracking works across batches + - Batch processing uses the injected model instance + - Document input type is preserved """ # Arrange - user_id = "user-67890" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Document 1", "Document 2"] # Create embeddings @@ -1673,10 +1658,8 @@ class TestEmbeddingEdgeCases: # Assert assert len(result) == 2 - # Verify user parameter was passed mock_model_instance.invoke_text_embedding.assert_called_once() call_args = mock_model_instance.invoke_text_embedding.call_args - assert call_args.kwargs["user"] == user_id assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 2add12fd09..db49221583 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -164,6 +164,13 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="No page found"): app.check_crawl_status("job-1") + def test_check_crawl_status_completed_with_null_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": None, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") payload = {"status": "processing", "total": 5, "completed": 1, "data": []} @@ -203,6 +210,77 @@ class TestFirecrawlApp: with pytest.raises(Exception, match="Error saving crawl data"): app.check_crawl_status("job-err") + def test_check_crawl_status_follows_pagination(self, mocker: MockerFixture): + """When status is completed and next is present, follow pagination to collect all pages.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + page2 = { + "status": "completed", + "total": 3, + "completed": 3, + "next": "https://custom.firecrawl.dev/v2/crawl/job-42?skip=2", + "data": [{"metadata": {"title": "p2", "description": "", "sourceURL": "https://p2"}, "markdown": "m2"}], + } + page3 = { + "status": "completed", + "total": 3, + "completed": 3, + "data": [{"metadata": {"title": "p3", "description": "", "sourceURL": "https://p3"}, "markdown": "m3"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(200, page2), _response(200, page3)]) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert [d["title"] for d in result["data"]] == ["p1", "p2", "p3"] + + def test_check_crawl_status_pagination_error_raises(self, mocker: MockerFixture): + """An error while fetching a paginated page raises an exception; no partial data is returned.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + page1 = { + "status": "completed", + "total": 2, + "completed": 2, + "next": "https://custom.firecrawl.dev/v2/crawl/job-99?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mocker.patch("httpx.get", side_effect=[_response(200, page1), _response(500, {"error": "server error"})]) + + with pytest.raises(Exception, match="fetch next crawl page"): + app.check_crawl_status("job-99") + + def test_check_crawl_status_pagination_capped_at_total(self, mocker: MockerFixture): + """Pagination stops once pages_processed reaches total, even if next is present.""" + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + # total=1: only the first page should be processed; next must not be followed + page1 = { + "status": "completed", + "total": 1, + "completed": 1, + "next": "https://custom.firecrawl.dev/v2/crawl/job-cap?skip=1", + "data": [{"metadata": {"title": "p1", "description": "", "sourceURL": "https://p1"}, "markdown": "m1"}], + } + mock_get = mocker.patch("httpx.get", return_value=_response(200, page1)) + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-cap") + + assert len(result["data"]) == 1 + mock_get.assert_called_once() # initial fetch only; next URL is not followed due to cap + def test_extract_common_fields_and_status_formatter(self): app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index e6cc582398..c861871f02 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -4,11 +4,12 @@ from unittest.mock import Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -21,7 +22,7 @@ class TestParagraphIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -167,7 +168,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_with_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -178,7 +179,7 @@ class TestParagraphIndexProcessor: def test_load_uses_keyword_add_texts_without_keywords_when_economy( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="chunk", metadata={})] with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: @@ -208,7 +209,7 @@ class TestParagraphIndexProcessor: def test_clean_economy_deletes_summaries_and_keywords( self, processor: ParagraphIndexProcessor, dataset: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( @@ -222,7 +223,7 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.delete.assert_called_once() def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: processor.clean(dataset, ["node-2"], with_keywords=True) @@ -267,7 +268,7 @@ class TestParagraphIndexProcessor: def test_index_list_chunks_economy( self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY with ( patch( "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", @@ -399,7 +400,9 @@ class TestParagraphIndexProcessor: model_instance.invoke_llm.return_value = self._llm_result("text summary") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -410,7 +413,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, usage = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -433,7 +436,9 @@ class TestParagraphIndexProcessor: image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -448,7 +453,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, _ = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -469,7 +474,9 @@ class TestParagraphIndexProcessor: image_file = SimpleNamespace() with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -486,7 +493,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() with pytest.raises(ValueError, match="Expected LLMResult"): ParagraphIndexProcessor.generate_summary( "tenant-1", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index 5c78cae7c1..b1ed735ee7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -19,7 +20,7 @@ class TestParentChildIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 99323eeec9..98c47bec8f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -6,6 +6,7 @@ import pytest from werkzeug.datastructures import FileStorage from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.models.document import AttachmentDocument, Document @@ -33,7 +34,7 @@ class TestQAIndexProcessor: dataset = Mock() dataset.id = "dataset-1" dataset.tenant_id = "tenant-1" - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.is_multimodal = True return dataset @@ -207,7 +208,7 @@ class TestQAIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY docs = [Document(page_content="Q1", metadata={"answer": "A1"})] with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: @@ -298,7 +299,7 @@ class TestQAIndexProcessor: def test_index_requires_high_quality( self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock ) -> None: - dataset.indexing_technique = "economy" + dataset.indexing_technique = IndexTechniqueType.ECONOMY qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) with ( diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b011ade884..059876d410 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,9 +61,9 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument @@ -76,7 +76,7 @@ from models.dataset import Document as DatasetDocument def create_mock_dataset( dataset_id: str | None = None, tenant_id: str | None = None, - indexing_technique: str = "high_quality", + indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", ) -> Mock: @@ -445,7 +445,7 @@ class TestIndexingRunnerTransform: """Mock all external dependencies for transform tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, ): yield { "db": mock_db, @@ -458,7 +458,7 @@ class TestIndexingRunnerTransform: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -482,7 +482,8 @@ class TestIndexingRunnerTransform: # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [ @@ -509,7 +510,7 @@ class TestIndexingRunnerTransform: assert len(result) == 2 assert result[0].page_content == "Chunk 1" assert result[1].page_content == "Chunk 2" - runner.model_manager.get_model_instance.assert_called_once_with( + model_manager.get_model_instance.assert_called_once_with( tenant_id=sample_dataset.tenant_id, provider=sample_dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -521,7 +522,8 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + model_manager = mock_dependencies["model_manager"].return_value + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() transformed_docs = [ @@ -539,14 +541,15 @@ class TestIndexingRunnerTransform: # Assert assert len(result) == 1 - runner.model_manager.get_model_instance.assert_not_called() + model_manager.get_model_instance.assert_not_called() def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs): """Test transformation with custom segmentation rules.""" # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})] @@ -586,7 +589,7 @@ class TestIndexingRunnerLoad: """Mock all external dependencies for load tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.current_app") as mock_app, patch("core.indexing_runner.threading.Thread") as mock_thread, patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor, @@ -605,7 +608,7 @@ class TestIndexingRunnerLoad: dataset = Mock(spec=Dataset) dataset.id = str(uuid.uuid4()) dataset.tenant_id = str(uuid.uuid4()) - dataset.indexing_technique = "high_quality" + dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.embedding_model_provider = "openai" dataset.embedding_model = "text-embedding-ada-002" return dataset @@ -645,7 +648,8 @@ class TestIndexingRunnerLoad: runner = IndexingRunner() mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -664,7 +668,7 @@ class TestIndexingRunnerLoad: runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents) # Assert - runner.model_manager.get_model_instance.assert_called_once() + model_manager.get_model_instance.assert_called_once() # Verify executor was used for parallel processing assert mock_executor_instance.submit.called @@ -674,7 +678,7 @@ class TestIndexingRunnerLoad: """Test loading with economy indexing (keyword only).""" # Arrange runner = IndexingRunner() - sample_dataset.indexing_technique = "economy" + sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -701,7 +705,7 @@ class TestIndexingRunnerLoad: # Arrange runner = IndexingRunner() sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - sample_dataset.indexing_technique = "high_quality" + sample_dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY # Add child documents for doc in sample_documents: @@ -714,7 +718,8 @@ class TestIndexingRunnerLoad: mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -754,7 +759,7 @@ class TestIndexingRunnerRun: with ( patch("core.indexing_runner.db") as mock_db, patch("core.indexing_runner.IndexProcessorFactory") as mock_factory, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.storage") as mock_storage, patch("core.indexing_runner.threading.Thread") as mock_thread, ): @@ -795,7 +800,7 @@ class TestIndexingRunnerRun: mock_dataset = Mock(spec=Dataset) mock_dataset.id = doc.dataset_id mock_dataset.tenant_id = doc.tenant_id - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) @@ -949,7 +954,7 @@ class TestIndexingRunnerRun: mock_dependencies["db"].session.get.side_effect = get_side_effect mock_dataset = Mock(spec=Dataset) - mock_dataset.indexing_technique = "economy" + mock_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_dependencies["db"].session.query.return_value.filter_by.return_value.first.return_value = mock_dataset mock_process_rule = Mock(spec=DatasetProcessRule) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index b150d677f1..415597f336 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -28,7 +28,7 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: @@ -57,7 +57,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -352,12 +352,14 @@ class TestRerankModelRunner: # Assert: Empty result is returned assert len(result) == 0 - def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents): - """Test that user parameter is passed to model invocation. + def test_run_uses_bound_model_instance( + self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager + ): + """Test that rerank uses the bound model instance directly. Verifies: - - User ID is correctly forwarded to the model - - Model receives all expected parameters + - The injected model instance is used for invocation + - No late rebinding occurs through ModelManager.get_model_instance """ # Arrange: Mock rerank result mock_rerank_result = RerankResult( @@ -368,16 +370,18 @@ class TestRerankModelRunner: ) mock_model_instance.invoke_rerank.return_value = mock_rerank_result - # Act: Run reranking with user parameter + # Act: Run reranking result = rerank_runner.run( query="test", documents=sample_documents, - user="user123", ) - # Assert: User parameter is passed to model + # Assert: The injected model instance is invoked directly. + assert len(result) == 1 + mock_model_manager.return_value.get_model_instance.assert_not_called() call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs - assert call_kwargs["user"] == "user123" + assert call_kwargs["query"] == "test" + assert "user" not in call_kwargs class _ForwardingBaseRerankRunner(BaseRerankRunner): @@ -387,7 +391,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: return super().run( @@ -395,7 +398,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents=documents, score_threshold=score_threshold, top_n=top_n, - user=user, query_type=query_type, ) @@ -424,7 +426,7 @@ class TestRerankModelRunnerMultimodal: Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), ] - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) @@ -441,7 +443,7 @@ class TestRerankModelRunnerMultimodal: ) with ( - patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm, patch.object( rerank_runner, "fetch_multimodal_rerank", @@ -539,8 +541,10 @@ class TestRerankModelRunnerMultimodal: ) mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + session = MagicMock() + session.query.return_value = query_chain with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), ): result, unique_documents = rerank_runner.fetch_multimodal_rerank( @@ -548,7 +552,6 @@ class TestRerankModelRunnerMultimodal: documents=[text_doc], score_threshold=0.2, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -557,7 +560,7 @@ class TestRerankModelRunnerMultimodal: invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE assert invoke_kwargs["docs"][0]["content"] == "text-content" - assert invoke_kwargs["user"] == "user-1" + assert "user" not in invoke_kwargs def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): query_chain = Mock() @@ -595,7 +598,7 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager: yield mock_manager @pytest.fixture @@ -1145,7 +1148,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1257,7 +1260,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1527,7 +1530,7 @@ class TestRerankEdgeCases: # Mock dependencies with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1598,7 +1601,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1673,7 +1676,7 @@ class TestRerankPerformance: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1715,7 +1718,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1824,7 +1827,7 @@ class TestRerankErrorHandling: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 665e98bd9c..a7e62e7b0a 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -35,9 +35,10 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset +from models.enums import CreatorUserRole # ==================== Helper Functions ==================== @@ -3747,6 +3748,24 @@ class TestDatasetRetrievalAdditionalHelpers: mock_session.add_all.assert_called() mock_session.commit.assert_called() + def test_on_query_normalizes_workflow_end_user_role(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query="python", + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="end-user", + user_id="u1", + ) + + mock_session.add_all.assert_called_once() + added_queries = mock_session.add_all.call_args.args[0] + + assert len(added_queries) == 1 + assert added_queries[0].created_by_role == CreatorUserRole.END_USER + mock_session.commit.assert_called_once() + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: usage = LLMUsage.empty_usage() chunk_1 = SimpleNamespace( @@ -3836,7 +3855,7 @@ class TestDatasetRetrievalAdditionalHelpers: model_instance.model_type_instance.get_model_schema.return_value = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_manager, patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, ): mock_manager.return_value.get_model_instance.return_value = model_instance @@ -4222,11 +4241,12 @@ class TestKnowledgeRetrievalCoverage: with ( patch.object(retrieval, "_check_knowledge_rate_limit"), patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: retrieval.knowledge_retrieval(request) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert error_cls in type(exc_info.value).__name__ @@ -4279,9 +4299,13 @@ class TestRetrieveCoverage: ), ) model_config = self._build_model_config() - model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None - with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: - mock_model_manager.return_value.get_model_instance.return_value = Mock() + model_instance = Mock() + model_instance.model_name = "gpt-4" + model_instance.credentials = {"api_key": "secret"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = model_instance result = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4294,8 +4318,58 @@ class TestRetrieveCoverage: hit_callback=Mock(), message_id="m1", ) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert result == (None, []) + def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config( + self, retrieval: DatasetRetrieval + ) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ), + ) + model_config = self._build_model_config(features=[]) + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL]) + bound_bundle = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {"api_key": "secret"} + bound_model_instance.provider_model_bundle = bound_bundle + bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve, + ): + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_single_retrieve.assert_called_once() + assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER + assert model_config.provider_model_bundle is bound_bundle + assert model_config.credentials == {"api_key": "secret"} + assert model_config.model_schema is bound_schema + assert context == "" + assert files == [] + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: retrieve_config = DatasetRetrieveConfigEntity( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, @@ -4312,12 +4386,17 @@ class TestRetrieveCoverage: extra={"title": "External", "dataset_name": "External DS"}, ) with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance context, files = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4402,7 +4481,7 @@ class TestRetrieveCoverage: hit_callback = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), @@ -4413,7 +4492,14 @@ class TestRetrieveCoverage: patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.TOOL_CALL] + ) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] context, files = retrieval.retrieve( app_id="app-1", @@ -4800,8 +4886,8 @@ class TestInternalHooksCoverage: dataset_docs = [ SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), - SimpleNamespace(id="doc-c", doc_form="qa_model"), - SimpleNamespace(id="doc-d", doc_form="qa_model"), + SimpleNamespace(id="doc-c", doc_form=IndexStructureType.QA_INDEX), + SimpleNamespace(id="doc-d", doc_form=IndexStructureType.QA_INDEX), ] child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index cfa9094e12..43c521dcfd 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,7 +1,7 @@ from unittest.mock import Mock from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index e429563739..c56528cf55 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -3,8 +3,9 @@ from unittest.mock import Mock, patch from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.model_entities import ModelType class TestReactMultiDatasetRouter: @@ -87,6 +88,7 @@ class TestReactMultiDatasetRouter: model_config = Mock() model_config.mode = "chat" model_config.parameters = {"temperature": 0.1} + model_instance = Mock() usage = LLMUsage.empty_usage() tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] tools[0].name = "dataset-1" @@ -108,13 +110,14 @@ class TestReactMultiDatasetRouter: dataset_id, returned_usage = router._react_invoke( query="python", model_config=model_config, - model_instance=Mock(), + model_instance=model_instance, tools=tools, user_id="u1", tenant_id="t1", ) mock_chat_prompt.assert_called_once() + assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance assert dataset_id == "dataset-2" assert returned_usage == usage @@ -162,7 +165,11 @@ class TestReactMultiDatasetRouter: model_instance = Mock() model_instance.invoke_llm.return_value = iter([chunk]) - with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + with ( + patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager, + patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance text, returned_usage = router._invoke_llm( completion_param={"temperature": 0.1}, model_instance=model_instance, @@ -174,6 +181,13 @@ class TestReactMultiDatasetRouter: assert text == "part" assert returned_usage == usage + mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1") + mock_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id="t1", + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) mock_deduct.assert_called_once() def test_handle_invoke_result_with_empty_usage(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e7eecfa297..2735ec512f 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType +from graphon.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 2a83a4e802..05b4f3a053 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( +from core.repositories.factory import OrderConfig +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +181,10 @@ class TestCeleryWorkflowNodeExecutionRepository: repo.save(sample_workflow_node_execution) @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") - def test_get_by_workflow_run_from_cache( + def test_get_by_workflow_execution_from_cache( self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution ): - """Test that get_by_workflow_run retrieves executions from cache.""" + """Test that get_by_workflow_execution retrieves executions from cache.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -195,18 +195,18 @@ class TestCeleryWorkflowNodeExecutionRepository: # Save execution to cache first repo.save(sample_workflow_node_execution) - workflow_run_id = sample_workflow_node_execution.workflow_execution_id + workflow_execution_id = sample_workflow_node_execution.workflow_execution_id order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) # Verify results were retrieved from cache assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id assert result[0] is sample_workflow_node_execution - def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): - """Test get_by_workflow_run without order configuration.""" + def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_execution without order configuration.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -214,7 +214,7 @@ class TestCeleryWorkflowNodeExecutionRepository: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - result = repo.get_by_workflow_run("workflow-run-id") + result = repo.get_by_workflow_execution("workflow-run-id") # Should return empty list since nothing in cache assert len(result) == 0 @@ -236,7 +236,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert sample_workflow_node_execution.id in repo._execution_cache # Test retrieving from cache - result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id) assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id @@ -251,12 +251,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create multiple executions for the same workflow - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.START, @@ -269,7 +269,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.LLM, @@ -285,10 +285,10 @@ class TestCeleryWorkflowNodeExecutionRepository: # Verify both are cached and mapped assert len(repo._execution_cache) == 2 - assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2 # Test retrieval - result = repo.get_by_workflow_run(workflow_run_id) + result = repo.get_by_workflow_execution(workflow_execution_id) assert len(result) == 2 @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") @@ -302,12 +302,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create executions with different indices - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.START, @@ -320,7 +320,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.LLM, @@ -336,14 +336,14 @@ class TestCeleryWorkflowNodeExecutionRepository: # Test ascending order order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 1 assert result[1].index == 2 # Test descending order order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 2 assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fe9eed0307..48327c3913 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -11,9 +11,12 @@ import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 9af4d12664..18805bac59 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -14,16 +14,18 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, MemberRecipient, +) +from graphon.nodes.human_input.entities import ( + FormDefinition, UserAction, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -89,9 +91,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="external@example.com"), ], ), @@ -125,9 +127,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="missing-member"), + MemberRecipient(reference_id="missing-member"), ExternalRecipient(email="external@example.com"), ], ), @@ -156,7 +158,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[], ), ) @@ -182,7 +184,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ ExternalRecipient(email="external@example.com"), ExternalRecipient(email="external@example.com"), @@ -212,9 +214,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="shared@example.com"), ], ), @@ -243,7 +245,7 @@ class TestHumanInputFormRepositoryImplHelpers: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[ExternalRecipient(email="external@example.com")], ), subject="subject", @@ -421,22 +423,22 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.id == form.id - assert entity.web_app_token == "token-123" + assert entity.submission_token == "token-123" assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id="run-1") - assert repo.get_form("run-1", "node-1") is None + assert repo.get_form("node-1") is None def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( @@ -451,9 +453,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is False @@ -476,9 +478,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is True diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 4116e8b4a5..1297a95df1 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -11,6 +11,8 @@ from unittest.mock import MagicMock import pytest from core.repositories.human_input_repository import ( + FormCreateParams, + FormNotFoundError, HumanInputFormRecord, HumanInputFormRepositoryImpl, HumanInputFormSubmissionRepository, @@ -19,18 +21,16 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType @@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None: assert entity.token == "tok" -def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: +def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None: form = _DummyForm( id="f1", workflow_run_id="run", @@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No ) entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] - assert entity.web_app_token == "ctok" + assert entity.submission_token == "ctok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] - assert entity.web_app_token == "wtok" + assert entity.submission_token == "wtok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] - assert entity.web_app_token is None + assert entity.submission_token is None def test_form_entity_submitted_data_parsed() -> None: @@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), subject="s", body="b", @@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc session=MagicMock(), form_id="f", delivery_id="d", - recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]), ) assert recipients == ["ok"] @@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m form_id="f", delivery_id="d", recipients_config=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), ) assert recipients == ["ok"] @@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - assert repo.get_form("run", "node") is None + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + assert repo.get_form("node") is None form = _DummyForm( id="f1", @@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - entity = repo.get_form("run", "node") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + entity = repo.get_form("node") assert entity is not None assert entity.id == "f1" assert entity.recipients[0].id == "r1" @@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M session = _FakeSession() _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + repo = HumanInputFormRepositoryImpl( + tenant_id="tenant", + app_id="app", + workflow_execution_id="run", + invoke_source="debugger", + submission_actor_id="acc-1", + ) form_config = HumanInputNodeData( title="Title", @@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M user_actions=[UserAction(id="submit", title="Submit")], ) params = FormCreateParams( - app_id="app", - workflow_execution_id="run", + workflow_execution_id=None, node_id="node", form_config=form_config, rendered_content="

hello

", @@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M display_in_ui=True, resolved_default_values={}, form_kind=HumanInputFormKind.RUNTIME, - console_recipient_required=True, - console_creator_account_id="acc-1", - backstage_recipient_required=True, ) entity = repo.create_form(params) assert entity.id == "form-id" assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) # Console token should take precedence when console recipient is present. - assert entity.web_app_token == "token-console" + assert entity.submission_token == "token-console" assert len(entity.recipients) == 3 diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index 232ab07882..6cb3c3c6ac 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -7,7 +7,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from graphon.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 73de15e2cf..6af7b02d4c 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -15,6 +15,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from configs import dify_config +from core.repositories.factory import OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, _deterministic_json_dump, @@ -22,13 +23,12 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import ( +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom @@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> lambda max_workers: FakeExecutor(), ) - result = repo.get_by_workflow_run("run", order_config=None) + result = repo.get_by_workflow_execution("run", order_config=None) assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 456c3dde12..abdbc72085 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -10,11 +10,11 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index eeab81a178..5af1376a0a 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -17,11 +17,11 @@ from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 251d6fd25e..f17927f16b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,6 @@ import json -from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 92e4b58473..afea9144c0 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -6,8 +6,8 @@ from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 90ed1647aa..b19a21d7f4 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -12,9 +12,9 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormOption, diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 69567c54eb..7f6a50af99 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,12 +1,26 @@ -from unittest.mock import Mock, PropertyMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from models.provider import LoadBalancingModelConfig, ProviderModelSetting +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel +from models.provider_ids import ModelProviderID + + +def _build_provider_manager(mocker: MockerFixture) -> ProviderManager: + return ProviderManager(model_runtime=mocker.Mock()) + + +def _build_session_context(session: Mock) -> MagicMock: + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return session_cm @pytest.fixture @@ -28,7 +42,7 @@ def mock_provider_entity(): return mock_entity -def test__to_model_settings(mock_provider_entity): +def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -69,7 +83,7 @@ def test__to_model_settings(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -89,7 +103,7 @@ def test__to_model_settings(mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mock_provider_entity): +def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -119,7 +133,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -137,7 +151,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mock_provider_entity): +def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -176,7 +190,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -194,7 +208,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test_get_default_model_uses_first_available_active_model(): +def test_get_default_model_uses_first_available_active_model(mocker: MockerFixture): mock_session = Mock() mock_session.scalar.return_value = None @@ -204,7 +218,7 @@ def test_get_default_model_uses_first_available_active_model(): Mock(model="gpt-4", provider=Mock(provider="openai")), ] - manager = ProviderManager() + manager = _build_provider_manager(mocker) with ( patch("core.provider_manager.db.session", mock_session), patch.object(manager, "get_configurations", return_value=provider_configurations), @@ -228,3 +242,345 @@ def test_get_default_model_uses_first_available_active_model(): assert saved_default_model.model_name == "gpt-3.5-turbo" assert saved_default_model.provider_name == "openai" mock_session.commit.assert_called_once() + + +def test_get_default_model_returns_none_when_no_default_or_active_models(mocker: MockerFixture): + mock_session = Mock() + mock_session.scalar.return_value = None + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is None + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_factory_cls.assert_not_called() + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_get_default_model_uses_injected_runtime_for_existing_default_record(mocker: MockerFixture): + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="openai", + model_name="gpt-4", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result is not None + assert result.model == "gpt-4" + assert result.provider.provider == "openai" + + +def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_records = {"openai": [SimpleNamespace(provider_name="openai")]} + provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]} + preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")} + + with ( + patch.object(manager, "_get_all_providers", return_value=provider_records), + patch.object(manager, "_init_trial_provider_records", return_value=provider_records), + patch.object(manager, "_get_all_provider_models", return_value=provider_model_records), + patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_providers.return_value = [] + + result = manager.get_configurations("tenant-id") + + expected_alias = str(ModelProviderID("openai")) + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result.tenant_id == "tenant-id" + assert expected_alias in provider_records + assert expected_alias in provider_model_records + assert expected_alias in preferred_provider_records + + +@pytest.mark.parametrize( + ("provider_name", "expected_provider_names"), + [ + ("openai", ["openai", "langgenius/openai/openai"]), + ("langgenius/openai/openai", ["langgenius/openai/openai", "openai"]), + ("langgenius/gemini/google", ["langgenius/gemini/google", "google"]), + ], +) +def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, expected_provider_names: list[str]): + assert ProviderManager._get_provider_names(provider_name) == expected_provider_names + + +def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + +def test_get_configurations_binds_manager_runtime_to_provider_configuration( + mocker: MockerFixture, mock_provider_entity +): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}), + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration), + ): + manager.get_configurations("tenant-id") + + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + model_type_instance = Mock() + provider_configuration.get_model_type_instance.return_value = model_type_instance + expected_bundle = Mock() + + with ( + patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}), + patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle, + ): + result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + provider_configuration.get_model_type_instance.assert_called_once_with(ModelType.LLM) + mock_bundle.assert_called_once_with( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + assert result is expected_bundle + + +def test_get_first_provider_first_model_returns_none_when_no_models(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == (None, None) + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False) + + +def test_get_first_provider_first_model_returns_first_model_and_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-4", provider=Mock(provider="openai")), + Mock(model="gpt-4o", provider=Mock(provider="openai")), + ] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == ("openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_model(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + with pytest.raises(ValueError, match="Model gpt-3.5-turbo does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + + +def test_update_default_model_record_updates_existing_record(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-3.5-turbo")] + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="anthropic", + model_name="claude-3-sonnet", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + assert result is existing_default_model + assert existing_default_model.provider_name == "openai" + assert existing_default_model.model_name == "gpt-3.5-turbo" + mock_session.commit.assert_called_once() + mock_session.add.assert_not_called() + + +def test_update_default_model_record_creates_record_with_origin_model_type(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + mock_session = Mock() + mock_session.scalar.return_value = None + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + mock_session.add.assert_called_once() + created_default_model = mock_session.add.call_args.args[0] + assert result is created_default_model + assert created_default_model.tenant_id == "tenant-id" + assert created_default_model.provider_name == "openai" + assert created_default_model.model_name == "gpt-4" + assert created_default_model.model_type == ModelType.LLM.to_origin_model_type() + mock_session.commit.assert_called_once() + + +def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None: + session = Mock() + openai_provider = SimpleNamespace(provider_name="openai") + gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google") + session.scalars.return_value = [openai_provider, gemini_provider] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_providers("tenant-id") + + assert list(result[str(ModelProviderID("openai"))]) == [openai_provider] + assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider] + + +@pytest.mark.parametrize( + "method_name", + [ + "_get_all_provider_models", + "_get_all_provider_model_settings", + "_get_all_provider_model_credentials", + ], +) +def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None: + session = Mock() + openai_primary = SimpleNamespace(provider_name="openai") + openai_secondary = SimpleNamespace(provider_name="openai") + anthropic_record = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = getattr(ProviderManager, method_name)("tenant-id") + + assert list(result["openai"]) == [openai_primary, openai_secondary] + assert list(result["anthropic"]) == [anthropic_record] + + +def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None: + session = Mock() + openai_preference = SimpleNamespace(provider_name="openai") + anthropic_preference = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_preference, anthropic_preference] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_preferred_model_providers("tenant-id") + + assert result == { + "openai": openai_preference, + "anthropic": anthropic_preference, + } + + +def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None: + with ( + patch("core.provider_manager.redis_client.get", return_value=b"False"), + patch("core.provider_manager.FeatureService.get_features") as mock_get_features, + patch("core.provider_manager.Session") as mock_session_cls, + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + assert result == {} + mock_get_features.assert_not_called() + mock_session_cls.assert_not_called() + + +def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: + session = Mock() + openai_config = SimpleNamespace(provider_name="openai") + anthropic_config = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_config, anthropic_config] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.redis_client.get", return_value=None), + patch("core.provider_manager.redis_client.setex") as mock_setex, + patch( + "core.provider_manager.FeatureService.get_features", + return_value=SimpleNamespace(model_load_balancing_enabled=True), + ), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True") + assert list(result["openai"]) == [openai_config] + assert list(result["anthropic"]) == [anthropic_config] diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index f123f60a34..1ff81f6120 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -12,7 +12,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from dify_graph.model_runtime.entities.message_entities import UserPromptMessage +from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): @@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool): yield self.create_text_message("ok") -def _build_tool() -> _BuiltinDummyTool: +def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool: entity = ToolEntity( identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), parameters=[], ) - runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER) return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) @@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type(): def test_invoke_model_calls_model_invocation_utils_invoke(): - tool = _build_tool() + tool = _build_tool(user_id="runtime-user") with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: assert ( tool.invoke_model( @@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke(): ) == "result" ) - mock_invoke.assert_called_once() + mock_invoke.assert_called_once_with( + user_id="u1", + tenant_id="tenant-1", + tool_type=ToolProviderType.BUILT_IN, + tool_name="tool-a", + prompt_messages=[UserPromptMessage(content="hello")], + caller_user_id="runtime-user", + ) def test_get_max_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + tool = _build_tool(user_id="runtime-user") + with patch( + "core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096 + ) as mock_get: assert tool.get_max_tokens() == 4096 + mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user") def test_get_prompt_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + tool = _build_tool(user_id="runtime-user") + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="runtime-user", + ) + + +def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing(): + tool = _build_tool() + + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id=None, + ) def test_runtime_none_raises(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 62cfb6ce5b..9ac280e31a 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -1,6 +1,8 @@ from __future__ import annotations +import calendar import math +from datetime import date from types import SimpleNamespace import pytest @@ -25,8 +27,8 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from dify_graph.file.enums import FileType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.file.enums import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: @@ -98,7 +100,13 @@ def test_timezone_conversion_tool(): def test_weekday_tool(): weekday_tool = _build_builtin_tool(WeekdayTool) valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text - assert "January 1, 2024" in valid + expected_date = date(2024, 1, 1) + expected_message = ( + f"{calendar.month_name[expected_date.month]} " + f"{expected_date.day}, {expected_date.year} " + f"is {calendar.day_name[expected_date.weekday()]}." + ) + assert valid == expected_message invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ 0 ].message.text @@ -186,13 +194,19 @@ def test_asr_invalid_file(): def test_asr_valid_file_invocation(monkeypatch): asr = _build_builtin_tool(ASRTool) - model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})() model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") - monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + captured_manager_kwargs = {} + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant", + lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager, + ) audio_file = SimpleNamespace(type=FileType.AUDIO) ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text assert ok == "transcript" + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_asr_available_models_and_runtime_parameters(monkeypatch): @@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch): def test_tts_invoke_returns_messages(monkeypatch): tts = _build_builtin_tool(TTSTool) + captured_manager_kwargs = {} voices_model_instance = type( "TTSM", (), @@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **kwargs: ( + captured_manager_kwargs.update(kwargs) + or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})() + ), ) messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_tts_get_available_models_requires_runtime(): @@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), ) with pytest.raises(ValueError, match="no voice available"): list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index a5242a78c5..353988d7a6 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse import pytest -from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature +from core.tools.signature import ( + get_signed_file_url_for_plugin, + sign_tool_file, + sign_upload_file, + verify_plugin_file_signature, + verify_tool_file_signature, +) def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: @@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] + + +def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload/for-plugin" + assert query["tenant_id"] == ["tenant-id"] + assert query["user_id"] == ["user-id"] + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is True + ) + + +def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + query = parse_qs(urlparse(url).query) + + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100) + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is False + ) diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index cca8254dd6..b3442636b7 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -14,6 +14,7 @@ import httpx import pytest from core.tools.tool_file_manager import ToolFileManager +from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: @@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: def test_get_file_generator_returns_stream_when_found() -> None: # Arrange manager = ToolFileManager() - tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + tool_file = SimpleNamespace( + id="tool123", + file_key="k2", + mimetype="image/png", + original_url=None, + name="image.png", + size=12, + ) session = Mock() session.query.return_value.where.return_value.first.return_value = tool_file @@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None: with patch("core.tools.tool_file_manager.storage") as storage: stream = iter([b"a", b"b"]) storage.load_stream.return_value = stream - with ( - _patch_session_factory(session), - patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), - ): + with _patch_session_factory(session): result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") assert list(result_stream) == [b"a", b"b"] - assert result_file == "validated-file" + assert result_file is not None + assert result_file.related_id == "tool123" + assert result_file.mime_type == "image/png" + assert result_file.transfer_method == FileTransferMethod.TOOL_FILE diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 0f73e22654..844bc01e29 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ToolProviderType, ) @@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters(): tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} @@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters(): tenant_id="tenant-1", app_id="app-1", agent_tool=agent_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert result is tool_runtime assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=None, + ) def test_get_workflow_runtime_apply_runtime_parameters(): @@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters(): ) tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} @@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters(): app_id="app-1", node_id="node-1", workflow_tool=workflow_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert workflow_result is tool_runtime2 assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=None, + ) def test_get_agent_runtime_raises_when_runtime_missing(): @@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime: result = ToolManager.get_tool_runtime_from_plugin( tool_type=ToolProviderType.API, tenant_id="tenant-1", provider="api-1", tool_name="search", tool_parameters={"q": "hello", "llm": "ignore"}, + user_id="user-1", ) assert result is tool_entity assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=None, + ) def test_hardcoded_provider_icon_success(): diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py index 5ceaa08893..ae5638784c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_configuration.py +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -110,7 +110,7 @@ def test_encrypt_tool_parameters(): assert encrypted["plain"] == "x" -def test_decrypt_tool_parameters_cache_hit_and_miss(): +def test_decrypt_tool_parameters_cache_hit_and_miss(monkeypatch): manager = _build_manager() with ( @@ -139,7 +139,7 @@ def test_delete_tool_parameters_cache(): mock_delete.assert_called_once() -def test_configuration_manager_decrypt_suppresses_errors(): +def test_configuration_manager_decrypt_suppresses_errors(monkeypatch): manager = _build_manager() with ( patch.object(ToolParameterCache, "get", return_value=None), diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index af3cdddd5f..6454a5bcd1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): # meta is preserved (still contains mime_type: None) assert "mime_type" in (o.meta or {}) assert o.meta["mime_type"] is None + assert o.meta["tool_file_id"] == "fake-tool-file-id" + + +def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta(): + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"), + meta={}, + ) + + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + assert len(out) == 1 + assert out[0].meta["tool_file_id"] == "existing-tool-file" diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 2acae889b2..a4a563a4a1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -15,8 +15,8 @@ from unittest.mock import Mock, patch import pytest from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: if error_match: with pytest.raises(InvokeModelError, match=error_match): - ModelInvocationUtils.get_max_llm_context_tokens("tenant") + ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") else: - assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected + + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1") def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None) def test_invoke_success_and_error_mappings(): @@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings(): tool_type="builtin", tool_name="tool-a", prompt_messages=[], + caller_user_id="caller-1", ) assert response.message.content == "ok" assert db_mock.session.add.call_count == 1 assert db_mock.session.commit.call_count == 2 + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1") @pytest.mark.parametrize( @@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): @@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected): tool_name="tool-a", prompt_messages=[], ) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1") diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index dd79b79718..43f3fbd5c9 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -3,7 +3,7 @@ import pytest from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index dd140cbb27..b147d7fcdb 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -13,7 +13,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cc00f79698..72a73dd936 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -24,7 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FILE_MODEL_IDENTITY +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: @@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): """Transform args into parameters and files payloads.""" tool = _setup_transform_args_tool(monkeypatch) + build_file_from_stored_mapping = MagicMock( + side_effect=[ + SimpleNamespace( + transfer_method=FileTransferMethod.TOOL_FILE, + type=FileType.IMAGE, + reference="tool-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.LOCAL_FILE, + type=FileType.DOCUMENT, + reference="upload-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.REMOTE_URL, + type=FileType.DOCUMENT, + reference=None, + generate_url=lambda: "https://example.com/a.pdf", + ), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.build_file_from_stored_mapping", + build_file_from_stored_mapping, + ) params, files = tool._transform_args( { @@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + assert build_file_from_stored_mapping.call_count == 3 + assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list) def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index bcb1d745e3..ee7a3d9c96 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -26,7 +26,7 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from dify_graph.enums import BuiltinNodeTypes, NodeType +from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 91259c9a45..72052c8c05 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -5,11 +5,12 @@ import pytest from pydantic import BaseModel from core.helper import encrypter -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import VariablePool +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -25,13 +26,13 @@ from dify_graph.variables.segments import ( StringSegment, get_segment_discriminator, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import ( +from graphon.variables.types import SegmentType +from graphon.variables.utils import ( dumps_with_segments, segment_orjson_default, to_selector, ) -from dify_graph.variables.variables import ( +from graphon.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, @@ -48,14 +49,28 @@ from dify_graph.variables.variables import ( ) +def _build_variable_pool( + *, + system_variables: list[Variable] | None = None, + environment_variables: list[Variable] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables or [], + environment_variables=environment_variables or [], + ), + ) + return variable_pool + + def test_segment_group_to_text(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="fake-user-id"), environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], - conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( @@ -71,11 +86,8 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="1", app_id="1", workflow_id="1"), ) template = "Hello, world!" segments_group = variable_pool.convert_template(template) @@ -84,12 +96,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(system_variables=build_system_variables(user_id="fake-user-id")) template = "{{#sys.user_id#}}" segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" @@ -116,7 +123,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, @@ -190,7 +196,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_segment.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: @@ -234,7 +239,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_variable.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index bb234d9bbd..d4e862220a 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,8 +1,8 @@ import pytest -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import StringSegment +from graphon.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 41ce483447..14f9b2991d 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -10,10 +10,10 @@ from typing import Any import pytest -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from graphon.file.enums import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -22,7 +22,7 @@ from dify_graph.variables.segments import ( ObjectSegment, StringSegment, ) -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dd0fe2e65a..dae5e1ce98 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from dify_graph.variables import ( +from graphon.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +11,7 @@ from dify_graph.variables import ( SegmentType, StringVariable, ) -from dify_graph.variables.variables import VariableBase +from graphon.variables.variables import VariableBase def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index d09b8397c3..3ce4bb753b 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from dify_graph.context.execution_context import ( +from context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() def teardown_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from dify_graph.context import ContextProviderNotFoundError + from context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 22792eb5b3..ef5500b72f 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock, patch import pytest -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from dify_graph.variables.variables import StringVariable +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from graphon.variables.variables import StringVariable class StubCoordinator: @@ -23,6 +23,17 @@ class StubCoordinator: class TestGraphRuntimeState: + def test_execution_context_defaults_to_empty_context(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + + with state.execution_context: + assert state.execution_context is not None + + state.execution_context = None + + with state.execution_context: + assert state.execution_context is not None + def test_property_getters_and_setters(self): # FIXME(-LAN-): Mock VariablePool if needed variable_pool = VariablePool() @@ -117,7 +128,7 @@ class TestGraphRuntimeState: queue = state.ready_queue - from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue + from graphon.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) @@ -126,7 +137,7 @@ class TestGraphRuntimeState: execution = state.graph_execution - from dify_graph.graph_engine.domain.graph_execution import GraphExecution + from graphon.graph_engine.domain.graph_execution import GraphExecution assert isinstance(execution, GraphExecution) assert execution.workflow_id == "" @@ -141,7 +152,7 @@ class TestGraphRuntimeState: mock_graph = MagicMock() with patch( - "dify_graph.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True + "graphon.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True ) as coordinator_cls: coordinator_instance = coordinator_cls.return_value state.configure(graph=mock_graph) diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py index 158f7018b5..856ec959b7 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -5,7 +5,7 @@ Tests for PauseReason discriminated union serialization/deserialization. import pytest from pydantic import BaseModel, ValidationError -from dify_graph.entities.pause_reason import ( +from graphon.entities.pause_reason import ( HumanInputRequired, PauseReason, SchedulingPause, diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py index 2d4c7f7b77..e8304b9bcd 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ b/api/tests/unit_tests/core/workflow/entities/test_template.py @@ -1,6 +1,6 @@ """Tests for template module.""" -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment +from graphon.nodes.base.template import Template, TextSegment, VariableSegment class TestTemplate: diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index 6100ebede5..7e08751683 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -1,5 +1,5 @@ -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ( +from graphon.runtime import VariablePool +from graphon.variables.segments import ( BooleanSegment, IntegerSegment, NoneSegment, @@ -126,7 +126,7 @@ class TestVariablePoolGetNotModifyVariableDictionary: def test_get_should_not_modify_variable_dictionary(self): pool = VariablePool.empty() pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 1 # only contains `sys` node id + assert len(pool.variable_dictionary) == 0 assert "start" not in pool.variable_dictionary pool = VariablePool.empty() diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py index 216e64db8d..5e697f22f3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py @@ -8,8 +8,8 @@ from typing import Any import pytest -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes +from graphon.entities.workflow_node_execution import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes class TestWorkflowNodeExecutionProcessDataTruncation: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py index 24bd9ccbed..b138a7dfdc 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph.py @@ -2,10 +2,10 @@ from unittest.mock import Mock -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph.edge import Edge -from dify_graph.graph.graph import Graph -from dify_graph.nodes.base.node import Node +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState +from graphon.graph.edge import Edge +from graphon.graph.graph import Graph +from graphon.nodes.base.node import Node def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py index 64c2eee776..f3eaa1d686 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph import Graph -from dify_graph.nodes.base.node import Node +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph import Graph +from graphon.nodes.base.node import Node def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index 75de07bd8b..3620a20e56 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -5,11 +5,11 @@ from typing import Any import pytest from core.workflow.node_factory import DifyNodeFactory -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from graphon.graph import Graph +from graphon.graph.validation import GraphValidationError +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -63,7 +63,7 @@ def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], ), diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index e94ad74eb0..bfd0b48392 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,14 +6,14 @@ from dataclasses import dataclass import pytest -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType +from graphon.graph import Graph +from graphon.graph.validation import GraphValidationError +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -96,7 +96,7 @@ def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: invoke_from="service-api", call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) return factory, graph_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 40ed61eb02..960fef7d43 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -68,7 +68,7 @@ print(f"Success rate: {suite_result.success_rate:.1f}%") #### Event Sequence Validation ```python -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, @@ -376,39 +376,39 @@ See `test_mock_example.py` for comprehensive examples including: ```bash # Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py # Run with specific test patterns -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -k "test_echo" +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -k "test_echo" # Run with verbose output -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -v +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_graph_engine.py -v ``` ### Mock System Tests ```bash # Run auto-mock system tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_auto_mock_system.py +uv run pytest api/tests/unit_tests/graphon/graph_engine/test_auto_mock_system.py # Run examples -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_example.py +uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_example.py # Run simple validation -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_simple.py +uv run python api/tests/unit_tests/graphon/graph_engine/test_mock_simple.py ``` ### All Tests ```bash # Run all graph engine tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ +uv run pytest api/tests/unit_tests/graphon/graph_engine/ # Run with coverage -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ --cov=dify_graph.graph_engine +uv run pytest api/tests/unit_tests/graphon/graph_engine/ --cov=graphon.graph_engine # Run in parallel -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ -n auto +uv run pytest api/tests/unit_tests/graphon/graph_engine/ -n auto ``` ## Troubleshooting diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 4dec618e49..795362b158 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,15 +3,15 @@ import json from unittest.mock import MagicMock -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import ( +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.entities.commands import ( AbortCommand, CommandType, GraphEngineCommand, UpdateVariablesCommand, VariableUpdate, ) -from dify_graph.variables import IntegerVariable, StringVariable +from graphon.variables import IntegerVariable, StringVariable class TestRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 6f821ba799..cacbe9ba4e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -2,18 +2,18 @@ from __future__ import annotations -from dify_graph.entities.base_node_data import RetryConfig -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.runtime import GraphRuntimeState, VariablePool +from graphon.entities.base_node_data import RetryConfig +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine.domain.graph_execution import GraphExecution +from graphon.graph_engine.event_management.event_handlers import EventHandler +from graphon.graph_engine.event_management.event_manager import EventManager +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.ready_queue.in_memory import InMemoryReadyQueue +from graphon.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from graphon.graph_events import NodeRunRetryEvent, NodeRunStartedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py index 25494dc647..dc0998caf1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py @@ -4,9 +4,9 @@ from __future__ import annotations import logging -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent +from graphon.graph_engine.event_management.event_manager import EventManager +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent class _FaultyLayer(GraphEngineLayer): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py index 73d59ea4e9..b030496eb1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, create_autospec -from dify_graph.graph import Edge, Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.graph_traversal.skip_propagator import SkipPropagator +from graphon.graph import Edge, Graph +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.graph_traversal.skip_propagator import SkipPropagator class TestSkipPropagator: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index fc8133f5e1..2fead1d719 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -7,13 +7,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import ( +from core.repositories.human_input_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRecipientEntity, HumanInputFormRepository, ) +from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now @@ -49,7 +49,7 @@ class _InMemoryFormEntity(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return self.token @property @@ -88,24 +88,24 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): self._form_counter = 0 self.created_params: list[FormCreateParams] = [] self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: self.created_params.append(params) self._form_counter += 1 form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" + token = f"token-{form_id}" entity = _InMemoryFormEntity( form_id=form_id, rendered=params.rendered_content, token=token, ) self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + self._forms_by_node_id[params.node_id] = entity return entity - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) # Convenience helpers for tests ------------------------------------- diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 9e7b3654b7..b642dc82fe 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -10,7 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes @pytest.fixture @@ -63,7 +63,7 @@ def mock_llm_node(): def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" from core.tools.entities.tool_entities import ToolProviderType - from dify_graph.nodes.tool.entities import ToolNodeData + from graphon.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from dify_graph.graph_events.node import NodeRunSucceededEvent - from dify_graph.node_events.base import NodeRunResult + from graphon.graph_events.node import NodeRunSucceededEvent + from graphon.node_events.base import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py index db32527849..7ff77c19c1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -2,13 +2,13 @@ from __future__ import annotations import pytest -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import ( +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.layers.base import ( GraphEngineLayer, GraphEngineLayerNotInitializedError, ) -from dify_graph.graph_events import GraphEngineEvent +from graphon.graph_events import GraphEngineEvent from ..test_table_runner import WorkflowRunner diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd..80874e768a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,14 +1,27 @@ import threading from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.entities.commands import CommandType -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult +from core.model_manager import ModelInstance +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events.node import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + + +def _build_dify_context() -> DifyRunContext: + return DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) def _build_succeeded_event() -> NodeRunSucceededEvent: @@ -25,6 +38,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer() node = MagicMock() @@ -32,8 +50,8 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -41,7 +59,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -53,8 +71,8 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -62,7 +80,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -74,7 +92,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node._model_instance = object() result_event = _build_succeeded_event() @@ -91,7 +109,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -113,8 +131,8 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -141,7 +159,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -167,7 +185,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -175,5 +193,5 @@ def test_quota_precheck_passes_without_abort() -> None: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=raw_model_instance) layer.command_channel.send_command.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 478a2b592e..14ce55938d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -16,7 +16,7 @@ import pytest from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index 548c10ce8d..ab3a31f673 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -5,18 +5,18 @@ from __future__ import annotations import queue from unittest import mock -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_events import ( +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.event_management.event_handlers import EventHandler +from graphon.graph_engine.orchestration.dispatcher import Dispatcher +from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from graphon.graph_events import ( GraphNodeEventBase, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult +from graphon.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py index 7af6b26d87..1510c8e595 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py @@ -1,4 +1,4 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index fc0d22f739..5d0b37acc5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -7,7 +7,7 @@ for workflows containing nodes that require third-party services. import pytest -from dify_graph.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -201,7 +201,7 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory @@ -308,8 +308,8 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.nodes.template_transform import TemplateTransformNode - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.nodes.template_transform import TemplateTransformNode + from graphon.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py index 30acbdaf3d..cefe3b8ac8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py @@ -1,4 +1,4 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 765c4deba3..01ac2d7a96 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,23 +3,23 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.entities.commands import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.entities.commands import ( AbortCommand, CommandType, PauseCommand, UpdateVariablesCommand, VariableUpdate, ) -from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.variables import IntegerVariable, StringVariable +from graphon.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import IntegerVariable, StringVariable def test_abort_command(): @@ -73,9 +73,8 @@ def test_abort_command(): config=GraphEngineConfig(), ) - # Send abort command before starting - abort_command = AbortCommand(reason="Test abort") - command_channel.send_command(abort_command) + # Queue an abort request before starting. + engine.request_abort("Test abort") # Run engine and collect events events = list(engine.run()) @@ -102,7 +101,7 @@ def test_redis_channel_serialization(): mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel + from graphon.graph_engine.command_channels.redis_channel import RedisChannel # Create channel with a specific key channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index 3a9a0b18bc..ba9c502452 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -7,7 +7,7 @@ This test suite validates the behavior of a workflow that: 3. Handles multiple answer nodes with different outputs """ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index 76bf179f33..3851480731 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -6,10 +6,10 @@ This test validates that: - When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) """ -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py index 778dad5952..3ee34e86c6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -1,10 +1,10 @@ import queue from datetime import datetime -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.orchestration.dispatcher import Dispatcher +from graphon.graph_events import NodeRunSucceededEvent +from graphon.node_events import NodeRunResult class StubExecutionCoordinator: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py index c87dc75b95..ada55f3dc5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py @@ -6,7 +6,7 @@ field is missing from the output configuration, ensuring backward compatibility with older workflow definitions. """ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 35406997ed..95a94110d2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock import pytest -from dify_graph.graph_engine.command_processing.command_processor import CommandProcessor -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_engine.worker_management.worker_pool import WorkerPool +from graphon.graph_engine.command_processing.command_processor import CommandProcessor +from graphon.graph_engine.domain.graph_execution import GraphExecution +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from graphon.graph_engine.worker_management.worker_pool import WorkerPool def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4e13177d2b..51ece26d49 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,11 +10,11 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType -from dify_graph.enums import ErrorStrategy -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.entities.base_node_data import DefaultValue, DefaultValueType +from graphon.enums import ErrorStrategy +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -455,7 +455,7 @@ def test_if_else_workflow_property_diverse_inputs(query_input): # Tests for the Layer system def test_layer_system_basic(): """Test basic layer functionality with DebugLoggingLayer.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer + from graphon.graph_engine.layers import DebugLoggingLayer runner = WorkflowRunner() @@ -495,7 +495,7 @@ def test_layer_system_basic(): def test_layer_chaining(): """Test chaining multiple layers.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer + from graphon.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer # Create a custom test layer class TestLayer(GraphEngineLayer): @@ -549,7 +549,7 @@ def test_layer_chaining(): def test_layer_error_handling(): """Test that layer errors don't crash the engine.""" - from dify_graph.graph_engine.layers import GraphEngineLayer + from graphon.graph_engine.layers import GraphEngineLayer # Create a layer that throws errors class FaultyLayer(GraphEngineLayer): @@ -591,7 +591,7 @@ def test_layer_error_handling(): def test_event_sequence_validation(): """Test the new event sequence validation feature.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() @@ -678,7 +678,7 @@ def test_event_sequence_validation(): def test_event_sequence_validation_with_table_tests(): """Test event sequence validation with table-driven tests.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from graphon.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 255784b77d..348ceb6788 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -6,13 +6,13 @@ import json from collections import deque from unittest.mock import MagicMock -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph_engine.domain import GraphExecution -from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator -from dify_graph.graph_engine.response_coordinator.path import Path -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.graph_events import NodeRunStreamChunkEvent -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState +from graphon.graph_engine.domain import GraphExecution +from graphon.graph_engine.response_coordinator import ResponseStreamCoordinator +from graphon.graph_engine.response_coordinator.path import Path +from graphon.graph_engine.response_coordinator.session import ResponseSession +from graphon.graph_events import NodeRunStreamChunkEvent +from graphon.nodes.base.template import Template, TextSegment, VariableSegment class CustomGraphExecutionError(Exception): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index d54f0be190..a6417822d2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -1,26 +1,26 @@ import time from collections.abc import Mapping -from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeState -from dify_graph.graph import Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.llm.entities import ( +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import NodeState +from graphon.graph import Graph +from graphon.graph_engine.graph_state_manager import GraphStateManager +from graphon.graph_engine.ready_queue import InMemoryReadyQueue +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig @@ -29,7 +29,7 @@ from .test_mock_nodes import MockLLMNode def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 538f53c603..ca9a929591 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,8 +4,11 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.graph import Graph +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -14,25 +17,23 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +61,7 @@ def _build_branching_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -125,6 +126,7 @@ def _build_branching_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -246,7 +248,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -302,7 +304,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 36bba6deb6..c50aaafe2c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,8 +3,11 @@ import time from unittest import mock from unittest.mock import MagicMock -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.graph import Graph +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -13,25 +16,23 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.graph_events.node import NodeRunHumanInputFormFilledEvent +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -59,7 +60,7 @@ def _build_llm_human_llm_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," ), user_inputs={}, @@ -121,6 +122,7 @@ def _build_llm_human_llm_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -191,7 +193,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -260,7 +262,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index 8da179c15e..246df45d5f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,33 +1,33 @@ import time from unittest import mock -from dify_graph.graph import Graph -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_system_variables +from graphon.graph import Graph +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.nodes.if_else.if_else_node import IfElseNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.utils.condition.entities import Condition from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig @@ -44,7 +44,7 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr ) variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py index 733fd53bc8..821da46b76 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py @@ -5,7 +5,7 @@ This test validates the behavior of a loop containing an answer node inside the loop that may produce output errors. """ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, @@ -14,6 +14,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -50,6 +51,7 @@ def test_loop_contains_answer(): NodeRunLoopStartedEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 1 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, @@ -60,6 +62,7 @@ def test_loop_contains_answer(): NodeRunLoopNextEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 2 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py index 6ff2722f78..4a60c7769c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py @@ -1,4 +1,4 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, @@ -7,6 +7,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -44,12 +45,16 @@ def test_loop_with_tool(): NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, + NodeRunVariableUpdatedEvent, NodeRunSucceededEvent, NodeRunLoopNextEvent, # 2024 NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, + NodeRunVariableUpdatedEvent, NodeRunSucceededEvent, # LOOP END NodeRunLoopSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 93010eea54..76b2984a4b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -8,9 +8,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node from .test_mock_nodes import ( MockAgentNode, @@ -28,8 +28,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -111,7 +111,7 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, http_request_config=self._http_request_config, http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, + tool_file_manager_factory=self._bound_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) elif node_type in { diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 3e4247f33f..aff479104f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,8 +2,8 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -11,8 +11,8 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create a MockNodeFactory instance graph_init_params = GraphInitParams( @@ -63,8 +63,8 @@ def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode # Create mock config @@ -128,8 +128,8 @@ def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode # Create mock config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 454263bef9..971b9b2bbf 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,29 +11,29 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.nodes.agent import AgentNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.code import CodeNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode -from dify_graph.nodes.http_request import HttpRequestNode -from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.nodes.question_classifier import QuestionClassifierNode -from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) -from dify_graph.nodes.tool import ToolNode +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.code import CodeNode +from graphon.nodes.document_extractor import DocumentExtractorNode +from graphon.nodes.http_request import HttpRequestNode +from graphon.nodes.llm import LLMNode +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.parameter_extractor import ParameterExtractorNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.question_classifier import QuestionClassifierNode +from graphon.nodes.template_transform import TemplateTransformNode +from graphon.nodes.tool import ToolNode +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -66,20 +66,26 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + kwargs.setdefault("prompt_message_serializer", MagicMock(spec=PromptMessageSerializerProtocol)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - if isinstance(self, (LLMNode, QuestionClassifierNode)): - kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("llm_file_saver", MagicMock(spec=LLMFileSaver)) + + if isinstance(self, HttpRequestNode): + kwargs.setdefault("file_reference_factory", MagicMock(spec=FileReferenceFactoryProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) # Provide default tool_file_manager_factory for ToolNode subclasses - from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + from graphon.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): presentation_provider = MagicMock() @@ -596,8 +602,8 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): ) -from dify_graph.nodes.iteration import IterationNode -from dify_graph.nodes.loop import LoopNode +from graphon.nodes.iteration import IterationNode +from graphon.nodes.loop import LoopNode class MockIterationNode(MockNodeMixin, IterationNode): @@ -611,11 +617,11 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory @@ -656,7 +662,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): ) if not iteration_graph: - from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError + from graphon.nodes.iteration.exc import IterationGraphNotFoundError raise IterationGraphNotFoundError("iteration graph not found") @@ -683,11 +689,11 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index a8398e8f79..15f6f51398 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,9 +6,9 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.code.limits import CodeNodeLimits +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode @@ -40,8 +40,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_default_output(self): """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -60,7 +60,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -103,8 +103,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_custom_output(self): """Test that MockTemplateTransformNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -123,7 +123,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -167,8 +167,8 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_error_simulation(self): """Test that MockTemplateTransformNode can simulate errors.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -187,7 +187,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -228,9 +228,9 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_with_variables(self): """Test that MockTemplateTransformNode processes templates with variables.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from dify_graph.variables import StringVariable + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool + from graphon.variables import StringVariable # Create test parameters graph_init_params = GraphInitParams( @@ -249,7 +249,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -298,8 +298,8 @@ class TestMockCodeNode: def test_mock_code_node_default_output(self): """Test that MockCodeNode returns default output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -318,7 +318,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -364,8 +364,8 @@ class TestMockCodeNode: def test_mock_code_node_with_output_schema(self): """Test that MockCodeNode generates outputs based on schema.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -384,7 +384,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -438,8 +438,8 @@ class TestMockCodeNode: def test_mock_code_node_custom_output(self): """Test that MockCodeNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -458,7 +458,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -514,8 +514,8 @@ class TestMockNodeFactory: def test_code_and_template_nodes_mocked_by_default(self): """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -534,7 +534,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -559,8 +559,8 @@ class TestMockNodeFactory: def test_factory_creates_mock_template_transform_node(self): """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -579,7 +579,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -614,8 +614,8 @@ class TestMockNodeFactory: def test_factory_creates_mock_code_node(self): """Test that MockNodeFactory creates MockCodeNode for code type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( @@ -634,7 +634,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 5b35b3310a..cb5200f8dc 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -4,8 +4,8 @@ Simple test to validate the auto-mock system without external dependencies. import sys -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -98,8 +98,8 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory detection...") @@ -154,8 +154,8 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory registration...") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e681b39cc7..37b43bd374 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,32 +4,33 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.config import GraphEngineConfig +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -67,7 +68,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -103,7 +104,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -112,7 +113,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -159,6 +160,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -168,6 +170,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 60167c0441..59e54bd39a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,39 +4,40 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.config import GraphEngineConfig +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -59,7 +60,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -95,7 +96,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -115,7 +116,7 @@ class DelayedHumanInputNode(HumanInputNode): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -162,6 +163,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -171,6 +173,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), delay_seconds=0.2, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b954a4faac..1a43734462 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -15,20 +15,20 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_system_variables +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.llm.node import LLMNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -98,7 +98,7 @@ def test_parallel_streaming_workflow(): ) # Create variable pool with system variables - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=init_params.workflow_id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index 7328ce443f..bcf123ee80 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,40 +4,41 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.config import GraphEngineConfig +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +61,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -96,7 +97,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, form: HumanInputFormEntity) -> None: self._form = form - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: if node_id != "human_pause": return None return self._form @@ -107,7 +108,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -201,6 +202,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 15a7de3c52..79d3d5bcfe 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,38 +3,39 @@ import time from typing import Any from unittest.mock import MagicMock -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from core.repositories.human_input_repository import ( + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities.workflow_start_reason import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.graph_events.graph import GraphRunStartedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.graph_events.graph import GraphRunStartedEvent +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -50,7 +51,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -65,7 +66,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -112,6 +113,7 @@ def _build_human_input_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index 9c84f42db6..146b728dc2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -12,9 +12,9 @@ import pytest import redis from core.app.apps.base_app_queue_manager import AppQueueManager -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from dify_graph.graph_engine.manager import GraphEngineManager +from graphon.graph_engine.command_channels.redis_channel import RedisChannel +from graphon.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from graphon.graph_engine.manager import GraphEngineManager class TestRedisStopIntegration: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py index cd9d56f683..62ca7a630e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py @@ -4,9 +4,9 @@ from __future__ import annotations import pytest -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.nodes.base.template import Template, TextSegment +from graphon.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType +from graphon.graph_engine.response_coordinator.session import ResponseSession +from graphon.nodes.base.template import Template, TextSegment class DummyResponseNode: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 4f1741d4fb..a359a5fef9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -1,9 +1,10 @@ -from dify_graph.graph_events import ( +from graphon.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -33,6 +34,7 @@ def test_streaming_conversation_variables(): NodeRunSucceededEvent, # Variable Assigner node NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, # ANSWER node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ab8fb346b8..81d68ba2aa 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,29 +12,29 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from graphon.entities.graph_init_params import GraphInitParams +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ( +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -60,20 +60,28 @@ class _TableTestChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) if self._use_mock_factory: node_factory = MockNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, mock_config=self._mock_config, ) else: - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, + ) + graph_config = graph_init_params.graph_config child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) if not child_graph: raise ValueError("child graph not found") @@ -81,13 +89,11 @@ class _TableTestChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, ) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -206,14 +212,15 @@ class WorkflowRunner: call_depth=0, ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, ) - user_inputs = inputs if inputs is not None else {} + root_node_inputs = dict(inputs or {}) + root_node_inputs.setdefault("query", query) # Extract conversation variables from workflow config conversation_variables = [] @@ -242,11 +249,16 @@ class WorkflowRunner: ) conversation_variables.append(var) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs=user_inputs, - conversation_variables=conversation_variables, + root_node_id = get_default_root_node_id(graph_config) + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables, + conversation_variables=conversation_variables, + ), ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=root_node_inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -260,7 +272,7 @@ class WorkflowRunner: graph = Graph.init( graph_config=graph_config, node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), + root_node_id=root_node_id, ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index 7f26bc11a7..12aec6edf2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,6 +1,6 @@ -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py index f63e8ff4ce..2ad41037a9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py @@ -2,9 +2,9 @@ from unittest.mock import patch import pytest -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from .test_table_runner import TableTestRunner, WorkflowTestCase diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py new file mode 100644 index 0000000000..60cab77c0a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py @@ -0,0 +1,129 @@ +import time +import uuid +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.layers.base import GraphEngineLayer +from graphon.graph_events import NodeRunVariableUpdatedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringVariable + +DEFAULT_NODE_ID = "node_id" + + +class CaptureVariableUpdateLayer(GraphEngineLayer): + def __init__(self) -> None: + super().__init__() + self.events: list[NodeRunVariableUpdatedEvent] = [] + self.observed_values: list[object | None] = [] + + def on_graph_start(self) -> None: + pass + + def on_event(self, event) -> None: + if not isinstance(event, NodeRunVariableUpdatedEvent): + return + + current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) + self.events.append(event) + self.observed_values.append(None if current_value is None else current_value.value) + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_graph_engine_applies_variable_updates_before_notifying_layers(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "Start"}, "id": "start"}, + { + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "over-write", + "input_variable_selector": ["node_id", "test_string_variable"], + }, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + workflow_id="1", + graph_config=graph_config, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, + call_depth=0, + ) + + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), + conversation_variables=[ + StringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value="the first value", + ) + ], + ), + ) + variable_pool.add( + [DEFAULT_NODE_ID, "test_string_variable"], + StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ), + ) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + + engine = GraphEngine( + workflow_id="workflow-id", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + ) + capture_layer = CaptureVariableUpdateLayer() + engine.layer(capture_layer) + + events = list(engine.run()) + + update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] + assert len(update_events) == 1 + assert update_events[0].variable.value == "the second value" + + current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) + assert current_value is not None + assert current_value.value == "the second value" + + assert len(capture_layer.events) == 1 + assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py index bc00b49fba..85132674b8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py @@ -4,15 +4,16 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import MagicMock, patch -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.graph_engine.worker import Worker -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.ready_queue import InMemoryReadyQueue +from graphon.graph_engine.worker import Worker +from graphon.graph_events import NodeRunFailedEvent, NodeRunStartedEvent def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) + mock_datetime = mocker.patch("graphon.graph_engine.worker.datetime") + mock_datetime.now.return_value = fixed_time.replace(tzinfo=UTC) worker = Worker( ready_queue=InMemoryReadyQueue(), @@ -75,7 +76,8 @@ def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: worker._event_queue.put.side_effect = put_side_effect - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + with patch("graphon.graph_engine.worker.datetime") as mock_datetime: + mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) worker.run() fallback_event = captured_events[-1] @@ -135,7 +137,8 @@ def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_time worker._event_queue.put.side_effect = put_side_effect - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + with patch("graphon.graph_engine.worker.datetime") as mock_datetime: + mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) worker.run() fallback_event = captured_events[-1] diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py new file mode 100644 index 0000000000..1f4509af9a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -0,0 +1,33 @@ +from unittest.mock import patch + +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from graphon.enums import BuiltinNodeTypes + + +def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: + messages = iter(()) + transformer = AgentMessageTransformer() + + with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", return_value=iter(())) as transform: + result = list( + transformer.transform( + messages=messages, + tool_info={}, + parameters_for_log={}, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + node_type=BuiltinNodeTypes.AGENT, + node_id="node-id", + node_execution_id="execution-id", + ) + ) + + assert len(result) == 2 + transform.assert_called_once_with( + messages=messages, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py new file mode 100644 index 0000000000..c86de7f6e6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -0,0 +1,49 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from graphon.model_runtime.entities.model_entities import ModelType + + +def test_fetch_model_reuses_single_model_assembly(): + provider_configuration = SimpleNamespace( + get_current_credentials=Mock(return_value={"api_key": "x"}), + provider=SimpleNamespace(provider="openai"), + ) + model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema")) + provider_model_bundle = SimpleNamespace( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + model_instance = Mock() + assembly = SimpleNamespace( + provider_manager=Mock(), + model_manager=Mock(), + ) + assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle + assembly.model_manager.get_model_instance.return_value = model_instance + + with patch( + "core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model( + tenant_id="tenant-1", + user_id="user-1", + value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"}, + ) + + assert resolved_instance is model_instance + assert resolved_schema == "schema" + mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + assembly.provider_manager.get_provider_model_bundle.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + ) + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fd563d1be2..9c0ad25b58 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -4,12 +4,12 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.answer.answer_node import AnswerNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -48,7 +48,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 81d3f5be9c..ec4cef1955 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,9 +1,9 @@ import pytest from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 972a945ca0..ef0df55995 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -2,15 +2,15 @@ import types from collections.abc import Mapping from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) -from dify_graph.nodes.variable_assigner.v1.node import ( +from graphon.nodes.variable_assigner.v1.node import ( VariableAssignerNode as VariableAssignerV1, ) -from dify_graph.nodes.variable_assigner.v2.node import ( +from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 784e08edd2..ce0c9b79c6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,13 +1,13 @@ from configs import dify_config -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.exc import ( +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.types import SegmentType +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.variables.types import SegmentType CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py index de7ed0815e..20fe2c1a74 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.variables.types import SegmentType +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.variables.types import SegmentType class TestCodeNodeDataOutput: diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 859115ceb3..1d76067ec2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,7 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py index cd822a6f89..f1a48f49b9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.http_request import build_http_request_config +from graphon.nodes.http_request import build_http_request_config def test_build_http_request_config_uses_literal_defaults(): diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index fec6ad90eb..88895608d9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, PropertyMock, patch import httpx import pytest -from dify_graph.nodes.http_request.entities import Response +from graphon.nodes.http_request.entities import Response @pytest.fixture @@ -104,7 +104,7 @@ def test_mimetype_based_detection(mock_response, content_type, expected_main_typ mock_response.headers = {"content-type": content_type} type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - with patch("dify_graph.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: + with patch("graphon.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: # Mock the return value based on expected_main_type if expected_main_type: mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index cea7195417..be7cc073db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -2,19 +2,19 @@ import pytest from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import ( +from core.workflow.system_variables import default_system_variables +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeConfig, HttpRequestNodeData, ) -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout -from dify_graph.nodes.http_request.exc import AuthorizationConfigError -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout +from graphon.nodes.http_request.exc import AuthorizationConfigError +from graphon.nodes.http_request.executor import Executor +from graphon.runtime import VariablePool HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -30,7 +30,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -86,7 +86,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -144,7 +144,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -231,7 +231,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -537,7 +537,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -584,7 +584,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -625,7 +625,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["node", "count"], 42) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 5e34bf1d94..a3cadc0681 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -7,12 +7,13 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -109,7 +110,7 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=time.perf_counter(), ) return HttpRequestNode( @@ -121,6 +122,7 @@ def _build_http_node( http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), ) @@ -161,7 +163,7 @@ def test_run_passes_node_data_ssl_verify_to_executor(monkeypatch: pytest.MonkeyP ) ) - monkeypatch.setattr("dify_graph.nodes.http_request.node.Executor", FakeExecutor) + monkeypatch.setattr("graphon.nodes.http_request.node.Executor", FakeExecutor) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d52dfa2a65..1d6a4da7c4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,5 +1,5 @@ -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from dify_graph.runtime import VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients +from graphon.runtime import VariablePool def test_render_body_template_replaces_variable_values(): diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 55aa62a1c0..5f28a07606 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -8,35 +8,38 @@ from unittest.mock import MagicMock import pytest from pydantic import ValidationError -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.node_events import PauseRequestedEvent -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.human_input.entities import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import HumanInputFormRepository +from core.workflow.human_input_compat import ( + DeliveryMethodType, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, + EmailRecipientType, ExternalRecipient, - FormInput, - FormInputDefault, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, _WebAppDeliveryConfig, ) -from dify_graph.nodes.human_input.enums import ( +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( + FormInput, + FormInputDefault, + HumanInputNodeData, + UserAction, +) +from graphon.nodes.human_input.enums import ( ButtonStyle, - DeliveryMethodType, - EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit, ) -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository @@ -54,9 +57,9 @@ class TestDeliveryMethod: def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="test-user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -193,7 +196,7 @@ class TestHumanInputNodeData: EmailDeliveryMethod( enabled=False, # Disabled method should be fine config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + subject="Hi there", body="", recipients=EmailRecipients(include_bound_group=True) ), ), ] @@ -212,7 +215,7 @@ class TestHumanInputNodeData: assert node_data.title == "Test Node" assert node_data.desc is None - assert node_data.delivery_methods == [] + assert node_data.model_dump().get("delivery_methods") is None assert node_data.form_content == "" assert node_data.inputs == [] assert node_data.user_actions == [] @@ -261,10 +264,10 @@ class TestRecipients: def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123") assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" @@ -273,37 +276,46 @@ class TestRecipients: assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" + def test_email_recipients_bound_group(self): + """Test email recipients with the bound group enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + include_bound_group=True, + items=[MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123")], ) - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + assert recipients.include_bound_group is True + assert len(recipients.items) == 1 # Items are preserved even when include_bound_group is True def test_email_recipients_specific_users(self): """Test email recipients with specific users.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], ) - assert recipients.whole_workspace is False + assert recipients.include_bound_group is False assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" + assert recipients.items[0].reference_id == "user-123" assert recipients.items[1].email == "external@example.com" + def test_legacy_recipient_keys_are_rejected(self): + with pytest.raises(ValidationError): + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + recipients = EmailRecipients(whole_workspace=True, items=[]) + assert recipients.include_bound_group is True + assert recipients.items == [] + class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -353,17 +365,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -378,7 +392,7 @@ class TestHumanInputNodeVariableResolution: def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -416,28 +430,96 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-2", rendered_content="Provide your name", - web_app_token="console-token", + submission_token="console-token", recipients=[SimpleNamespace(token="recipient-token")], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() pause_event = next(run_result) assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" + assert not hasattr(pause_event.reason, "form_token") + + def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): + variable_pool = VariablePool( + system_variables=build_system_variables( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-4", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "end-user-1", + "user_from": "end-user", + "invoke_from": "web-app", + } + }, + call_depth=0, + ) + + config = { + "id": "human", + "data": { + "type": "human-input", + "title": "Human Input", + "form_content": "Provide your name", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "delivery_methods": [{"enabled": True, "type": "webapp", "config": {}}], + }, + } + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-4", + rendered_content="Provide your name", + submission_token="token", + recipients=[], + submitted=False, + ) + + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + runtime=runtime, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + params = mock_repo.create_form.call_args.args[0] + assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user-123", app_id="app", workflow_id="workflow", @@ -472,7 +554,7 @@ class TestHumanInputNodeVariableResolution: enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], ), subject="Subject", @@ -489,17 +571,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-3", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -511,11 +595,11 @@ class TestHumanInputNodeVariableResolution: method = params.delivery_methods[0] assert isinstance(method, EmailDeliveryMethod) assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False + assert method.config.recipients.include_bound_group is False assert len(method.config.recipients.items) == 1 recipient = method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" class TestValidation: @@ -552,7 +636,7 @@ class TestHumanInputNodeRenderedContent: def test_replaces_outputs_placeholders_after_submission(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -591,12 +675,14 @@ class TestHumanInputNodeRenderedContent: config = {"id": "human", "data": node_data.model_dump()} form_repository = InMemoryHumanInputFormRepository() + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=form_repository, + runtime=runtime, ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index b0ed47158d..fc4497f010 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,18 +1,19 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_init_params import GraphInitParams +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now @@ -25,7 +26,7 @@ class _FakeFormRepository: def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -85,11 +86,12 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -149,6 +151,7 @@ def _build_timeout_node() -> HumanInputNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py index 93c199514e..8cc91bdb54 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.iteration.entities import ( +from graphon.nodes.iteration.entities import ( ErrorHandleMode, IterationNodeData, IterationStartNodeData, diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index fdf5f4d1f8..58b82aa893 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,7 +1,7 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.exc import ( +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.nodes.iteration.exc import ( InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -9,7 +9,7 @@ from dify_graph.nodes.iteration.exc import ( IteratorVariableNotFoundError, StartNodeIdNotFoundError, ) -from dify_graph.nodes.iteration.iteration_node import IterationNode +from graphon.nodes.iteration.iteration_node import IterationNode class TestIterationNodeExceptions: diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py new file mode 100644 index 0000000000..4c3ad85fcd --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py @@ -0,0 +1,201 @@ +from threading import Event +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph_events import GraphRunAbortedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.nodes.iteration.exc import ChildGraphAbortedError +from graphon.nodes.iteration.iteration_node import IterationNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage + + +class _AbortOnRequestGraphEngine: + def __init__(self, *, index: int, total_tokens: int) -> None: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + self.started = Event() + self.abort_requested = Event() + self.finished = Event() + self.abort_reason: str | None = None + self.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ) + + def request_abort(self, reason: str | None = None) -> None: + self.abort_reason = reason + self.abort_requested.set() + + def run(self): + self.started.set() + assert self.abort_requested.wait(1), "parallel sibling never received an abort request" + self.finished.set() + yield GraphRunAbortedEvent(reason=self.abort_reason) + + +def _build_immediate_abort_graph_engine( + *, + index: int, + total_tokens: int, + wait_before_abort: Event | None = None, +) -> SimpleNamespace: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + started = Event() + finished = Event() + + def run(): + started.set() + if wait_before_abort is not None: + assert wait_before_abort.wait(1), "parallel sibling never started" + finished.set() + yield GraphRunAbortedEvent(reason="quota exceeded") + + return SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ), + run=run, + request_abort=lambda reason=None: None, + started=started, + finished=finished, + ) + + +def _build_iteration_node( + *, + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, + is_parallel: bool = False, +) -> IterationNode: + node = IterationNode.__new__(IterationNode) + node._node_id = "iteration-node" + node._node_data = IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id="child-start", + is_parallel=is_parallel, + parallel_nums=2, + error_handle_mode=error_handle_mode, + ) + + variable_pool = build_test_variable_pool() + variable_pool.add(["start", "items"], ["first", "second"]) + node.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=LLMUsage.empty_usage(), + ) + return node + + +def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: + node = _build_iteration_node() + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], 0) + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): + list( + node._run_single_iter( + variable_pool=variable_pool, + outputs=[], + graph_engine=graph_engine, + ) + ) + + +def test_iteration_run_fails_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + node._create_graph_engine.assert_called_once() + node._run_single_iter.assert_called_once() + + +def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=_usage_with_tokens(7), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 + + +@pytest.mark.parametrize( + "error_handle_mode", + [ + ErrorHandleMode.CONTINUE_ON_ERROR, + ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, + ], +) +def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( + error_handle_mode: ErrorHandleMode, +) -> None: + node = _build_iteration_node( + error_handle_mode=error_handle_mode, + is_parallel=True, + ) + blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) + aborting_engine = _build_immediate_abort_graph_engine( + index=0, + total_tokens=3, + wait_before_abort=blocking_engine.started, + ) + node._create_graph_engine = MagicMock( + side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] + ) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + assert events[-1].node_run_result.llm_usage.total_tokens == 8 + assert node.graph_runtime_state.llm_usage.total_tokens == 8 + assert blocking_engine.started.is_set() + assert blocking_engine.abort_requested.is_set() + assert blocking_engine.finished.is_set() + assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 2eb4feef5f..82cc734274 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -1,18 +1,18 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any import pytest -from dify_graph.entities import GraphInitParams -from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError -from dify_graph.nodes.iteration.iteration_node import IterationNode -from dify_graph.runtime import ( +from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.nodes.iteration.exc import IterationGraphNotFoundError +from graphon.nodes.iteration.iteration_node import IterationNode +from graphon.runtime import ( ChildEngineBuilderNotConfiguredError, ChildGraphNotFoundError, GraphRuntimeState, VariablePool, ) -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -22,17 +22,16 @@ class _MissingGraphBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> object: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -69,8 +68,6 @@ def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing runtime_state.create_child_engine( workflow_id="workflow", graph_init_params=graph_init_params, - graph_runtime_state=_build_runtime_state(), - graph_config={}, root_node_id="root", ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py index 8660449032..41d7c3193d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py @@ -1,14 +1,13 @@ import time -from contextlib import nullcontext from datetime import UTC, datetime import pytest -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.iteration_node import IterationNode +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from graphon.nodes.iteration.iteration_node import IterationNode def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: @@ -21,11 +20,17 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: parallel_nums=2, error_handle_mode=ErrorHandleMode.TERMINATED, ) - node._capture_execution_context = lambda: nullcontext() - node._sync_conversation_variables_from_snapshot = lambda snapshot: None node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): + def fake_execute_tracked_iteration_parallel( + *, + index: int, + item: object, + started_child_engines: dict[int, object], + started_child_engines_lock: object, + ): + _ = started_child_engines + _ = started_child_engines_lock return ( 0.1 + (index * 0.1), [ @@ -37,11 +42,10 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: ), ], f"output-{item}", - {}, LLMUsage.empty_usage(), ) - node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel + node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel outputs: list[object] = [] iter_run_map: dict[str, float] = {} diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab..a6fca1bfb4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -5,6 +5,7 @@ from unittest.mock import Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode @@ -14,10 +15,10 @@ from core.workflow.nodes.knowledge_index.protocols import ( PreviewItem, SummaryIndexServiceProtocol, ) -from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +41,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -78,7 +79,7 @@ def sample_node_data(): type="knowledge-index", chunk_structure="general_structure", index_chunk_variable_selector=["start", "chunks"], - indexing_technique="high_quality", + indexing_technique=IndexTechniqueType.HIGH_QUALITY, summary_index_setting=None, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 99997db6b2..45e8ae7d20 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -16,11 +16,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringSegment +from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -43,7 +43,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -157,7 +157,7 @@ class TestKnowledgeRetrievalNode: ): """Test _run with query variable in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" query_selector = ["start", "query"] @@ -441,7 +441,7 @@ class TestFetchDatasetRetriever: ): """Test _fetch_dataset_retriever in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" variables = {"query": query} diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index d71e0921c1..eca34f05be 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock import pytest -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.list_operator.node import ListOperatorNode -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.node import ListOperatorNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import ArrayNumberSegment, ArrayStringSegment class TestListOperatorNode: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index b0f0fd428b..4f9ba0194a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -6,17 +6,14 @@ from unittest.mock import MagicMock import httpx import pytest -from core.helper import ssrf_proxy -from core.tools import signature -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import FileTransferMethod, FileType, models -from dify_graph.nodes.llm.file_saver import ( +from graphon.file import FileTransferMethod, FileType +from graphon.nodes.llm.file_saver import ( FileSaverImpl, _extract_content_type_and_extension, _get_extension, _validate_extension_override, ) -from models import ToolFile +from graphon.nodes.protocols import ToolFileManagerProtocol _PNG_DATA = b"\x89PNG\r\n\x1a\n" @@ -27,58 +24,45 @@ def _gen_id(): class TestFileSaverImpl: def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - user_id = _gen_id() - tenant_id = _gen_id() file_type = FileType.IMAGE mime_type = "image/png" - mock_signed_url = "https://example.com/image.png" - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) + mock_tool_file = MagicMock() mock_tool_file.id = _gen_id() - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - + mock_tool_file.name = f"{_gen_id()}.png" + mock_tool_file.file_key = "test-file-key" + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) - # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. - mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) - # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. - monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) - mocked_sign_file.return_value = mock_signed_url + file_reference = MagicMock() + file_reference_factory = MagicMock() + file_reference_factory.build_from_mapping.return_value = file_reference http_client = MagicMock() - storage_file_manager = FileSaverImpl( - user_id=user_id, - tenant_id=tenant_id, + file_saver = FileSaverImpl( + tool_file_manager=mocked_tool_file_manager, + file_reference_factory=file_reference_factory, http_client=http_client, ) - file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file.tenant_id == tenant_id - assert file.type == file_type - assert file.transfer_method == FileTransferMethod.TOOL_FILE - assert file.extension == ".png" - assert file.mime_type == mime_type - assert file.size == len(_PNG_DATA) - assert file.related_id == mock_tool_file.id - - assert file.generate_url() == mock_signed_url + file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type) + assert file is file_reference mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) + file_reference_factory.build_from_mapping.assert_called_once_with( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": mock_tool_file.name, + "extension": ".png", + "mime_type": mime_type, + "size": len(_PNG_DATA), + "tool_file_id": mock_tool_file.id, + "related_id": mock_tool_file.id, + "storage_key": mock_tool_file.file_key, + } + ) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" @@ -91,8 +75,8 @@ class TestFileSaverImpl: http_client.get.return_value = mock_response file_saver = FileSaverImpl( - user_id=_gen_id(), - tenant_id=_gen_id(), + tool_file_manager=MagicMock(), + file_reference_factory=MagicMock(), http_client=http_client, ) @@ -104,8 +88,6 @@ class TestFileSaverImpl: def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" mime_type = "image/png" - user_id = _gen_id() - tenant_id = _gen_id() mock_request = httpx.Request("GET", _TEST_URL) mock_response = httpx.Response( @@ -117,21 +99,13 @@ class TestFileSaverImpl: http_client = MagicMock() http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), + file_saver = FileSaverImpl( + tool_file_manager=MagicMock(), + file_reference_factory=MagicMock(), + http_client=http_client, ) - mock_tool_file.id = _gen_id() - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) + expected_file = MagicMock() + mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file) monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) @@ -141,7 +115,7 @@ class TestFileSaverImpl: FileType.IMAGE, extension_override=".png", ) - assert file == mock_tool_file + assert file is expected_file def test_validate_extension_override(): diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 618a498659..dfc982f49c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -3,12 +3,94 @@ from unittest import mock import pytest from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent -from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage -from dify_graph.nodes.llm.exc import NoPromptFoundError -from dify_graph.runtime import VariablePool +from graphon.file import FileTransferMethod, FileType +from graphon.file.models import File +from graphon.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from graphon.nodes.llm.exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) +from graphon.runtime import VariablePool +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label={"en_US": "GPT-3.5 Turbo"}, + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_model_instance(*, model_schema: AIModelEntity | None = None) -> mock.MagicMock: + model_instance = mock.MagicMock(spec=ModelInstance) + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.get_model_schema.return_value = model_schema or _build_model_schema(features=[]) + model_instance.get_llm_num_tokens.return_value = 0 + return model_instance + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + +@pytest.fixture +def variable_pool() -> VariablePool: + pool = VariablePool.empty() + pool.add(["node1", "output"], "resolved_value") + pool.add(["node2", "text"], "hello world") + pool.add(["start", "user_input"], "dynamic_param") + return pool def _fetch_prompt_messages_with_mocked_content(content): @@ -24,15 +106,15 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( - "dify_graph.nodes.llm.llm_utils.fetch_model_schema", + "graphon.nodes.llm.llm_utils.fetch_model_schema", return_value=mock.MagicMock(features=[]), ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_list_messages", + "graphon.nodes.llm.llm_utils.handle_list_messages", return_value=[SystemPromptMessage(content=content)], ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", + "graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[], ), ): @@ -53,6 +135,159 @@ def _fetch_prompt_messages_with_mocked_content(content): ) +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + """Non-string params pass through; string params with variables get resolved.""" + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + """When a referenced variable doesn't exist in the pool, convert_template + falls back to the raw selector path (e.g. 'nonexistent.var').""" + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): with pytest.raises(NoPromptFoundError): _fetch_prompt_messages_with_mocked_content( @@ -104,3 +339,700 @@ def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_ ] ) ] + + +def test_fetch_model_schema_raises_when_model_schema_is_missing(): + model_instance = _build_model_instance() + model_instance.get_model_schema.return_value = None + + with pytest.raises(ValueError, match="Model schema not found for gpt-3.5-turbo"): + llm_utils.fetch_model_schema(model_instance=model_instance) + + +def test_fetch_files_supports_known_segments_and_rejects_invalid_types(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + variable_pool = VariablePool.empty() + variable_pool.add(["input", "file"], file) + variable_pool.add(["input", "files"], ArrayFileSegment(value=[file])) + variable_pool.add(["input", "none"], NoneSegment()) + variable_pool.add(["input", "empty"], ArrayAnySegment(value=[])) + variable_pool.add(["input", "invalid"], {"a": 1}) + + assert llm_utils.fetch_files(variable_pool, ["input", "file"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "files"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "none"]) == [] + assert llm_utils.fetch_files(variable_pool, ["input", "empty"]) == [] + + with pytest.raises(InvalidVariableTypeError, match="Invalid variable type"): + llm_utils.fetch_files(variable_pool, ["input", "invalid"]) + + +def test_fetch_files_returns_empty_for_missing_variable(): + assert llm_utils.fetch_files(VariablePool.empty(), ["input", "missing"]) == [] + + +def test_convert_history_messages_to_text_skips_system_messages_and_formats_images(): + history_text = llm_utils.convert_history_messages_to_text( + history_messages=[ + SystemPromptMessage(content="skip"), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="Answer"), + ], + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert history_text == "Human: Question\n[image]\nAssistant: Answer" + + +def test_fetch_memory_text_uses_prompt_memory_interface(): + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=321, + message_limit=2, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert memory_text == "Human: Question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_handle_list_messages_renders_jinja2_messages(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=renderer, + ) + + assert prompt_messages == [SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")])] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_handle_list_messages_splits_text_and_file_content(): + variable_pool = VariablePool.empty() + image_file = _build_image_file( + file_id="image-file", + related_id="image-related", + remote_url="https://example.com/file.png", + ) + variable_pool.add(["input", "image"], image_file) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ) as mock_to_prompt: + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="Analyze {{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Analyze ")]), + UserPromptMessage( + content=[ + ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + ] + ), + ] + mock_to_prompt.assert_called_once() + + +def test_handle_list_messages_supports_array_file_segments(): + variable_pool = VariablePool.empty() + first_file = _build_image_file(file_id="first", related_id="first-related", remote_url="https://example.com/1.png") + second_file = _build_image_file( + file_id="second", + related_id="second-related", + remote_url="https://example.com/2.png", + ) + variable_pool.add(["input", "images"], ArrayFileSegment(value=[first_file, second_file])) + + first_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/1.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + second_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/2.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=[first_prompt, second_prompt], + ): + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [UserPromptMessage(content=[first_prompt, second_prompt])] + + +def test_render_jinja2_message_handles_empty_template_success_and_missing_renderer(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + assert ( + llm_utils.render_jinja2_message( + template="", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + == "" + ) + + with pytest.raises(ValueError, match="template_renderer is required"): + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + assert ( + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=renderer, + ) + == "Hello Dify" + ) + + +def test_handle_completion_template_supports_basic_and_jinja2_templates(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + basic_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize {{#context#}}", + edition_type="basic", + ), + context="the docs", + jinja2_variables=[], + variable_pool=variable_pool, + ) + jinja_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="Hello {{ name }}", + edition_type="jinja2", + ), + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + template_renderer=renderer, + ) + + assert basic_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Summarize the docs")]), + ] + assert jinja_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + + +def test_combine_message_content_with_role_handles_all_supported_roles(): + contents = [TextPromptMessageContent(data="hello")] + + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.USER) == ( + UserPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.ASSISTANT) == ( + AssistantPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.SYSTEM) == ( + SystemPromptMessage(content=contents) + ) + + with pytest.raises(NotImplementedError, match="Role custom is not supported"): + llm_utils.combine_message_content_with_role(contents=contents, role="custom") # type: ignore[arg-type] + + +def test_calculate_rest_token_uses_context_size_and_template_alias(): + model_instance = _build_model_instance( + model_schema=_build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="output_limit", + use_template="max_tokens", + label={"en_US": "Output Limit"}, + type=ParameterType.INT, + ) + ], + ) + ) + model_instance.parameters = {"max_tokens": 512} + model_instance.get_llm_num_tokens.return_value = 256 + + assert ( + llm_utils.calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 3328 + ) + + +def test_handle_memory_chat_mode_returns_empty_without_memory_and_uses_window_when_present(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + assert ( + llm_utils.handle_memory_chat_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == [] + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=123) as mock_rest: + messages = llm_utils.handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + assert messages == [UserPromptMessage(content="Question")] + mock_rest.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=123, message_limit=2) + + +def test_handle_memory_completion_mode_validates_role_prefix_and_formats_history(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="Question"), + AssistantPromptMessage(content="Answer"), + ] + + assert ( + llm_utils.handle_memory_completion_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == "" + ) + + with ( + mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456), + pytest.raises(MemoryRolePrefixRequiredError, match="Memory role prefix is required"), + ): + llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456): + history_text = llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ), + model_instance=model_instance, + ) + + assert history_text == "Human: Question\nAssistant: Answer" + memory.get_history_prompt_messages.assert_called_with(max_token_limit=456, message_limit=None) + + +def test_append_file_prompts_merges_with_existing_user_content_or_appends_new_message(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + file_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + prompt_messages = [UserPromptMessage(content=[TextPromptMessageContent(data="Question")])] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[file_prompt, TextPromptMessageContent(data="Question")]), + ] + + prompt_messages = [SystemPromptMessage(content="System prompt")] + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages[-1] == UserPromptMessage(content=[file_prompt]) + + +def test_fetch_prompt_messages_chat_mode_includes_query_memory_and_supported_files(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.VISION])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history")] + sys_file = _build_image_file(file_id="sys", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + file_prompts = [ + ImagePromptMessageContent( + format="png", + url="https://example.com/sys.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + format="png", + url="https://example.com/context.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=file_prompts, + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history") + assert prompt_messages[2] == UserPromptMessage( + content=[ + file_prompts[1], + file_prompts[0], + TextPromptMessageContent(data="current question"), + ] + ) + + +def test_fetch_prompt_messages_completion_mode_updates_list_content_with_histories_and_query(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="another question"), + AssistantPromptMessage(content="another answer"), + ] + + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header", + edition_type="basic", + ), + stop=None, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [ + UserPromptMessage(content="latest question\nHuman: another question\nAssistant: another answer\nPrompt header") + ] + + +def test_fetch_prompt_messages_filters_content_unsupported_by_model_features(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.DOCUMENT])) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_list_messages", + return_value=[ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + ], + ), + mock.patch("graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[]), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=("END",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("END",) + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_completion_mode_supports_string_content_and_invalid_template_type(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix #histories# and #sys.query#")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=("HALT",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [UserPromptMessage(content="Prefix history text and latest question")] + + with pytest.raises(TemplateTypeNotSupportError): + llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=object(), # type: ignore[arg-type] + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + invalid_prompt = mock.MagicMock() + invalid_prompt.content = object() + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[invalid_prompt], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + pytest.raises(ValueError, match="Invalid prompt content type"), + ): + llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix only")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content="history text\nPrefix only")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fc96088af1..a2fbc50392 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,40 +5,80 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities import GraphInitParams -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( +from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, + SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + PromptConfig, VisionConfig, VisionConfigOptions, ) -from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from graphon.nodes.llm.exc import ( + InvalidContextStructureError, + LLMNodeError, + NoPromptFoundError, + VariableNotFoundError, +) +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import ( + LLMNode, + _calculate_rest_token, + _handle_completion_template, + _handle_memory_chat_mode, + _handle_memory_completion_mode, + _render_jinja2_message, +) +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -55,6 +95,62 @@ class MockTokenBufferMemory: return self.history_messages +def _build_prepared_llm_mock() -> mock.MagicMock: + model_instance = mock.MagicMock() + model_instance.provider = "openai" + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.stop = () + model_instance.get_llm_num_tokens.return_value = 0 + model_instance.get_model_schema.return_value = AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + model_instance.is_structured_output_parse_error.return_value = False + return model_instance + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( @@ -91,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) return GraphRuntimeState( @@ -107,7 +203,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -120,9 +216,9 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node @@ -132,28 +228,31 @@ def llm_node( def model_config(monkeypatch): from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass - def mock_plugin_model_providers(_self): - providers = MockModelClass().fetch_model_providers("test") - for provider in providers: - provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + def mock_model_providers(_self): + providers = [] + for provider in MockModelClass().fetch_model_providers("test"): + provider_schema = provider.declaration.model_copy(deep=True) + provider_schema.provider = f"{provider.plugin_id}/{provider.provider}" + provider_schema.provider_name = provider.provider + providers.append(provider_schema) return providers monkeypatch.setattr( ModelProviderFactory, - "get_plugin_model_providers", - mock_plugin_model_providers, + "get_model_providers", + mock_model_providers, ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(tenant_id="test") - provider_instance = model_provider_factory.get_plugin_model_provider("openai") + model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + provider_instance = model_provider_factory.get_model_provider("openai") model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance.declaration, + provider=provider_instance, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -181,13 +280,18 @@ def model_config(monkeypatch): ) -def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): +def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_config: ModelConfigWithCredentialsEntity): mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) - mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_model_factory = mock.MagicMock(spec=DifyModelFactory) provider_model_bundle = model_config.provider_model_bundle model_type_instance = provider_model_bundle.model_type_instance provider_model = mock.MagicMock() + completion_params = { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } model_instance = mock.MagicMock( model_type_instance=model_type_instance, @@ -208,12 +312,36 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): - fetch_model_config( - node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + hydrated_model_instance, model_config_with_credentials = fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params=completion_params, + ), credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, ) + assert hydrated_model_instance is model_instance + assert hydrated_model_instance.provider == "openai" + assert hydrated_model_instance.model_name == "gpt-3.5-turbo" + assert hydrated_model_instance.credentials == {"api_key": "test"} + assert hydrated_model_instance.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert hydrated_model_instance.stop == ("Observation:", "Human:") + assert model_config_with_credentials.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert model_config_with_credentials.stop == ["Observation:", "Human:"] + assert completion_params == { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") provider_model.raise_for_status.assert_called_once() @@ -230,12 +358,20 @@ def test_dify_model_access_adapters_call_managers(): mock_provider_configuration.get_provider_model.return_value = mock_provider_model mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} - credentials_provider = DifyCredentialsProvider( + run_context = DifyRunContext( tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + credentials_provider = DifyCredentialsProvider( + run_context=run_context, provider_manager=mock_provider_manager, ) model_factory = DifyModelFactory( - tenant_id="tenant", + run_context=run_context, model_manager=mock_model_manager, ) @@ -255,18 +391,18 @@ def test_dify_model_access_adapters_call_managers(): model="gpt-3.5-turbo", ) mock_provider_model.raise_for_status.assert_called_once() - mock_model_manager.get_model_instance.assert_called_once_with( - tenant_id="tenant", - provider="openai", - model_type=ModelType.LLM, - model="gpt-3.5-turbo", - ) + mock_model_manager.get_model_instance.assert_called_once() + assert mock_model_manager.get_model_instance.call_args.kwargs == { + "tenant_id": "tenant", + "provider": "openai", + "model_type": ModelType.LLM, + "model": "gpt-3.5-turbo", + } def test_fetch_files_with_file_segment(): file = File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -284,7 +420,6 @@ def test_fetch_files_with_array_file_segment(): files = [ File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -293,7 +428,6 @@ def test_fetch_files_with_array_file_segment(): ), File( id="2", - tenant_id="test", type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -343,7 +477,6 @@ def test_fetch_files_with_non_existent_variable(): # files = [ # File( # id="1", -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -448,7 +581,6 @@ def test_fetch_files_with_non_existent_variable(): # sys_query=fake_query, # sys_files=[ # File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -524,7 +656,6 @@ def test_fetch_files_with_non_existent_variable(): # + [UserPromptMessage(content=fake_query)], # file_variables={ # "input.image": File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -569,7 +700,7 @@ def test_fetch_files_with_non_existent_variable(): def test_handle_list_messages_basic(llm_node): messages = [ LLMNodeChatModelMessage( - text="Hello, {#context#}", + text="Hello, {{#context#}}", role=PromptMessageRole.USER, edition_type="basic", ) @@ -592,32 +723,414 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] -def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): - llm_node._template_renderer.render_jinja2.return_value = "Hello, world" +def test_handle_list_messages_replaces_double_brace_context_placeholder(llm_node): messages = [ LLMNodeChatModelMessage( - text="", - jinja2_text="Hello, {{ name }}", - role=PromptMessageRole.USER, - edition_type="jinja2", + text="Answer user's question with the following context:\n\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", ) ] + context = "## Overview\nSends a JSON request." result = llm_node.handle_list_messages( messages=messages, - context=None, + context=context, jinja2_variables=[], variable_pool=llm_node.graph_runtime_state.variable_pool, vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, - template_renderer=llm_node._template_renderer, ) - assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] - llm_node._template_renderer.render_jinja2.assert_called_once_with( - template="Hello, {{ name }}", - inputs={}, + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert result[0].content == [ + TextPromptMessageContent( + data="Answer user's question with the following context:\n\n## Overview\nSends a JSON request." + ) + ] + + +def test_handle_list_messages_renders_jinja2_messages(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_node.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + jinja2_template_renderer=renderer, ) + assert prompt_messages == [ + SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_transform_chat_messages_prefers_jinja2_text(llm_node): + completion_template = LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="completion prompt", + edition_type="jinja2", + ) + chat_messages = [ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="chat prompt", + role=PromptMessageRole.USER, + edition_type="jinja2", + ), + LLMNodeChatModelMessage( + text="keep original", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + ] + + transformed_completion = llm_node._transform_chat_messages(completion_template) + transformed_messages = llm_node._transform_chat_messages(chat_messages) + + assert transformed_completion.text == "completion prompt" + assert transformed_messages[0].text == "chat prompt" + assert transformed_messages[1].text == "keep original" + + +def test_fetch_jinja_inputs_serializes_supported_segment_types(llm_node): + llm_node.graph_runtime_state.variable_pool.add( + ["input", "items"], + ["alpha", {"metadata": {"_source": "knowledge"}, "content": "beta"}, 3], + ) + llm_node.graph_runtime_state.variable_pool.add( + ["input", "context_doc"], + {"metadata": {"_source": "knowledge"}, "content": "context body"}, + ) + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"a": 1}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[ + VariableSelector(variable="items", value_selector=["input", "items"]), + VariableSelector(variable="context_doc", value_selector=["input", "context_doc"]), + VariableSelector(variable="payload", value_selector=["input", "payload"]), + ] + ) + } + ) + + assert llm_node._fetch_jinja_inputs(node_data) == { + "items": "alpha\nbeta\n3", + "context_doc": "context body", + "payload": '{"a": 1}', + } + + +def test_fetch_jinja_inputs_raises_for_missing_variable(llm_node): + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[VariableSelector(variable="missing", value_selector=["input", "missing"])] + ) + } + ) + + with pytest.raises(VariableNotFoundError, match="Variable missing not found"): + llm_node._fetch_jinja_inputs(node_data) + + +def test_fetch_inputs_collects_prompt_and_memory_variables(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"active": True}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_template": [ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}} with {{#input.payload#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + "memory": MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#input.name#}}", + ), + } + ) + + assert llm_node._fetch_inputs(node_data) == { + "#input.name#": "Dify", + "#input.payload#": {"active": True}, + } + + +def test_fetch_context_emits_string_context_event(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], "retrieved context") + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert events == [ + RunRetrieverResourceEvent(retriever_resources=[], context="retrieved context", context_files=[]), + ] + + +def test_fetch_context_collects_retriever_resources_and_attachments(llm_node): + attachment = _build_image_file( + file_id="attachment", + related_id="attachment-related", + remote_url="https://example.com/attachment.png", + ) + llm_node._retriever_attachment_loader = mock.MagicMock() + llm_node._retriever_attachment_loader.load.return_value = [attachment] + + llm_node.graph_runtime_state.variable_pool.add( + ["context", "value"], + [ + { + "content": "chunk body", + "summary": "chunk summary", + "files": [{"id": "file-1"}], + "metadata": { + "_source": "knowledge", + "dataset_id": "dataset-1", + "segment_id": "segment-1", + "segment_word_count": 12, + }, + }, + "tail text", + ], + ) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert len(events) == 1 + event = events[0] + assert event.context == "chunk summary\nchunk body\ntail text" + assert event.context_files == [attachment] + assert event.retriever_resources == [ + { + "position": None, + "dataset_id": "dataset-1", + "dataset_name": None, + "document_id": None, + "document_name": None, + "data_source_type": None, + "segment_id": "segment-1", + "retriever_from": None, + "score": None, + "hit_count": None, + "word_count": 12, + "segment_position": None, + "index_node_hash": None, + "content": "chunk body", + "page": None, + "doc_metadata": None, + "files": [{"id": "file-1"}], + "summary": "chunk summary", + } + ] + llm_node._retriever_attachment_loader.load.assert_called_once_with(segment_id="segment-1") + + +def test_fetch_context_rejects_invalid_context_structure(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], [{"summary": "missing content"}]) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + with pytest.raises(InvalidContextStructureError, match="Invalid context structure"): + list(llm_node._fetch_context(node_data)) + + +def test_fetch_prompt_messages_chat_mode_appends_memory_query_and_files(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[ModelFeature.VISION]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history answer")] + + sys_file = _build_image_file(file_id="sys-file", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context-file", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + + prompt_content_side_effect = [ + ImagePromptMessageContent( + url="https://example.com/sys.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + url="https://example.com/context.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch("graphon.nodes.llm.node.file_manager.to_prompt_message_content") as mock_to_prompt: + mock_to_prompt.side_effect = prompt_content_side_effect + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + ), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history answer") + assert isinstance(prompt_messages[2], UserPromptMessage) + assert isinstance(prompt_messages[2].content, list) + assert isinstance(prompt_messages[2].content[0], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[1], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[2], TextPromptMessageContent) + assert prompt_messages[2].content[0].url == "https://example.com/context.png" + assert prompt_messages[2].content[1].url == "https://example.com/sys.png" + assert prompt_messages[2].content[2].data == "current question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=None) + + +def test_fetch_prompt_messages_completion_mode_injects_histories_and_query(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + +def test_fetch_prompt_messages_raises_when_only_unsupported_content_remains(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + variable_pool = VariablePool.empty() + variable_pool.add( + ["input", "image"], + _build_image_file(file_id="image-file", related_id="image-related", remote_url="https://example.com/file.png"), + ) + + with ( + mock.patch( + "graphon.nodes.llm.node.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + url="https://example.com/file.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + pytest.raises(NoPromptFoundError, match="No prompt found"), + ): + LLMNode.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + ) + + +def test_handle_completion_template_replaces_double_brace_context_placeholder(llm_node): + prompt_messages = _handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize the following context:\n{{#context#}}", + edition_type="basic", + ), + context="## Overview\nSends a JSON request.", + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_template_renderer=None, + ) + + assert prompt_messages == [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Summarize the following context:\n## Overview\nSends a JSON request.") + ] + ) + ] + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) @@ -635,15 +1148,15 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): AssistantPromptMessage(content="first answer"), ] - model_instance = mock.MagicMock(spec=ModelInstance) + model_instance = _build_prepared_llm_mock() memory_config = MemoryConfig( role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = llm_utils.handle_memory_completion_mode( + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -659,7 +1172,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -672,9 +1185,9 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node, mock_file_saver @@ -690,7 +1203,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -721,7 +1233,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -776,7 +1287,6 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: mock_saved_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", @@ -906,3 +1416,322 @@ class TestReasoningFormat: assert clean_text == text_with_think assert reasoning_content == "" + + +@pytest.mark.parametrize( + ("structured_output_enabled", "structured_output"), + [ + (False, None), + (True, {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}), + ], +) +def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enabled, structured_output): + model_instance = _build_prepared_llm_mock() + prompt_messages = [UserPromptMessage(content="hello")] + file_saver = mock.MagicMock(spec=LLMFileSaver) + + model_instance.invoke_llm.return_value = iter([]) + model_instance.invoke_llm_with_structured_output.return_value = iter([]) + + with ( + mock.patch.object(LLMNode, "handle_invoke_result", return_value=iter(["handled"])) as mock_handle, + mock.patch("graphon.nodes.llm.node.time.perf_counter", return_value=10.0), + ): + result = list( + LLMNode.invoke_llm( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=("STOP",), + structured_output_enabled=structured_output_enabled, + structured_output=structured_output, + file_saver=file_saver, + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + reasoning_format="separated", + ) + ) + + assert result == ["handled"] + if structured_output_enabled: + model_instance.invoke_llm_with_structured_output.assert_called_once_with( + prompt_messages=prompt_messages, + json_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, + model_parameters={}, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm.assert_not_called() + else: + model_instance.invoke_llm.assert_called_once_with( + prompt_messages=prompt_messages, + model_parameters={}, + tools=None, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm_with_structured_output.assert_not_called() + + assert mock_handle.call_args.kwargs["request_start_time"] == 10.0 + + +def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_output(): + usage = LLMUsage.from_metadata({"prompt_tokens": 12, "completion_tokens": 4, "total_tokens": 16}) + first_chunk = LLMResultChunkWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="plan")]), + ), + structured_output={"draft": True}, + ) + final_chunk = LLMResultChunk( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="answer")]), + usage=usage, + finish_reason="stop", + ), + ) + + with mock.patch("graphon.nodes.llm.node.time.perf_counter", side_effect=[2.0, 5.0]): + events = list( + LLMNode.handle_invoke_result( + invoke_result=iter([first_chunk, final_chunk]), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=_build_prepared_llm_mock(), + reasoning_format="separated", + request_start_time=1.0, + ) + ) + + assert events[0] == first_chunk + assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="plan", is_final=False) + assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + + completed = events[3] + assert isinstance(completed, ModelInvokeCompletedEvent) + assert completed.text == "answer" + assert completed.reasoning_content == "plan" + assert completed.structured_output == {"draft": True} + assert completed.finish_reason == "stop" + assert completed.usage.total_tokens == 16 + assert completed.usage.latency == 4.0 + assert completed.usage.time_to_first_token == 1.0 + assert completed.usage.time_to_generate == 3.0 + + +def test_handle_invoke_result_wraps_structured_output_parse_errors(): + model_instance = _build_prepared_llm_mock() + model_instance.is_structured_output_parse_error.return_value = True + + def broken_stream(): + raise ValueError("bad json") + yield + + with pytest.raises(LLMNodeError, match="Failed to parse structured output: bad json"): + list( + LLMNode.handle_invoke_result( + invoke_result=broken_stream(), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=model_instance, + ) + ) + + +def test_handle_blocking_result_extracts_reasoning_and_structured_output(): + invoke_result = LLMResultWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + message=AssistantPromptMessage(content="reasoningfinal answer"), + usage=LLMUsage.empty_usage(), + structured_output={"answer": "final answer"}, + ) + + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + reasoning_format="separated", + request_latency=1.2345, + ) + + assert event.text == "final answer" + assert event.reasoning_content == "reasoning" + assert event.structured_output == {"answer": "final answer"} + assert event.usage.latency == 1.234 + + +def test_fetch_structured_output_schema_validates_payload(): + assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object"}}) == { + "type": "object" + } + + with pytest.raises(LLMNodeError, match="Please provide a valid structured output schema"): + LLMNode.fetch_structured_output_schema(structured_output={}) + + with pytest.raises(LLMNodeError, match="structured_output_schema must be a JSON object"): + LLMNode.fetch_structured_output_schema(structured_output={"schema": ["not", "an", "object"]}) + + +def test_extract_variable_selector_to_variable_mapping_includes_runtime_selectors(): + node_data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ), + ], + prompt_config=PromptConfig( + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])] + ), + memory=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#sys.query#}}", + ), + context=ContextConfig(enabled=True, variable_selector=["context", "value"]), + vision=VisionConfig(enabled=True), + ) + + mapping = LLMNode._extract_variable_selector_to_variable_mapping( + graph_config={}, + node_id="llm-1", + node_data=node_data, + ) + + assert mapping == { + "llm-1.#input.name#": ["input", "name"], + "llm-1.#sys.query#": ["sys", "query"], + "llm-1.#context#": ["context", "value"], + "llm-1.#files#": ["sys", "files"], + "llm-1.name": ["input", "name"], + } + + +def test_render_jinja2_message_requires_renderer_and_passes_inputs(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + with pytest.raises( + TemplateRenderError, + match="LLMNode requires an injected jinja2_template_renderer for jinja2 prompts", + ): + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + assert ( + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=renderer, + ) + == "Hello Dify" + ) + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_calculate_rest_token_uses_context_size_and_max_tokens(): + model_instance = _build_prepared_llm_mock() + model_instance.parameters = {"max_tokens": 512} + model_instance.get_model_schema.return_value = _build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + ) + ], + ) + model_instance.get_llm_num_tokens.return_value = 1000 + + assert ( + _calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 2584 + ) + + +def test_handle_memory_chat_mode_uses_calculated_token_budget(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + history = [UserPromptMessage(content="question")] + memory.get_history_prompt_messages.return_value = history + + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=321) as mock_rest_token: + result = _handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=_build_prepared_llm_mock(), + ) + + assert result == history + mock_rest_token.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager_factory: + DifyCredentialsProvider(run_context=run_context, provider_manager=mock.MagicMock()) + DifyModelFactory(run_context=run_context, model_manager=mock.MagicMock()) + + mock_provider_manager_factory.assert_not_called() + + +def test_build_dify_model_access_binds_run_context_user_id_once(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager: + build_dify_model_access(run_context) + + mock_provider_manager.assert_called_once_with(tenant_id="tenant", user_id="user") + + +def test_dify_model_access_requires_run_context_argument(): + with pytest.raises(TypeError): + DifyCredentialsProvider() + + with pytest.raises(TypeError): + DifyModelFactory() diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index e40d565ef5..af1cff4e81 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -2,10 +2,10 @@ from collections.abc import Mapping, Sequence from pydantic import BaseModel, Field -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage +from graphon.file import File +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.nodes.llm.entities import LLMNodeChatModelMessage class LLMNodeTestScenario(BaseModel): diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py index fd48edc58c..ccf1077838 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py @@ -1,5 +1,5 @@ -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig -from dify_graph.variables.types import SegmentType +from graphon.nodes.parameter_extractor.entities import ParameterConfig +from graphon.variables.types import SegmentType class TestParameterConfig: diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 7eca531b62..8f8ec49f14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -7,18 +7,18 @@ from typing import Any import pytest -from dify_graph.model_runtime.entities import LLMMode -from dify_graph.nodes.llm import ModelConfig, VisionConfig -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData -from dify_graph.nodes.parameter_extractor.exc import ( +from factories.variable_factory import build_segment_with_type +from graphon.model_runtime.entities import LLMMode +from graphon.nodes.llm import ModelConfig, VisionConfig +from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.exc import ( InvalidNumberOfParametersError, InvalidSelectValueError, InvalidValueTypeError, RequiredParameterMissingError, ) -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.variables.types import SegmentType -from factories.variable_factory import build_segment_with_type +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.variables.types import SegmentType @dataclass diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py index e57ebbd83e..01878ed692 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from dify_graph.enums import ErrorStrategy -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.enums import ErrorStrategy +from graphon.nodes.template_transform.entities import TemplateTransformNodeData class TestTemplateTransformNodeData: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 332a8761f9..bc44ececd8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -3,11 +3,13 @@ from unittest.mock import MagicMock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -62,7 +64,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM @@ -78,7 +80,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" @@ -91,7 +93,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" @@ -111,7 +113,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -130,6 +132,26 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" + @pytest.mark.parametrize("max_output_length", [0, -1]) + def test_node_initialization_rejects_non_positive_max_output_length( + self, + basic_node_data, + mock_graph_runtime_state, + graph_init_params, + max_output_length, + ): + mock_renderer = MagicMock() + + with pytest.raises(ValueError, match="max_output_length must be a positive integer"): + TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=max_output_length, + ) + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool @@ -153,7 +175,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -181,7 +203,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -201,7 +223,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -221,7 +243,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, max_output_length=10, ) @@ -230,6 +252,28 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error + def test_run_output_length_equal_to_limit_succeeds( + self, basic_node_data, mock_graph_runtime_state, graph_init_params + ): + mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "1234567890" + + node = TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=10, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["output"] == "1234567890" + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { @@ -263,7 +307,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -291,6 +335,69 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] + def test_extract_variable_selector_to_variable_mapping_accepts_validated_node_data(self): + node_data = TemplateTransformNodeData( + title="Test", + variables=[VariableSelector(variable="var1", value_selector=["sys", "input1"])], + template="{{ var1 }}", + ) + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + + def test_extract_variable_selector_to_variable_mapping_returns_empty_mapping_without_variables(self): + node_data = { + "title": "Test", + "template": "{{ missing }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {} + + def test_extract_variable_selector_to_variable_mapping_accepts_sequence_value_selectors(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ("sys", "input1")}, + {"variable": "empty_selector", "value_selector": ()}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == { + "node_123.var1": ["sys", "input1"], + "node_123.empty_selector": [], + } + + def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ["sys", "input1"]}, + {"variable": "missing_selector"}, + ["not", "a", "mapping"], + {"variable": 1, "value_selector": ["sys", "input2"]}, + {"variable": "invalid_selector", "value_selector": ["sys", 2]}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -307,7 +414,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -346,7 +453,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -375,7 +482,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -405,7 +512,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py new file mode 100644 index 0000000000..636237e56e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.template_transform_node import ( + DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, + TemplateTransformNode, +) +from graphon.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params + +from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 + + +@pytest.fixture +def graph_init_params(): + return build_test_graph_init_params( + workflow_id="test_workflow", + graph_config={}, + tenant_id="test_tenant", + app_id="test_app", + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + mock_state = MagicMock(spec=GraphRuntimeState) + mock_state.variable_pool = MagicMock() + return mock_state + + +def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): + node = TemplateTransformNode( + id="test_node", + config={ + "id": "test_node", + "data": { + "title": "Template Transform", + "variables": [], + "template": "hello", + }, + }, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=MagicMock(), + ) + + assert node._max_output_length == DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH + + +def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entries(): + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={"ignored": True}, + node_id="node_123", + node_data={ + "variables": [ + VariableSelector(variable="validated", value_selector=["sys", "input1"]), + {"variable": "raw", "value_selector": ("sys", "input2")}, + {"variable": "invalid_selector", "value_selector": ["sys", 3]}, + ["not", "a", "mapping"], + ] + }, + ) + + assert mapping == { + "node_123.validated": ["sys", "input1"], + "node_123.raw": ["sys", "input2"], + } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 2b0205fb7b..0522dd9d14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -3,13 +3,14 @@ from collections.abc import Mapping import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +36,7 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) return init_params, runtime_state @@ -67,7 +68,7 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == "account" assert dify_ctx.invoke_from == "debugger" @@ -80,7 +81,7 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) @@ -91,7 +92,7 @@ def test_node_accepts_invoke_from_enum(): graph_runtime_state=runtime_state, ) - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == UserFrom.ACCOUNT assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER assert node.get_run_context_value("missing") is None @@ -127,3 +128,29 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" assert node_data.get("missing", "fallback") == "fallback" + + +def test_node_hydration_preserves_compatibility_extra_fields(): + graph_config: dict[str, object] = {} + init_params, runtime_state = _build_context(graph_config) + node_config = NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + "compat_flag": True, + }, + } + ) + + node = _SampleNode( + id="node-1", + config=node_config, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + assert node.node_data.foo == "bar" + assert node.node_data.get("compat_flag") is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 40754974c1..87ec2d5bce 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -6,21 +6,21 @@ import pytest from docx.oxml.text.paragraph import CT_P from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData -from dify_graph.nodes.document_extractor.node import ( +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment -from dify_graph.variables.variables import StringVariable +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment +from graphon.variables.variables import StringVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -183,14 +183,14 @@ def test_run_extract_text( mock_response.raise_for_status = Mock() document_extractor_node._http_client.get = Mock(return_value=mock_response) - monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) + monkeypatch.setattr("graphon.file.file_manager.download", mock_download) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) result = document_extractor_node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c746a945fe..782750e02e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,19 +4,18 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.graph import Graph -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition -from dify_graph.variables import ArrayFileSegment +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.graph import Graph +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.nodes.if_else.if_else_node import IfElseNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from graphon.variables import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +34,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -142,7 +141,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) @@ -253,7 +252,6 @@ def test_array_file_contains_file_name(): node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -316,7 +314,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -371,7 +369,7 @@ def test_execute_if_else_boolean_false_conditions(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -440,7 +438,7 @@ def test_execute_if_else_boolean_cases_structure(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 6ca72b64b2..b217e4e8e7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,11 +2,10 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.list_operator.entities import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.list_operator.entities import ( ExtractConfig, FilterBy, FilterCondition, @@ -15,9 +14,9 @@ from dify_graph.nodes.list_operator.entities import ( Order, OrderByConfig, ) -from dify_graph.nodes.list_operator.exc import InvalidKeyError -from dify_graph.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func -from dify_graph.variables import ArrayFileSegment +from graphon.nodes.list_operator.exc import InvalidKeyError +from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from graphon.variables import ArrayFileSegment @pytest.fixture @@ -72,7 +71,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image1.jpg", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", @@ -80,7 +78,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="document1.pdf", type=FileType.DOCUMENT, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", @@ -88,7 +85,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image2.png", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", @@ -96,7 +92,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="audio1.mp3", type=FileType.AUDIO, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -120,14 +115,12 @@ def test_filter_files_by_type(list_operator_node): { "filename": "document1.pdf", "type": FileType.DOCUMENT, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related2", }, { "filename": "image2.png", "type": FileType.IMAGE, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related3", }, @@ -136,7 +129,6 @@ def test_filter_files_by_type(list_operator_node): for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type - assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id @@ -144,7 +136,6 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", @@ -165,7 +156,6 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py index 6372583839..d613ba154a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -1,6 +1,22 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.nodes.loop.entities import LoopNodeData -from dify_graph.nodes.loop.loop_node import LoopNode +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph_events import GraphRunAbortedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent +from graphon.nodes.loop.entities import LoopNodeData +from graphon.nodes.loop.loop_node import LoopNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: @@ -50,3 +66,85 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf ) assert seen_configs == [child_node_config] + + +def test_run_single_loop_raises_on_child_abort_event() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(RuntimeError, match="quota exceeded"): + list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) + + +def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=2, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), + ) + create_graph_engine = MagicMock(return_value=aborting_engine) + node._create_graph_engine = create_graph_engine + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], LoopStartedEvent) + assert isinstance(events[1], LoopFailedEvent) + assert events[1].error == "quota exceeded" + assert isinstance(events[2], StreamCompletedEvent) + assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[2].node_run_result.error == "quota exceeded" + create_graph_engine.assert_called_once() + + +def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), + ) + node._create_graph_engine = MagicMock(return_value=aborting_engine) + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index c5a02e87e4..efbf786a55 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -1,13 +1,14 @@ from types import SimpleNamespace from unittest.mock import MagicMock -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.nodes.question_classifier import ( +from graphon.model_runtime.entities import ImagePromptMessageContent +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.protocols import HttpClientProtocol +from graphon.nodes.question_classifier import ( QuestionClassifierNode, QuestionClassifierNodeData, ) +from graphon.template_rendering import Jinja2TemplateRenderer from tests.workflow_test_utils import build_test_graph_init_params @@ -86,7 +87,7 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon "instruction": "This is a test instruction", } ) - template_renderer = MagicMock(spec=TemplateRenderer) + template_renderer = MagicMock(spec=Jinja2TemplateRenderer) node = QuestionClassifierNode( id="node-id", config={"id": "node-id", "data": node_data.model_dump(mode="json")}, @@ -107,11 +108,11 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon ) fetch_prompt_messages = MagicMock(return_value=([], None)) monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", + "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", fetch_prompt_messages, ) monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", + "graphon.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index b8f0e25e91..543f9878de 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,19 +4,22 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.variables import Variable +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def make_start_node(user_inputs, variables): - variable_pool = VariablePool( - system_variables=SystemVariable(), - user_inputs=user_inputs, - conversation_variables=[], + variable_pool = build_test_variable_pool( + variables=build_system_variables(), + node_id="start", + inputs=user_inputs, ) config = { @@ -232,3 +235,64 @@ def test_json_object_optional_variable_not_provided(): # Current implementation raises a validation error even when the variable is optional with pytest.raises(ValueError, match="profile is required in input form"): node._run() + + +def test_start_node_outputs_full_variable_pool_snapshot(): + variable_pool = build_test_variable_pool( + variables=[ + *build_system_variables(query="hello", workflow_run_id="run-123"), + _build_prefixed_variable(ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY", "secret"), + _build_prefixed_variable(CONVERSATION_VARIABLE_NODE_ID, "session_id", "conversation-1"), + ], + node_id="start", + inputs={"profile": {"age": 20, "name": "Tom"}}, + ) + + config = { + "id": "start", + "data": StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ).model_dump(), + } + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node = StartNode( + id="start", + config=config, + graph_init_params=build_test_graph_init_params( + workflow_id="wf", + graph_config={}, + tenant_id="tenant", + app_id="app", + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + result = node._run() + + assert result.inputs == {"profile": {"age": 20, "name": "Tom"}} + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + assert result.outputs["sys.query"] == "hello" + assert result.outputs["sys.workflow_run_id"] == "run-123" + assert result.outputs["env.API_KEY"] == "secret" + assert result.outputs["conversation.session_id"] == "conversation-1" + + +def _build_prefixed_variable(node_id: str, name: str, value: object) -> Variable: + return segment_to_variable( + segment=build_segment(value), + selector=(node_id, name), + name=name, + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3cbd96dfef..c806181340 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -3,23 +3,55 @@ from __future__ import annotations import sys import types from collections.abc import Generator +from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment +from core.workflow.system_variables import build_system_variables +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.tool.tool_node import ToolNode + + +class _StubToolRuntime: + def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + raise NotImplementedError + + def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: + return [] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: dict[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + yield from () + + def get_usage(self, *, tool_runtime: ToolRuntimeHandle) -> LLMUsage: + return LLMUsage.empty_usage() + + def build_file_reference(self, *, mapping: dict[str, Any]) -> Any: + return mapping + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + return default_icon, None @pytest.fixture @@ -31,8 +63,8 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) - from dify_graph.nodes.protocols import ToolFileManagerProtocol - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.protocols import ToolFileManagerProtocol + from graphon.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { "nodes": [ @@ -66,13 +98,14 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id")) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + runtime = _StubToolRuntime() node = ToolNode( id="node-instance", @@ -80,6 +113,7 @@ def tool_node(monkeypatch) -> ToolNode: graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=runtime, ) return node @@ -93,29 +127,19 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: - def _identity_transform(messages, *_args, **_kwargs): - return messages - - tool_runtime = MagicMock() - with patch.object( - ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True - ): - generator = tool_node._transform_message( - messages=iter([message]), - tool_info={"provider_type": "builtin", "provider_id": "provider"}, - parameters_for_log={}, - user_id="user-id", - tenant_id="tenant-id", - node_id=tool_node._node_id, - tool_runtime=tool_runtime, - ) - return _collect_events(generator) +def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: + generator = tool_node._transform_message( + messages=iter([message]), + tool_info={"provider_type": "builtin", "provider_id": "provider"}, + parameters_for_log={}, + node_id=tool_node._node_id, + tool_runtime=ToolRuntimeHandle(raw=object()), + ) + return _collect_events(generator) def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - tenant_id="tenant-id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", @@ -125,9 +149,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): size=123, storage_key="file-key", ) - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), meta={"file": file_obj}, ) @@ -150,9 +174,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): def test_plain_link_messages_remain_links(tool_node: ToolNode): - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="https://dify.ai"), meta=None, ) @@ -167,3 +191,35 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [] + + +def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): + file_obj = File( + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="file-id", + filename="demo.pdf", + extension=".pdf", + mime_type="application/pdf", + size=123, + storage_key="file-key", + ) + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + None, + SimpleNamespace(mime_type="application/pdf"), + ) + tool_node._runtime.build_file_reference = MagicMock(return_value=file_obj) + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.IMAGE_LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), + meta={"tool_file_id": "file-id"}, + ) + + events, _ = _run_transform(tool_node, message) + + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert len(completed_events) == 1 + files_segment = completed_events[0].node_run_result.outputs["files"] + assert isinstance(files_segment, ArrayFileSegment) + assert files_segment.value == [file_obj] diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py new file mode 100644 index 0000000000..438af211f3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool + + +@pytest.fixture +def runtime(monkeypatch) -> DifyToolNodeRuntime: + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute + ops_stub.TraceTask = object # pragma: no cover - stub attribute + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + init_params = build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + return DifyToolNodeRuntime(init_params.run_context) + + +def _build_tool_node_data() -> ToolNodeData: + return ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + +def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: + core_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + meta=None, + ) + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables(conversation_id="conversation-id") + ) + workflow_tool = MagicMock() + + with ( + patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_tool), + patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, + patch.object( + ToolFileMessageTransformer, + "transform_tool_invoke_messages", + side_effect=lambda *, messages, **_: messages, + ) as transform_tool_messages, + ): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + ) + messages = list( + runtime.invoke( + tool_runtime=tool_runtime, + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + ) + + assert not hasattr(tool_runtime, "conversation_id") + assert len(messages) == 1 + graph_message = messages[0] + assert graph_message.type == ToolRuntimeMessage.MessageType.LINK + assert isinstance(graph_message.message, ToolRuntimeMessage.TextMessage) + assert graph_message.message.text == "https://dify.ai" + + callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] + assert isinstance(callback, DifyWorkflowCallbackHandler) + assert generic_invoke_mock.call_args.kwargs["conversation_id"] == "conversation-id" + + transform_kwargs = transform_tool_messages.call_args.kwargs + assert transform_kwargs["conversation_id"] == "conversation-id" + + +def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: + invoke_error = PluginInvokeError('{"error_type":"RateLimit","message":"too many"}') + + with patch.object(ToolEngine, "generic_invoke", side_effect=invoke_error): + with pytest.raises(ToolRuntimeInvocationError, match="An error occurred in the provider"): + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + + +def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None: + usage_payload = LLMUsage.empty_usage().model_dump() + usage_payload["total_tokens"] = 42 + + usage = runtime.get_usage( + tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=usage_payload)), + ) + + assert usage.total_tokens == 42 + + +def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: + node_data = _build_tool_node_data() + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: + tool_runtime = runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + + assert not hasattr(tool_runtime, "conversation_id") + workflow_tool = runtime_mock.call_args.args[3] + assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN + + +def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: + tool_runtime = ToolRuntimeHandle( + raw=SimpleNamespace( + get_merged_runtime_parameters=MagicMock( + return_value=[ + SimpleNamespace(name="city", required=True), + SimpleNamespace(name="country", required=False), + ] + ) + ) + ) + + parameters = runtime.get_runtime_parameters(tool_runtime=tool_runtime) + + assert [(parameter.name, parameter.required) for parameter in parameters] == [ + ("city", True), + ("country", False), + ] + + +def test_get_usage_returns_empty_usage_when_tool_has_no_usage(runtime: DifyToolNodeRuntime) -> None: + usage = runtime.get_usage(tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=None))) + + assert usage == LLMUsage.empty_usage() + + +@pytest.mark.parametrize( + ("payload", "expected_type"), + [ + (ToolInvokeMessage.JsonMessage(json_object={"ok": True}, suppress_output=True), ToolRuntimeMessage.JsonMessage), + (ToolInvokeMessage.BlobMessage(blob=b"bytes"), ToolRuntimeMessage.BlobMessage), + ( + ToolInvokeMessage.BlobChunkMessage( + id="blob-id", + sequence=1, + total_length=5, + blob=b"hello", + end=True, + ), + ToolRuntimeMessage.BlobChunkMessage, + ), + (ToolInvokeMessage.FileMessage(file_marker="marker"), ToolRuntimeMessage.FileMessage), + ( + ToolInvokeMessage.VariableMessage(variable_name="city", variable_value="Tokyo", stream=True), + ToolRuntimeMessage.VariableMessage, + ), + ( + ToolInvokeMessage.LogMessage( + id="log-id", + label="lookup", + status=ToolInvokeMessage.LogMessage.LogStatus.SUCCESS, + data={"count": 1}, + metadata={"source": "tool"}, + ), + ToolRuntimeMessage.LogMessage, + ), + ], +) +def test_convert_message_payload_supports_runtime_message_types( + runtime: DifyToolNodeRuntime, + payload: object, + expected_type: type[object], +) -> None: + message = runtime._convert_message_payload(payload) + + assert isinstance(message, expected_type) + + +def test_convert_message_payload_rejects_unknown_types(runtime: DifyToolNodeRuntime) -> None: + with pytest.raises(TypeError, match="unsupported tool message payload"): + runtime._convert_message_payload(object()) + + +def test_resolve_provider_icons_prefers_builtin_tool_icons(runtime: DifyToolNodeRuntime) -> None: + plugin = SimpleNamespace( + plugin_id="langgenius/tools", + name="search", + declaration=SimpleNamespace(icon={"plugin": "icon"}), + ) + builtin_tool = SimpleNamespace( + name="langgenius/tools/search", + icon={"builtin": "icon"}, + icon_dark={"builtin": "dark"}, + ) + + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[builtin_tool]), + ): + installer_cls.return_value.list_plugins.return_value = [plugin] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="langgenius/tools/search") + + assert icon == {"builtin": "icon"} + assert icon_dark == {"builtin": "dark"} + + +def test_resolve_provider_icons_returns_default_when_provider_is_unknown(runtime: DifyToolNodeRuntime) -> None: + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[]), + ): + installer_cls.return_value.list_plugins.return_value = [] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="unknown", default_icon="fallback") + + assert icon == "fallback" + assert icon_dark is None + + +@pytest.mark.parametrize( + ("exc", "message"), + [ + (PluginDaemonClientSideError("bad request"), "Failed to invoke tool, error: bad request"), + (ToolInvokeError("broken"), "Failed to invoke tool provider: broken"), + (RuntimeError("unexpected"), "unexpected"), + ], +) +def test_map_invocation_exception_normalizes_runtime_errors( + runtime: DifyToolNodeRuntime, + exc: Exception, + message: str, +) -> None: + error = runtime._map_invocation_exception(exc, provider_name="provider") + + assert isinstance(error, ToolRuntimeInvocationError) + assert str(error) == message diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 9aeab0409e..c8ddc53284 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -2,12 +2,12 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: @@ -17,9 +17,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable(user_id="user", files=[]), - user_inputs={"payload": "value"}, + variable_pool=build_test_variable_pool( + variables=build_system_variables(user_id="user", files=[]), + node_id="node-1", + inputs={"payload": "value"}, ), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index e69c05dc0b..fabc8df73e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,22 +2,38 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable, StringVariable +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.nodes.variable_assigner.common import helpers as common_helpers +from graphon.nodes.variable_assigner.v1 import VariableAssignerNode +from graphon.nodes.variable_assigner.v1.node_data import WriteMode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ArrayStringVariable, StringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool( + *, + conversation_id: str, + conversation_variables: list[StringVariable | ArrayStringVariable], +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=conversation_id), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_overwrite_string_variable(): graph_config = { "edges": [ @@ -71,10 +87,8 @@ def test_overwrite_string_variable(): conversation_id = str(uuid.uuid4()) # construct variable pool - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -108,16 +122,14 @@ def test_overwrite_string_variable(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == input_variable.value - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.value == "the second value" - assert got.to_object() == "the second value" + assert updated_event.variable.value == "the second value" + assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) def test_append_variable_to_array(): @@ -172,10 +184,8 @@ def test_append_variable_to_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) variable_pool.add( @@ -208,15 +218,13 @@ def test_append_variable_to_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == ["the first value", "the second value"] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["the first value", "the second value"] + assert updated_event.variable.value == ["the first value", "the second value"] def test_clear_array(): @@ -265,10 +273,8 @@ def test_clear_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -297,12 +303,10 @@ def test_clear_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == [] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py index a7673c5a14..9ac8bbe9c2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -1,6 +1,6 @@ -from dify_graph.nodes.variable_assigner.v2.enums import Operation -from dify_graph.nodes.variable_assigner.v2.helpers import is_input_value_valid -from dify_graph.variables import SegmentType +from graphon.nodes.variable_assigner.v2.enums import Operation +from graphon.nodes.variable_assigner.v2.helpers import is_input_value_valid +from graphon.variables import SegmentType def test_is_input_value_valid_overwrite_array_string(): diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 6874f3fef1..53346c4a90 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,20 +2,33 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_events import NodeRunVariableUpdatedEvent +from graphon.nodes.variable_assigner.v2 import VariableAssignerNode +from graphon.nodes.variable_assigner.v2.enums import InputType, Operation +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ArrayStringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id="conversation_id"), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_handle_item_directly(): """Test the _handle_item method directly for remove operations.""" # Create variables @@ -106,12 +119,7 @@ def test_remove_first_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -146,11 +154,8 @@ def test_remove_first_from_array(): # Run the node result = list(node.run()) - # Completed run - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["second", "third"] + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["second", "third"] def test_remove_last_from_array(): @@ -194,12 +199,7 @@ def test_remove_last_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -231,11 +231,9 @@ def test_remove_last_from_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["first", "second"] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["first", "second"] def test_remove_first_from_empty_array(): @@ -279,12 +277,7 @@ def test_remove_first_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -316,11 +309,9 @@ def test_remove_first_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_remove_last_from_empty_array(): @@ -364,12 +355,7 @@ def test_remove_last_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -401,11 +387,9 @@ def test_remove_last_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_node_factory_creates_variable_assigner_node(): @@ -433,12 +417,7 @@ def test_node_factory_creates_variable_assigner_node(): }, call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(conversation_variables=[]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 6be5bb23e8..be18391b2c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -324,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from dify_graph.entities.base_node_data import BaseNodeData + from graphon.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index ddf1af5a59..617554ee17 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -6,7 +6,7 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 78dd7ce0f3..6fbd26131d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,7 +8,7 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -16,11 +16,12 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node( @@ -96,6 +97,18 @@ def create_test_file_dict( } +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="webhook-node-1", + inputs=inputs, + ) + + +def expected_factory_mapping(file_dict: dict) -> dict: + return {**file_dict, "upload_file_id": file_dict["related_id"]} + + def test_webhook_node_file_conversion_to_file_variable(): """Test that webhook node converts file dictionaries to FileVariable objects.""" # Create test file dictionary (as it comes from webhook service) @@ -111,9 +124,8 @@ def test_webhook_node_file_conversion_to_file_variable(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -122,14 +134,14 @@ def test_webhook_node_file_conversion_to_file_variable(): "image_upload": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory and variable factory + # Mock the file reference boundary and variable factory with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -153,8 +165,7 @@ def test_webhook_node_file_conversion_to_file_variable(): # Verify file factory was called with correct parameters mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) # Verify segment factory was called to create FileSegment @@ -184,16 +195,15 @@ def test_webhook_node_file_conversion_with_missing_files(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, # No files } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -219,9 +229,8 @@ def test_webhook_node_file_conversion_with_none_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -230,7 +239,7 @@ def test_webhook_node_file_conversion_with_none_file(): "file": None, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -256,9 +265,8 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -267,7 +275,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): "file": "not_a_dict", # Wrapped to match node expectation }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -300,9 +308,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -315,13 +322,13 @@ def test_webhook_node_file_conversion_mixed_parameters(): "file_param": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -350,8 +357,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): # Verify file conversion was called mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) @@ -370,9 +376,8 @@ def test_webhook_node_different_file_types(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -383,13 +388,13 @@ def test_webhook_node_different_file_types(): "video": create_test_file_dict("video.mp4", "video"), }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -430,9 +435,8 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -441,7 +445,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): "file": "just a string", }, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 139f65d6c3..9f954b2090 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,7 +2,7 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -12,13 +12,14 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import FileVariable, StringVariable +from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_init_params import GraphInitParams +from graphon.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool +from graphon.variables import FileVariable, StringVariable +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -62,6 +63,14 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="1", + inputs=inputs, + ) + + def test_webhook_node_basic_initialization(): """Test basic webhook node initialization and configuration.""" data = WebhookData( @@ -76,10 +85,7 @@ def test_webhook_node_basic_initialization(): timeout=30, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - ) + variable_pool = build_webhook_variable_pool({}) node = create_webhook_node(data, variable_pool) @@ -119,9 +125,8 @@ def test_webhook_node_run_with_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "Authorization": "Bearer token123", @@ -132,7 +137,7 @@ def test_webhook_node_run_with_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -155,9 +160,8 @@ def test_webhook_node_run_with_query_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": { @@ -167,7 +171,7 @@ def test_webhook_node_run_with_query_params(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -191,9 +195,8 @@ def test_webhook_node_run_with_body_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -205,7 +208,7 @@ def test_webhook_node_run_with_body_params(): }, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -222,7 +225,6 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -232,7 +234,6 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - tenant_id="1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", @@ -250,9 +251,8 @@ def test_webhook_node_run_with_file_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -262,14 +262,14 @@ def test_webhook_node_run_with_file_params(): "document": file2.to_dict(), }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -284,7 +284,6 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -303,23 +302,22 @@ def test_webhook_node_run_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, "files": {"upload": file_obj.to_dict()}, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -343,10 +341,7 @@ def test_webhook_node_run_empty_webhook_data(): body=[WebhookBodyParameter(name="message", type="string", required=False)], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, # No webhook_data - ) + variable_pool = build_webhook_variable_pool({}) # No webhook_data node = create_webhook_node(data, variable_pool) result = node._run() @@ -369,9 +364,8 @@ def test_webhook_node_run_case_insensitive_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "content-type": "application/json", # lowercase @@ -382,7 +376,7 @@ def test_webhook_node_run_case_insensitive_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -399,12 +393,11 @@ def test_webhook_node_variable_pool_user_inputs(): data = WebhookData(title="Test Webhook") # Add some additional variables to the pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", - }, + } ) variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value")) @@ -430,16 +423,15 @@ def test_webhook_node_different_methods(method): method=method, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index e8ce6f60f7..453e0a8502 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -1,6 +1,6 @@ """Tests for workflow pause related enums and constants.""" -from dify_graph.enums import ( +from graphon.enums import ( WorkflowExecutionStatus, ) diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py new file mode 100644 index 0000000000..0623800b30 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +from pydantic import BaseModel + +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + is_human_input_webapp_enabled, + normalize_human_input_node_data_for_graph, + normalize_node_config_for_graph, + normalize_node_data_for_graph, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert "