From 7b98463916b3260b8ae78f5b77c0e93f12f4728d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 17 Mar 2026 04:16:01 +0800 Subject: [PATCH] chore: save changes Signed-off-by: -LAN- --- api/.importlinter | 7 +- api/context/__init__.py | 82 ++--- .../context/execution_context.py | 125 +++----- api/context/flask_app_context.py | 6 +- api/{dify_graph => }/context/models.py | 2 +- .../console/app/workflow_draft_variable.py | 2 +- api/controllers/console/app/workflow_run.py | 30 +- .../rag_pipeline_draft_variable.py | 2 +- api/controllers/files/tool_files.py | 13 +- api/controllers/files/upload.py | 2 +- api/controllers/inner_api/plugin/plugin.py | 2 +- .../app/apps/advanced_chat/app_generator.py | 7 +- api/core/app/apps/advanced_chat/app_runner.py | 36 ++- .../advanced_chat/generate_task_pipeline.py | 45 ++- api/core/app/apps/base_app_generator.py | 51 +++- .../common/graph_runtime_state_support.py | 7 +- .../common/workflow_response_converter.py | 47 ++- .../app/apps/draft_variable_saver.py} | 15 +- .../app/apps/message_based_app_generator.py | 10 +- .../app/apps/pipeline/pipeline_generator.py | 10 +- api/core/app/apps/pipeline/pipeline_runner.py | 30 +- api/core/app/apps/workflow/app_generator.py | 5 +- api/core/app/apps/workflow/app_runner.py | 25 +- .../apps/workflow/generate_task_pipeline.py | 24 +- api/core/app/apps/workflow_app_runner.py | 27 +- .../conversation_variable_persist_layer.py | 54 ++-- .../app/layers/pause_state_persist_layer.py | 6 +- api/core/app/layers/trigger_post_layer.py | 6 +- api/core/app/workflow/file_runtime.py | 133 +++++++-- api/core/app/workflow/layers/persistence.py | 7 +- api/core/datasource/datasource_manager.py | 5 +- .../datasource/utils/message_transformer.py | 10 +- api/core/ops/aliyun_trace/aliyun_trace.py | 4 +- .../arize_phoenix_trace.py | 4 +- api/core/ops/langfuse_trace/langfuse_trace.py | 4 +- .../ops/langsmith_trace/langsmith_trace.py | 4 +- api/core/ops/opik_trace/opik_trace.py | 4 +- api/core/ops/tencent_trace/tencent_trace.py | 2 +- api/core/ops/weave_trace/weave_trace.py | 4 +- .../processor/paragraph_index_processor.py | 8 +- api/core/rag/retrieval/dataset_retrieval.py | 8 +- api/core/repositories/__init__.py | 11 +- .../celery_workflow_execution_repository.py | 2 +- ...lery_workflow_node_execution_repository.py | 27 +- api/core/repositories/factory.py | 31 +- .../repositories/human_input_repository.py | 150 +++++++--- ...qlalchemy_workflow_execution_repository.py | 2 +- ...hemy_workflow_node_execution_repository.py | 15 +- api/core/tools/signature.py | 41 +++ api/core/tools/tool_file_manager.py | 23 +- api/core/tools/utils/message_transformer.py | 60 +++- api/core/tools/workflow_as_tool/tool.py | 17 +- api/core/workflow/file_reference.py | 52 ++++ api/core/workflow/human_input_compat.py | 247 +++++++++++++++ api/core/workflow/node_factory.py | 62 ++-- api/core/workflow/node_runtime.py | 166 +++++++++-- api/core/workflow/nodes/agent/agent_node.py | 8 +- .../nodes/agent/message_transformer.py | 13 +- .../workflow/nodes/agent/runtime_support.py | 8 +- .../nodes/datasource/datasource_node.py | 16 +- .../knowledge_index/knowledge_index_node.py | 14 +- .../knowledge_retrieval_node.py | 9 +- .../trigger_plugin/trigger_event_node.py | 10 +- .../trigger_schedule/trigger_schedule_node.py | 12 +- .../workflow/nodes/trigger_webhook/node.py | 21 +- api/core/workflow/system_variables.py | 177 +++++++++++ api/core/workflow/template_rendering.py | 29 ++ .../workflow/variable_pool_initializer.py | 15 + .../workflow/variable_prefixes.py} | 2 + api/core/workflow/workflow_entry.py | 56 +++- api/dify_graph/context/__init__.py | 34 --- .../conversation_variable_updater.py | 39 --- api/dify_graph/entities/graph_init_params.py | 4 + api/dify_graph/entities/pause_reason.py | 8 - api/dify_graph/entities/workflow_execution.py | 11 +- .../entities/workflow_node_execution.py | 16 +- api/dify_graph/enums.py | 24 -- api/dify_graph/file/file_manager.py | 32 +- api/dify_graph/file/helpers.py | 108 ++----- api/dify_graph/file/models.py | 93 +++--- api/dify_graph/file/protocols.py | 41 ++- api/dify_graph/file/runtime.py | 49 +-- .../event_management/event_handlers.py | 12 + api/dify_graph/graph_engine/graph_engine.py | 10 +- api/dify_graph/graph_engine/worker.py | 4 +- .../worker_management/worker_pool.py | 4 +- api/dify_graph/graph_events/__init__.py | 2 + api/dify_graph/graph_events/node.py | 7 + .../model_runtime/callbacks/base_callback.py | 18 +- .../callbacks/logging_callback.py | 25 +- .../__base/large_language_model.py | 32 +- .../model_runtime/utils/encoders.py | 54 ++-- api/dify_graph/node_events/__init__.py | 2 + api/dify_graph/node_events/node.py | 7 + api/dify_graph/nodes/base/node.py | 17 +- api/dify_graph/nodes/http_request/executor.py | 2 +- api/dify_graph/nodes/http_request/node.py | 1 - api/dify_graph/nodes/human_input/entities.py | 230 +------------- api/dify_graph/nodes/human_input/enums.py | 17 -- .../nodes/human_input/human_input_node.py | 90 +----- api/dify_graph/nodes/if_else/if_else_node.py | 4 +- .../nodes/iteration/iteration_node.py | 67 +++-- api/dify_graph/nodes/llm/file_saver.py | 6 +- api/dify_graph/nodes/llm/llm_utils.py | 57 ++-- api/dify_graph/nodes/llm/node.py | 14 +- api/dify_graph/nodes/llm/protocols.py | 9 - api/dify_graph/nodes/loop/loop_node.py | 10 +- api/dify_graph/nodes/protocols.py | 4 +- .../question_classifier_node.py | 6 +- api/dify_graph/nodes/runtime.py | 44 ++- api/dify_graph/nodes/start/start_node.py | 9 +- api/dify_graph/nodes/tool/tool_node.py | 27 +- api/dify_graph/nodes/tool_runtime_entities.py | 6 +- .../nodes/variable_assigner/v1/node.py | 41 ++- .../nodes/variable_assigner/v2/node.py | 77 +++-- api/dify_graph/repositories/__init__.py | 14 - .../human_input_form_repository.py | 152 ---------- .../workflow_execution_repository.py | 30 -- .../workflow_node_execution_repository.py | 73 ----- api/dify_graph/runtime/graph_runtime_state.py | 13 + .../runtime/graph_runtime_state_protocol.py | 4 - api/dify_graph/runtime/read_only_wrappers.py | 5 - api/dify_graph/runtime/variable_pool.py | 71 +---- api/dify_graph/system_variable.py | 217 -------------- api/dify_graph/template_rendering.py | 35 +-- api/dify_graph/utils/datetime_utils.py | 20 -- api/dify_graph/variable_loader.py | 8 - .../logstore_workflow_execution_repository.py | 2 +- ...tore_workflow_node_execution_repository.py | 27 +- api/factories/file_factory.py | 67 +++-- api/factories/variable_factory.py | 2 +- api/models/human_input.py | 7 +- api/models/model.py | 36 ++- api/models/workflow.py | 17 +- .../api_workflow_node_execution_repository.py | 2 +- .../api_workflow_run_repository.py | 2 +- .../sqlalchemy_api_workflow_run_repository.py | 31 +- .../human_input_delivery_test_service.py | 16 +- api/services/rag_pipeline/rag_pipeline.py | 47 +-- .../workflow_draft_variable_service.py | 16 +- api/services/workflow_service.py | 112 ++++--- api/tasks/mail_human_input_delivery_task.py | 2 +- .../factories/test_storage_key_loader.py | 1 - .../test_workflow_draft_variable_service.py | 2 +- .../workflow/nodes/test_code.py | 4 +- .../workflow/nodes/test_http.py | 10 +- .../workflow/nodes/test_llm.py | 4 +- .../nodes/test_parameter_extractor.py | 4 +- .../workflow/nodes/test_template_transform.py | 4 +- .../workflow/nodes/test_tool.py | 4 +- .../layers/test_pause_state_persist_layer.py | 5 +- .../test_human_input_form_repository_impl.py | 24 +- .../test_human_input_resume_node_execution.py | 10 +- .../factories/test_storage_key_loader.py | 1 - .../test_human_input_delivery_test.py | 8 +- .../test_workflow_draft_variable_service.py | 2 +- .../test_mail_human_input_delivery_task.py | 13 +- .../controllers/console/app/test_workflow.py | 1 - .../app/workflow_draft_variables_test.py | 4 +- .../test_rag_pipeline_draft_variable.py | 2 +- .../controllers/files/test_tool_files.py | 10 +- .../test_app_runner_conversation_variables.py | 19 +- .../test_app_runner_input_moderation.py | 15 +- .../test_generate_task_pipeline_core.py | 11 +- .../test_graph_runtime_state_support.py | 6 +- .../test_workflow_response_converter.py | 3 +- ...workflow_response_converter_human_input.py | 4 +- ..._workflow_response_converter_resumption.py | 4 +- ..._workflow_response_converter_truncation.py | 8 +- .../app/apps/pipeline/test_pipeline_runner.py | 7 +- .../core/app/apps/test_base_app_generator.py | 3 - .../core/app/apps/test_pause_resume.py | 6 +- .../app/apps/test_workflow_app_runner_core.py | 10 +- .../test_workflow_app_runner_single_node.py | 4 +- .../app/apps/test_workflow_pause_events.py | 4 +- .../workflow/test_generate_task_pipeline.py | 11 +- .../test_generate_task_pipeline_core.py | 41 +-- ...est_conversation_variable_persist_layer.py | 90 ++---- .../layers/test_pause_state_persist_layer.py | 29 +- .../datasource/test_datasource_manager.py | 2 - api/tests/unit_tests/core/file/test_models.py | 16 +- .../ops/aliyun_trace/test_aliyun_trace.py | 4 +- .../test_arize_phoenix_trace.py | 2 +- .../ops/langfuse_trace/test_langfuse_trace.py | 6 +- .../langsmith_trace/test_langsmith_trace.py | 6 +- .../core/ops/opik_trace/test_opik_trace.py | 6 +- .../ops/tencent_trace/test_tencent_trace.py | 2 +- .../unit_tests/core/ops/test_opik_trace.py | 4 +- .../core/ops/weave_trace/test_weave_trace.py | 2 +- .../core/plugin/utils/test_chunk_merger.py | 3 - .../prompt/test_advanced_prompt_transform.py | 5 - ...lery_workflow_node_execution_repository.py | 38 +-- .../core/repositories/test_factory.py | 9 +- .../test_human_input_form_repository_impl.py | 42 +-- api/tests/unit_tests/core/test_file.py | 1 - .../core/tools/test_tool_file_manager.py | 20 +- .../tools/utils/test_message_transformer.py | 21 ++ .../unit_tests/core/variables/test_segment.py | 42 +-- .../variables/test_segment_type_validation.py | 1 - .../context/test_execution_context.py | 12 +- .../entities/test_graph_runtime_state.py | 2 +- .../workflow/entities/test_variable_pool.py | 2 +- .../graph/test_graph_skip_validation.py | 4 +- .../workflow/graph/test_graph_validation.py | 4 +- .../graph_engine/human_input_test_utils.py | 16 +- .../graph_engine/test_graph_state_snapshot.py | 4 +- .../test_human_input_pause_multi_branch.py | 10 +- .../test_human_input_pause_single_branch.py | 10 +- .../graph_engine/test_if_else_streaming.py | 4 +- .../graph_engine/test_loop_contains_answer.py | 3 + .../test_mock_nodes_template_code.py | 20 +- .../test_parallel_human_input_join_resume.py | 18 +- ...rallel_human_input_pause_missing_finish.py | 18 +- .../test_parallel_streaming_workflow.py | 4 +- .../test_pause_deferred_ready_nodes.py | 18 +- .../graph_engine/test_pause_resume_state.py | 16 +- .../test_streaming_conversation_variables.py | 2 + .../graph_engine/test_table_runner.py | 25 +- .../test_variable_update_events.py | 129 ++++++++ .../nodes/agent/test_message_transformer.py | 33 ++ .../core/workflow/nodes/answer/test_answer.py | 4 +- .../test_http_request_executor.py | 32 +- .../http_request/test_http_request_node.py | 4 +- .../human_input/test_email_delivery_config.py | 2 +- .../nodes/human_input/test_entities.py | 104 ++++--- .../test_human_input_form_filled_event.py | 6 +- .../test_iteration_child_engine_errors.py | 4 +- .../test_knowledge_index_node.py | 6 +- .../test_knowledge_retrieval_node.py | 4 +- .../workflow/nodes/llm/test_file_saver.py | 5 +- .../core/workflow/nodes/llm/test_node.py | 38 ++- .../core/workflow/nodes/test_base_node.py | 6 +- .../core/workflow/nodes/test_if_else.py | 13 +- .../core/workflow/nodes/test_list_operator.py | 9 - .../nodes/test_question_classifier_node.py | 5 +- .../nodes/test_start_node_json_object.py | 14 +- .../workflow/nodes/tool/test_tool_node.py | 39 ++- .../nodes/tool/test_tool_node_runtime.py | 58 ++-- .../v1/test_variable_assigner_v1.py | 59 ++-- .../v2/test_variable_assigner_v2.py | 82 ++--- .../webhook/test_webhook_file_conversion.py | 68 +++-- .../nodes/webhook/test_webhook_node.py | 80 +++-- .../core/workflow/test_node_factory.py | 67 ++++- .../core/workflow/test_node_runtime.py | 41 +++ .../core/workflow/test_system_variable.py | 281 +++--------------- .../test_system_variable_read_only_view.py | 202 ------------- .../core/workflow/test_variable_pool.py | 140 ++++----- .../core/workflow/test_workflow_entry.py | 28 +- .../workflow/test_workflow_entry_helpers.py | 29 +- .../model_runtime/utils/test_encoders.py | 6 +- .../factories/test_build_from_mapping.py | 33 +- .../factories/test_variable_factory.py | 13 - api/tests/unit_tests/models/test_workflow.py | 3 - .../test_sqlalchemy_repository.py | 11 +- .../test_human_input_delivery_test_service.py | 20 +- .../services/test_variable_truncator.py | 1 - .../test_workflow_draft_variable_service.py | 5 +- .../test_workflow_human_input_delivery.py | 30 +- api/tests/workflow_test_utils.py | 18 ++ 259 files changed, 3580 insertions(+), 3537 deletions(-) rename api/{dify_graph => }/context/execution_context.py (60%) rename api/{dify_graph => }/context/models.py (78%) rename api/{dify_graph/repositories/draft_variable_repository.py => core/app/apps/draft_variable_saver.py} (69%) create mode 100644 api/core/workflow/file_reference.py create mode 100644 api/core/workflow/human_input_compat.py create mode 100644 api/core/workflow/system_variables.py create mode 100644 api/core/workflow/template_rendering.py create mode 100644 api/core/workflow/variable_pool_initializer.py rename api/{dify_graph/constants.py => core/workflow/variable_prefixes.py} (66%) delete mode 100644 api/dify_graph/context/__init__.py delete mode 100644 api/dify_graph/conversation_variable_updater.py delete mode 100644 api/dify_graph/repositories/__init__.py delete mode 100644 api/dify_graph/repositories/human_input_form_repository.py delete mode 100644 api/dify_graph/repositories/workflow_execution_repository.py delete mode 100644 api/dify_graph/repositories/workflow_node_execution_repository.py delete mode 100644 api/dify_graph/system_variable.py delete mode 100644 api/dify_graph/utils/datetime_utils.py create mode 100644 api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py create mode 100644 api/tests/unit_tests/core/workflow/test_node_runtime.py delete mode 100644 api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py diff --git a/api/.importlinter b/api/.importlinter index 024827c8cf..81a7b01c1b 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -2,6 +2,7 @@ root_packages = core constants + context dify_graph configs controllers @@ -44,6 +45,7 @@ source_modules = forbidden_modules = constants configs + context controllers extensions factories @@ -80,8 +82,6 @@ forbidden_modules = core.tools core.trigger core.variables -ignore_imports = - dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model [importlinter:contract:workflow-third-party-imports] name = Workflow Third-Party Imports @@ -90,9 +90,6 @@ source_modules = dify_graph forbidden_modules = sqlalchemy -ignore_imports = - dify_graph.repositories.draft_variable_repository -> sqlalchemy - dify_graph.nodes.llm.llm_utils -> core.model_manager [importlinter:contract:rsc] name = RSC diff --git a/api/context/__init__.py b/api/context/__init__.py index 969e5f583d..7957eb0076 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -1,74 +1,36 @@ """ -Core Context - Framework-agnostic context management. +Application-layer context adapters. -This module provides context management that is independent of any specific -web framework. Framework-specific implementations register their context -capture functions at application initialization time. - -This ensures the workflow layer remains completely decoupled from Flask -or any other web framework. +Concrete execution-context implementations live here so `dify_graph` only +depends on injected context managers rather than framework state capture. """ -import contextvars -from collections.abc import Callable - -from dify_graph.context.execution_context import ( +from context.execution_context import ( + AppContext, + ContextProviderNotFoundError, ExecutionContext, + ExecutionContextBuilder, IExecutionContext, NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) - -# Global capturer function - set by framework-specific modules -_capturer: Callable[[], IExecutionContext] | None = None - - -def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """ - Register a context capture function. - - This should be called by framework-specific modules (e.g., Flask) - during application initialization. - - Args: - capturer: Function that captures current context and returns IExecutionContext - """ - global _capturer - _capturer = capturer - - -def capture_current_context() -> IExecutionContext: - """ - Capture current execution context. - - This function uses the registered context capturer. If no capturer - is registered, it returns a minimal context with only contextvars - (suitable for non-framework environments like tests or standalone scripts). - - Returns: - IExecutionContext with captured context - """ - if _capturer is None: - # No framework registered - return minimal context - return ExecutionContext( - app_context=NullAppContext(), - context_vars=contextvars.copy_context(), - ) - - return _capturer() - - -def reset_context_provider() -> None: - """ - Reset the context capturer. - - This is primarily useful for testing to ensure a clean state. - """ - global _capturer - _capturer = None - +from context.models import SandboxContext __all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", "register_context_capturer", "reset_context_provider", ] diff --git a/api/dify_graph/context/execution_context.py b/api/context/execution_context.py similarity index 60% rename from api/dify_graph/context/execution_context.py rename to api/context/execution_context.py index e3007530f0..dd825c2f91 100644 --- a/api/dify_graph/context/execution_context.py +++ b/api/context/execution_context.py @@ -1,5 +1,8 @@ """ -Execution Context - Abstracted context management for workflow execution. +Application-layer execution context adapters. + +Concrete context capture lives outside `dify_graph` so the graph package only +consumes injected context managers when it needs to preserve thread-local state. """ import contextvars @@ -16,33 +19,33 @@ class AppContext(ABC): """ Abstract application context interface. - This abstraction allows workflow execution to work with or without Flask - by providing a common interface for application context management. + Application adapters can implement this to restore framework-specific state + such as Flask app context around worker execution. """ @abstractmethod def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" - pass + raise NotImplementedError @abstractmethod def get_extension(self, name: str) -> Any: - """Get Flask extension by name (e.g., 'db', 'cache').""" - pass + """Get application extension by name.""" + raise NotImplementedError @abstractmethod def enter(self) -> AbstractContextManager[None]: """Enter the application context.""" - pass + raise NotImplementedError @runtime_checkable class IExecutionContext(Protocol): """ - Protocol for execution context. + Protocol for enterable execution context objects. - This protocol defines the interface that all execution contexts must implement, - allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + Concrete implementations may carry extra framework state, but callers only + depend on standard context-manager behavior plus optional user metadata. """ def __enter__(self) -> "IExecutionContext": @@ -62,14 +65,10 @@ class IExecutionContext(Protocol): @final class ExecutionContext: """ - Execution context for workflow execution in worker threads. + Generic execution context used by application-layer adapters. - This class encapsulates all context needed for workflow execution: - - Application context (Flask app or standalone) - - Context variables for Python contextvars - - User information (optional) - - It is designed to be serializable and passable to worker threads. + It restores captured `contextvars` and optionally enters an application + context before the worker executes graph logic. """ def __init__( @@ -78,14 +77,6 @@ class ExecutionContext: context_vars: contextvars.Context | None = None, user: Any = None, ) -> None: - """ - Initialize execution context. - - Args: - app_context: Application context (Flask or standalone) - context_vars: Python contextvars to preserve - user: User object (optional) - """ self._app_context = app_context self._context_vars = context_vars self._user = user @@ -98,27 +89,21 @@ class ExecutionContext: @property def context_vars(self) -> contextvars.Context | None: - """Get context variables.""" + """Get captured context variables.""" return self._context_vars @property def user(self) -> Any: - """Get user object.""" + """Get captured user object.""" return self._user @contextmanager def enter(self) -> Generator[None, None, None]: - """ - Enter this execution context. - - This is a convenience method that creates a context manager. - """ - # Restore context variables if provided + """Enter this execution context.""" if self._context_vars: for var, val in self._context_vars.items(): var.set(val) - # Enter app context if available if self._app_context is not None: with self._app_context.enter(): yield @@ -141,18 +126,10 @@ class ExecutionContext: class NullAppContext(AppContext): """ - Null implementation of AppContext for non-Flask environments. - - This is used when running without Flask (e.g., in tests or standalone mode). + Null application context for non-framework environments. """ def __init__(self, config: dict[str, Any] | None = None) -> None: - """ - Initialize null app context. - - Args: - config: Optional configuration dictionary - """ self._config = config or {} self._extensions: dict[str, Any] = {} @@ -165,7 +142,7 @@ class NullAppContext(AppContext): return self._extensions.get(name) def set_extension(self, name: str, extension: Any) -> None: - """Set extension by name.""" + """Register an extension for tests or standalone execution.""" self._extensions[name] = extension @contextmanager @@ -176,9 +153,7 @@ class NullAppContext(AppContext): class ExecutionContextBuilder: """ - Builder for creating ExecutionContext instances. - - This provides a fluent API for building execution contexts. + Builder for creating `ExecutionContext` instances. """ def __init__(self) -> None: @@ -211,63 +186,42 @@ class ExecutionContextBuilder: _capturer: Callable[[], IExecutionContext] | None = None - -# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. -# Key mapping: -# (name, tenant_id) -> provider -# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") -# - tenant_id: tenant identifier string -# Value: -# provider: Callable[[], BaseModel] returning the typed context value -# Type-safety note: -# - This registry cannot enforce that all providers for a given name return the same BaseModel type. -# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), -# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and -# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} T = TypeVar("T", bound=BaseModel) class ContextProviderNotFoundError(KeyError): - """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + """Raised when a tenant-scoped context provider is missing.""" pass def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """Register a single enterable execution context capturer (e.g., Flask).""" + """Register an enterable execution context capturer.""" global _capturer _capturer = capturer def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: - """Register a tenant-specific provider for a named context. - - Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. - Consider adding a typed wrapper for this registration in your feature module. - """ + """Register a tenant-specific provider for a named context.""" _tenant_context_providers[(name, tenant_id)] = provider def read_context(name: str, *, tenant_id: str) -> BaseModel: - """ - Read a context value for a specific tenant. - - Raises KeyError if the provider for (name, tenant_id) is not registered. - """ - prov = _tenant_context_providers.get((name, tenant_id)) - if prov is None: + """Read a context value for a specific tenant.""" + provider = _tenant_context_providers.get((name, tenant_id)) + if provider is None: raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") - return prov() + return provider() def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal - context with NullAppContext + copy of current contextvars. + If no framework adapter is registered, return a minimal context that only + restores `contextvars`. """ if _capturer is None: return ExecutionContext( @@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext: def reset_context_provider() -> None: - """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + """Reset the capturer and tenant-scoped providers.""" global _capturer _capturer = None _tenant_context_providers.clear() + + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 324a9ee8b4..eddd6448d8 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,11 +10,7 @@ from typing import Any, final from flask import Flask, current_app, g -from dify_graph.context import register_context_capturer -from dify_graph.context.execution_context import ( - AppContext, - IExecutionContext, -) +from context.execution_context import AppContext, IExecutionContext, register_context_capturer @final diff --git a/api/dify_graph/context/models.py b/api/context/models.py similarity index 78% rename from api/dify_graph/context/models.py rename to api/context/models.py index af5a4b2614..e86f8d50c9 100644 --- a/api/dify_graph/context/models.py +++ b/api/context/models.py @@ -7,7 +7,7 @@ class SandboxContext(BaseModel): """Typed context for sandbox integration. All fields optional by design.""" sandbox_url: AnyHttpUrl | None = None - sandbox_token: str | None = None # optional, if later needed for auth + sandbox_token: str | None = None __all__ = ["SandboxContext"] diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index b78d97a382..07267a6aff 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,7 +15,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.file import helpers as file_helpers from dify_graph.variables.segment_group import SegmentGroup from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 7ac653395e..35041e0fd8 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -32,6 +32,7 @@ from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom +from models.human_input import HumanInputFormRecipient, RecipientType from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME @@ -47,6 +48,28 @@ def _build_backstage_input_url(form_token: str | None) -> str | None: return f"{base_url.rstrip('/')}/form/{form_token}" +def _load_form_tokens_by_form_id(form_ids: list[str]) -> dict[str, str]: + if not form_ids: + return {} + + priority = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, + } + tokens_by_form_id: dict[str, tuple[int, str]] = {} + with sessionmaker(bind=db.engine, expire_on_commit=False)() as session: + stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(stmt): + if recipient.recipient_type not in priority or not recipient.access_token: + continue + candidate = (priority[recipient.recipient_type], recipient.access_token) + current = tokens_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + tokens_by_form_id[recipient.form_id] = candidate + return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} + + # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 @@ -496,6 +519,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + form_tokens_by_form_id = _load_form_tokens_by_form_id( + [reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)] + ) # Build response paused_at = pause_entity.paused_at if pause_entity else None @@ -514,7 +540,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): "pause_type": { "type": "human_input", "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), + "backstage_input_url": _build_backstage_input_url( + form_tokens_by_form_id.get(reason.form_id) + ), }, } ) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index c5dadb75f5..2768e1ded4 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,7 +21,7 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 9e3fb3a90b..2f1e2f28bd 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -70,22 +70,25 @@ class ToolFileApi(Resource): except Exception: raise UnsupportedFileTypeError() + mime_type = tool_file.mime_type + filename = tool_file.filename + response = Response( stream, - mimetype=tool_file.mimetype, + mimetype=mime_type, direct_passthrough=True, headers={}, ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args.as_attachment: - encoded_filename = quote(tool_file.name) + if args.as_attachment and filename: + encoded_filename = quote(filename) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" enforce_download_for_html( response, - mime_type=tool_file.mimetype, - filename=tool_file.name, + mime_type=mime_type, + filename=filename, extension=extension, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 52690a12e1..ed3278a28b 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden import services +from core.tools.signature import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8b3950e6..4339ec0513 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -28,7 +28,7 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file.helpers import get_signed_file_url_for_plugin +from core.tools.signature import get_signed_file_url_for_plugin from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5d974335ff..52e6dc8b7c 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -24,6 +24,7 @@ from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -34,13 +35,9 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import ( - DraftVariableSaverFactory, -) -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 66037696af..6164001324 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -25,14 +25,19 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import ( + build_bootstrap_variables, + build_system_variables, + system_variables_to_mapping, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.enums import WorkflowType from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import VariableLoader from dify_graph.variables.variables import Variable from extensions.ext_database import db @@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - system_inputs = SystemVariable( + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, conversation_id=self.conversation.id, @@ -150,7 +155,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.application_generate_entity.inputs = new_inputs self.application_generate_entity.query = new_query - system_inputs.query = new_query + system_inputs = build_system_variables( + system_variables_to_mapping(system_inputs), + query=new_query, + ) # annotation reply if self.handle_annotation_reply( @@ -166,14 +174,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Create a variable pool. # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=new_inputs, - environment_variables=self._workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=conversation_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + conversation_variables=conversation_variables, + ), ) + root_node_id = get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs) # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) @@ -185,6 +196,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, + root_node_id=root_node_id, ) db.session.close() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f7b5030d33..fd8d6d1992 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -14,6 +14,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -65,14 +66,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.workflow.file_reference import parse_file_reference +from core.workflow.system_variables import build_system_variables from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import WorkflowExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile @@ -83,6 +84,21 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) +def _resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + return None + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -117,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( query=message.query, files=application_generate_entity.files, conversation_id=conversation.id, @@ -741,8 +757,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( tenant_id=self._workflow_tenant_id, + workflow_execution_id=self._workflow_run_id, ) - form = form_repository.get_form(self._workflow_run_id, node_id) + form = form_repository.get_form(node_id) if form is None: return None return form.id @@ -940,7 +957,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): transfer_method=file["transfer_method"], url=file["remote_url"], belongs_to=MessageFileBelongsTo.ASSISTANT, - upload_file_id=file["related_id"], + upload_file_id=_resolve_file_record_id(file), created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else CreatorUserRole.END_USER, @@ -1003,13 +1020,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): return message def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20e6ac98ea..4ce52822ea 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -3,15 +3,16 @@ from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import NodeType -from dify_graph.file import File, FileUploadConfig -from dify_graph.repositories.draft_variable_repository import ( +from core.app.apps.draft_variable_saver import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.enums import NodeType +from dify_graph.file import File, FileUploadConfig from dify_graph.variables.input_entities import VariableEntityType +from extensions.ext_database import db from factories import file_factory from libs.orjson import orjson_dumps from models import Account, EndUser @@ -21,6 +22,40 @@ if TYPE_CHECKING: from dify_graph.variables.input_entities import VariableEntity +@final +class _DebuggerDraftVariableSaver: + """Adapter that binds SQLAlchemy session setup outside the saver port.""" + + def __init__( + self, + *, + account: Account, + app_id: str, + node_id: str, + node_type: NodeType, + node_execution_id: str, + enclosing_node_id: str | None = None, + ) -> None: + self._account = account + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + with Session(db.engine) as session, session.begin(): + DraftVariableSaverImpl( + session=session, + app_id=self._app_id, + node_id=self._node_id, + node_type=self._node_type, + node_execution_id=self._node_execution_id, + enclosing_node_id=self._enclosing_node_id, + user=self._account, + ).save(process_data, outputs) + + class BaseAppGenerator: def _prepare_user_inputs( self, @@ -226,32 +261,30 @@ class BaseAppGenerator: assert isinstance(account, Account) def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - return DraftVariableSaverImpl( - session=session, + return _DebuggerDraftVariableSaver( + account=account, app_id=app_id, node_id=node_id, node_type=node_type, node_execution_id=node_execution_id, enclosing_node_id=enclosing_node_id, - user=account, ) else: def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: + _ = app_id, node_id, node_type, node_execution_id, enclosing_node_id return NoopDraftVariableSaver() return draft_var_saver_factory diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 6a8e436163..c2343c2108 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.runtime import GraphRuntimeState if TYPE_CHECKING: @@ -30,10 +31,10 @@ class GraphRuntimeStateSupport: return self._resolve_graph_runtime_state(graph_runtime_state) def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: - system_variables = graph_runtime_state.variable_pool.system_variables - if not system_variables or not system_variables.workflow_execution_id: + workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_run_id: raise ValueError("workflow_execution_id missing from runtime state") - return str(system_variables.workflow_execution_id) + return workflow_run_id def _resolve_graph_runtime_state( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 621b0d8cf3..da2e6e32a3 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping, Sequence @@ -50,25 +51,25 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager +from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from dify_graph.file import FILE_MODEL_IDENTITY, File from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment +from dify_graph.variables.variables import Variable from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser -from models.human_input import HumanInputForm +from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowRun from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator @@ -111,11 +112,11 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - system_variables: SystemVariable, + system_variables: Sequence[Variable], ): self._application_generate_entity = application_generate_entity self._user = user - self._system_variables = system_variables + self._system_variables = system_variables_to_mapping(system_variables) self._workflow_inputs = self._prepare_workflow_inputs() # Disable truncation for SERVICE_API calls to keep backward compatibility. @@ -133,7 +134,7 @@ class WorkflowResponseConverter: # ------------------------------------------------------------------ def _prepare_workflow_inputs(self) -> Mapping[str, Any]: inputs = dict(self._application_generate_entity.inputs) - for field_name, value in self._system_variables.to_dict().items(): + for field_name, value in self._system_variables.items(): # TODO(@future-refactor): store system variables separately from user inputs so we don't # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. if field_name == SystemVariableKey.CONVERSATION_ID: @@ -318,13 +319,37 @@ class WorkflowResponseConverter: pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] expiration_times_by_form_id: dict[str, datetime] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_token_by_form_id: dict[str, tuple[int, str]] = {} if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) + stmt = select( + HumanInputForm.id, + HumanInputForm.expiration_time, + HumanInputForm.form_definition, + ).where(HumanInputForm.id.in_(human_input_form_ids)) + recipient_stmt = select(HumanInputFormRecipient).where( + HumanInputFormRecipient.form_id.in_(human_input_form_ids) ) + token_priority = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, + } with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): + for form_id, expiration_time, form_definition in session.execute(stmt): expiration_times_by_form_id[str(form_id)] = expiration_time + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + for recipient in session.scalars(recipient_stmt): + if recipient.recipient_type not in token_priority or not recipient.access_token: + continue + candidate = (token_priority[recipient.recipient_type], recipient.access_token) + current = form_token_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + form_token_by_form_id[recipient.form_id] = candidate responses: list[StreamResponse] = [] @@ -344,8 +369,8 @@ class WorkflowResponseConverter: form_content=reason.form_content, inputs=reason.inputs, actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, + display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False), + form_token=form_token_by_form_id.get(reason.form_id, (99, None))[1], resolved_default_values=reason.resolved_default_values, expiration_time=int(expiration_time.timestamp()), ), diff --git a/api/dify_graph/repositories/draft_variable_repository.py b/api/core/app/apps/draft_variable_saver.py similarity index 69% rename from api/dify_graph/repositories/draft_variable_repository.py rename to api/core/app/apps/draft_variable_saver.py index b2ebfacffd..4963300e94 100644 --- a/api/dify_graph/repositories/draft_variable_repository.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -4,31 +4,30 @@ import abc from collections.abc import Mapping from typing import Any, Protocol -from sqlalchemy.orm import Session - from dify_graph.enums import NodeType class DraftVariableSaver(Protocol): @abc.abstractmethod - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + """Persist node draft variables for a completed execution.""" + raise NotImplementedError class DraftVariableSaverFactory(Protocol): @abc.abstractmethod def __call__( self, - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - pass + """Build a saver bound to a concrete node execution.""" + raise NotImplementedError class NoopDraftVariableSaver(DraftVariableSaver): - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + return None diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 64c28ca60f..9825cf9bf5 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -28,6 +28,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.file_reference import parse_file_reference from extensions.ext_database import db from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic @@ -42,6 +43,13 @@ from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) +def _resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return reference + return parsed_reference.record_id + + class MessageBasedAppGenerator(BaseAppGenerator): def _handle_response( self, @@ -227,7 +235,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): transfer_method=file.transfer_method, belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, - upload_file_id=file.related_id, + upload_file_id=_resolve_file_record_id(file.reference), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19d67eb108..be4f1b5841 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -18,6 +18,7 @@ import contexts from configs import dify_config from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager @@ -34,11 +35,12 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from libs.flask_utils import preserve_flask_contexts diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index e767766bdb..1249303285 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -12,16 +12,17 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import CHILD_SYNC_VARIABLE_NODE_IDS from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import VariableLoader from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db @@ -112,7 +113,7 @@ class PipelineRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=files, user_id=user_id, app_id=app_config.app_id, @@ -142,19 +143,25 @@ class PipelineRunner(WorkflowBasedAppRunner): ) ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - rag_pipeline_variables=rag_pipeline_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=workflow.environment_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ) + root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id( + workflow.graph_dict + ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( graph_runtime_state=graph_runtime_state, - start_node_id=self.application_generate_entity.start_node_id, + start_node_id=root_node_id, workflow=workflow, user_from=user_from, invoke_from=invoke_from, @@ -268,6 +275,7 @@ class PipelineRunner(WorkflowBasedAppRunner): invoke_from=invoke_from, ), call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) node_factory = DifyNodeFactory( diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6fbe19a3b2..6fd8efcc32 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -17,6 +17,7 @@ from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager @@ -30,11 +31,9 @@ from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index caea8b6b95..0cae506a4b 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -8,14 +8,15 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.enums import WorkflowType from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span @@ -96,7 +97,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=self.application_generate_entity.files, user_id=self._sys_user_id, app_id=app_config.app_id, @@ -104,12 +105,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=self._workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + ), ) + root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph = self._init_graph( @@ -120,7 +125,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, - root_node_id=self._root_node_id, + root_node_id=root_node_id, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 96dd8c5445..6baa9cfffb 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, @@ -55,11 +56,10 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole @@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory self._workflow = workflow - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( files=application_generate_entity.files, user_id=user_session_id, app_id=application_generate_entity.app_config.app_id, @@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): return response def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index adc6cce9af..ded8fb0035 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -34,6 +34,13 @@ from core.app.entities.queue_entities import ( ) from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + inject_default_system_variable_mappings, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import CHILD_SYNC_VARIABLE_NODE_IDS from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -68,7 +75,6 @@ from dify_graph.graph_events import ( ) from dify_graph.graph_events.graph import GraphRunAbortedEvent from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -131,6 +137,7 @@ class WorkflowBasedAppRunner: invoke_from=invoke_from, ), call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) # Use the provided graph_runtime_state for consistent state management @@ -173,14 +180,15 @@ class WorkflowBasedAppRunner: ValueError: If neither single_iteration_run nor single_loop_run is specified """ # Create initial runtime state with variable pool containing environment variables - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), environment_variables=workflow.environment_variables, ), - start_at=time.time(), ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: @@ -284,6 +292,7 @@ class WorkflowBasedAppRunner: invoke_from=InvokeFrom.DEBUGGER, ), call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) node_factory = DifyNodeFactory( @@ -325,6 +334,12 @@ class WorkflowBasedAppRunner: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=target_node_config["id"], + node_type=node_type, + node_data=target_node_config["data"], + variable_mapping=variable_mapping, + ) load_into_variable_pool( variable_loader=self._variable_loader, diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d227e4e904..cedb87490b 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,19 @@ +""" +Persist conversation-scoped variable updates emitted by the graph engine. + +The graph package emits generic variable update events and stays unaware of +conversation identity or storage concerns. This layer lives in the application +core, listens to those generic events, and persists only the `conversation.*` +scope updates that matter to chat applications. +""" + import logging -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import BuiltinNodeTypes +from core.workflow.system_variables import SystemVariableKey, get_system_text +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.variables import VariableBase +from dify_graph.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -20,41 +27,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): pass def on_event(self, event: GraphEngineEvent) -> None: - if not isinstance(event, NodeRunSucceededEvent): - return - if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: - return - if self.graph_runtime_state is None: + if not isinstance(event, NodeRunVariableUpdatedEvent): return - updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] - if not updated_variables: + selector = event.variable.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) return - conversation_id = self.graph_runtime_state.system_variable.conversation_id + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) if conversation_id is None: return - updated_any = False - for item in updated_variables: - selector = item.selector - if len(selector) < 2: - logger.warning("Conversation variable selector invalid. selector=%s", selector) - continue - if selector[0] != CONVERSATION_VARIABLE_NODE_ID: - continue - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - logger.warning( - "Conversation variable not found in variable pool. selector=%s", - selector, - ) - continue - self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) - updated_any = True + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + return - if updated_any: - self._conversation_variable_updater.flush() + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 4370c01a0b..93ab7ae9ce 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -6,6 +6,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events.base import GraphEngineEvent from dify_graph.graph_events.graph import GraphRunPausedEvent @@ -119,7 +120,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer): generate_entity=entity_wrapper, ) - workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a4019a83e1..da9c6bada9 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -5,6 +5,7 @@ from typing import Any, ClassVar from pydantic import TypeAdapter from core.db.session_factory import session_factory +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events.base import GraphEngineEvent from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent @@ -59,7 +60,10 @@ class TriggerPostLayer(GraphEngineLayer): outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id - workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id, "Workflow run id is not set" total_tokens = self.graph_runtime_state.total_tokens diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index e0f8d27111..f674216d0c 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -1,34 +1,32 @@ from __future__ import annotations +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse from collections.abc import Generator +from typing import TYPE_CHECKING, Literal from configs import dify_config +from core.db.session_factory import session_factory from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file +from core.workflow.file_reference import parse_file_reference +from dify_graph.file.enums import FileTransferMethod from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol from dify_graph.file.runtime import set_workflow_file_runtime from extensions.ext_storage import storage +from models import ToolFile, UploadFile + +if TYPE_CHECKING: + from dify_graph.file.models import File class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): """Production runtime wiring for ``dify_graph.file``.""" - @property - def files_url(self) -> str: - return dify_config.FILES_URL - - @property - def internal_files_url(self) -> str | None: - return dify_config.INTERNAL_FILES_URL - - @property - def secret_key(self) -> str: - return dify_config.SECRET_KEY - - @property - def files_access_timeout(self) -> int: - return dify_config.FILES_ACCESS_TIMEOUT - @property def multimodal_send_format(self) -> str: return dify_config.MULTIMODAL_SEND_FORMAT @@ -39,9 +37,110 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + storage_key = self._resolve_storage_key(file=file) + data = storage.load(storage_key, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {storage_key} is not a bytes object") + return data + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + return file.remote_url + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if file.transfer_method == FileTransferMethod.LOCAL_FILE: + return self.resolve_upload_file_url( + upload_file_id=parsed_reference.record_id, + for_external=for_external, + ) + if file.transfer_method in {FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE}: + if file.extension is None: + raise ValueError("Missing file extension") + return self.resolve_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + return None + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + base_url = self._base_url(for_external=for_external) + url = f"{base_url}/files/{upload_file_id}/file-preview" + query = self._sign_query(payload=f"file-preview|{upload_file_id}") + if as_attachment: + query["as_attachment"] = "true" + return f"{url}?{urllib.parse.urlencode(query)}" + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: + payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest() + if sign != base64.urlsafe_b64encode(recalculated).decode(): + return False + return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def _base_url(*, for_external: bool) -> str: + if for_external: + return dify_config.FILES_URL + return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + @staticmethod + def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + + def _sign_query(self, *, payload: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest() + return { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(sign).decode(), + } + + def _resolve_storage_key(self, *, file: File) -> str: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if parsed_reference.storage_key: + return parsed_reference.storage_key + + record_id = parsed_reference.record_id + with session_factory.create_session() as session: + if file.transfer_method in { + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + }: + upload_file = session.get(UploadFile, record_id) + if upload_file is None: + raise ValueError(f"Upload file {record_id} not found") + return upload_file.key + + tool_file = session.get(ToolFile, record_id) + if tool_file is None: + raise ValueError(f"Tool file {record_id} not found") + return tool_file.file_key + def bind_dify_workflow_file_runtime() -> None: set_workflow_file_runtime(DifyWorkflowFileRuntime()) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 99b64b3ab5..62324fee35 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -17,10 +17,11 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution from dify_graph.enums import ( - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, @@ -43,8 +44,6 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.node_events import NodeRunResult -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.datetime_utils import naive_utc_now diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 6b9751a5a4..28edbc25f4 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -24,6 +24,7 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import WorkflowNodeExecutionMetadataKey @@ -351,13 +352,11 @@ class DatasourceManager: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key), size=upload_file.size, - storage_key=upload_file.key, url=upload_file.source_url, ) return file_info diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 2881888e27..455656b396 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,6 +4,7 @@ from mimetypes import guess_extension, guess_type from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager +from core.workflow.file_reference import parse_file_reference from dify_graph.file import File, FileTransferMethod, FileType from models.tools import ToolFile @@ -103,8 +104,13 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("datasource file is missing reference") + url = cls.get_datasource_file_url( + datasource_file_id=parsed_reference.record_id, + extension=file.extension, + ) if file.type == FileType.IMAGE: yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 18f35b5b9c..2a4527a349 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + return workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id + ) def build_workflow_node_span( self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 7cb54b2c88..7cf599f4ac 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -275,8 +275,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) try: diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6e62387a1f..394b4e682c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 32a0c77fe2..c65440935e 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fb72bc2381..70b1fb67a8 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 7e56b1effa..d7e8ee964b 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id) return list(executions) except Exception: diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2a657b672c..98b232410a 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) # rearrange workflow_node_executions by starting time diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 80b0e20784..d49e4c13a3 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -27,6 +27,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor, Su from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.file_reference import build_file_reference from dify_graph.file import File, FileTransferMethod, FileType, file_manager from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import ( @@ -604,13 +605,14 @@ class ParagraphIndexProcessor(BaseIndexProcessor): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + storage_key=upload_file.key, + ), size=upload_file.size, - storage_key=upload_file.key, ) file_objects.append(file_obj) except Exception as e: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3b808b36dc..d2c7300b42 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, @@ -517,13 +518,14 @@ class DatasetRetrieval: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=segment.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + storage_key=upload_file.key, + ), size=upload_file.size, - storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), ) context_files.append(attachment_info) diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..cfa9962ea8 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -4,7 +4,13 @@ from __future__ import annotations from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .factory import ( + DifyCoreRepositoryFactory, + OrderConfig, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -12,7 +18,10 @@ __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", + "OrderConfig", "RepositoryImportError", "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", + "WorkflowExecutionRepository", + "WorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 57764574d7..a93727db2a 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -11,8 +11,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 650cf79550..b0db967153 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -12,11 +12,11 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.repositories.workflow_node_execution_repository import ( +from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # For now, we'll re-raise the exception raise - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + Retrieve all workflow node executions for a workflow execution from cache. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results Returns: A sequence of WorkflowNodeExecution instances """ try: - # Get execution IDs for this workflow run from cache - execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + # Get execution IDs for this workflow execution from cache + execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, []) # Retrieve executions from cache result = [] @@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): for field_name in reversed(order_config.order_by): result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) - logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + logger.debug( + "Retrieved %d workflow node executions for execution %s from cache", + len(result), + workflow_execution_id, + ) return result except Exception: - logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + logger.exception( + "Failed to get workflow node executions for execution %s from cache", + workflow_execution_id, + ) return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dc9f8c96bf..76caf527db 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -from typing import Union +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Protocol, Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom +@dataclass +class OrderConfig: + """Configuration for ordering node execution instances.""" + + order_by: list[str] + order_direction: Literal["asc", "desc"] | None = None + + +class WorkflowExecutionRepository(Protocol): + def save(self, execution: WorkflowExecution): ... + + +class WorkflowNodeExecutionRepository(Protocol): + def save(self, execution: WorkflowNodeExecution): ... + + def save_execution_data(self, execution: WorkflowNodeExecution): ... + + def get_by_workflow_execution( + self, + workflow_execution_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: ... + + class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 8966e4edca..2412fbdb5b 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -2,33 +2,22 @@ import dataclasses import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Protocol from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( + BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, + InteractiveSurfaceDeliveryMethod, ) +from dify_graph.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin @@ -36,6 +25,7 @@ from models.human_input import ( BackstageRecipientPayload, ConsoleDeliveryPayload, ConsoleRecipientPayload, + DeliveryMethodType, EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, @@ -58,6 +48,65 @@ class _WorkspaceMemberInfo: email: str +class FormNotFoundError(Exception): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str | None + node_id: str + form_config: HumanInputNodeData + rendered_content: str + delivery_methods: Sequence[DeliveryChannelConfig] + display_in_ui: bool + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + +class HumanInputFormRecipientEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def token(self) -> str: ... + + +class HumanInputFormEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def submission_token(self) -> str | None: ... + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... + + +class HumanInputFormRepository(Protocol): + def get_form(self, node_id: str) -> HumanInputFormEntity | None: ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ... + + class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): def __init__(self, recipient_model: HumanInputFormRecipient): self._recipient_model = recipient_model @@ -77,7 +126,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( + self._interactive_surface_recipient = next( ( recipient for recipient in recipient_models @@ -98,12 +147,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._form_model.id @property - def web_app_token(self): + def submission_token(self) -> str | None: if self._console_recipient is not None: return self._console_recipient.access_token - if self._web_app_recipient is None: + if self._interactive_surface_recipient is None: return None - return self._web_app_recipient.access_token + return self._interactive_surface_recipient.access_token @property def recipients(self) -> list[HumanInputFormRecipientEntity]: @@ -202,9 +251,15 @@ class HumanInputFormRepositoryImpl: *, tenant_id: str, app_id: str | None = None, - ): + workflow_execution_id: str | None = None, + invoke_source: str | None = None, + submission_actor_id: str | None = None, + ) -> None: self._tenant_id = tenant_id self._app_id = app_id + self._workflow_execution_id = workflow_execution_id + self._invoke_source = invoke_source + self._submission_actor_id = submission_actor_id def _delivery_method_to_model( self, @@ -221,7 +276,7 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): + if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, @@ -249,16 +304,16 @@ class HumanInputFormRepositoryImpl: delivery_id: str, recipients_config: EmailRecipients, ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + bound_reference_ids = [ + recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient) ] external_emails = [ recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) ] - if recipients_config.whole_workspace: + if recipients_config.include_bound_group: members = self._query_all_workspace_members(session=session) else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids) return self._create_email_recipients_from_resolved( form_id=form_id, @@ -340,11 +395,33 @@ class HumanInputFormRepositoryImpl: rows = session.execute(stmt).all() return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def _should_create_console_recipient( + self, + *, + form_config: HumanInputNodeData, + form_kind: HumanInputFormKind, + ) -> bool: + if form_kind != HumanInputFormKind.RUNTIME: + return False + if self._invoke_source == "debugger": + return True + if self._invoke_source == "explore": + return form_config.is_webapp_enabled() + return False + + def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool: + return form_kind == HumanInputFormKind.RUNTIME and ( + self._invoke_source is not None or self._submission_actor_id is not None + ) + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config - app_id = params.app_id or self._app_id + app_id = self._app_id if not app_id: raise ValueError("app_id is required to create a human input form") + workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id + if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None: + raise ValueError("workflow_execution_id is required for runtime human input forms") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -365,7 +442,7 @@ class HumanInputFormRepositoryImpl: id=form_id, tenant_id=self._tenant_id, app_id=app_id, - workflow_run_id=params.workflow_execution_id, + workflow_run_id=workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, form_definition=form_definition.model_dump_json(), @@ -384,7 +461,7 @@ class HumanInputFormRepositoryImpl: session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( + if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models ): console_delivery_id = str(uuidv7()) @@ -400,13 +477,13 @@ class HumanInputFormRepositoryImpl: delivery_id=console_delivery_id, recipient_type=RecipientType.CONSOLE, recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(console_delivery) session.add(console_recipient) recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( + if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models ): backstage_delivery_id = str(uuidv7()) @@ -422,7 +499,7 @@ class HumanInputFormRepositoryImpl: delivery_id=backstage_delivery_id, recipient_type=RecipientType.BACKSTAGE, recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(backstage_delivery) @@ -432,9 +509,12 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + if self._workflow_execution_id is None: + raise ValueError("workflow_execution_id is required to load runtime human input forms") + form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.workflow_run_id == self._workflow_execution_id, HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 55e96515ac..011a79b07b 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -9,9 +9,9 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from dify_graph.entities import WorkflowExecution from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7373ebc7cc..c9d0bd8597 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -17,10 +17,10 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.entities import WorkflowNodeExecution from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_storage import storage from libs.helper import extract_tenant_id @@ -518,29 +518,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return db_models - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. This method always queries the database to ensure complete and ordered results, but updates the cache with any retrieved executions. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of node execution instances """ - # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) + db_models = self.get_db_models_by_workflow_run(workflow_execution_id, order_config, triggered_from) with ThreadPoolExecutor(max_workers=10) as executor: domain_models = executor.map(self._to_domain_model, db_models, timeout=30) diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 22e099deba..1807226924 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -3,6 +3,7 @@ import hashlib import hmac import os import time +import urllib.parse from configs import dify_config @@ -58,3 +59,43 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: + """Build the signed upload URL used by the plugin-facing file upload endpoint.""" + + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + upload_url = f"{base_url}/files/upload/for-plugin" + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": encoded_sign, + "user_id": user_id, + "tenant_id": tenant_id, + } + ) + return f"{upload_url}?{query}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str +) -> bool: + """Verify the signature used by the plugin-facing file upload endpoint.""" + + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 210f488afc..fef899bb76 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -14,7 +14,8 @@ import httpx from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from dify_graph.file.models import ToolFile as ToolFilePydanticModel +from core.workflow.file_reference import build_file_reference +from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile @@ -23,6 +24,20 @@ logger = logging.getLogger(__name__) class ToolFileManager: + @staticmethod + def _build_graph_file_reference(tool_file: ToolFile) -> File: + extension = guess_extension(tool_file.mimetype) or ".bin" + return File( + type=get_file_type_by_mime_type(tool_file.mimetype), + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id), storage_key=tool_file.file_key), + filename=tool_file.name, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + ) + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -209,9 +224,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id( - self, tool_file_id: str - ) -> tuple[Generator | None, ToolFilePydanticModel | None]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: """ get file binary @@ -233,7 +246,7 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, ToolFilePydanticModel.model_validate(tool_file) + return stream, self._build_graph_file_reference(tool_file) # init tool_file_parser diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fc5fead2d..1a3c491216 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Generator from datetime import date, datetime from decimal import Decimal @@ -10,12 +11,15 @@ import pytz from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager +from core.workflow.file_reference import parse_file_reference from dify_graph.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account logger = logging.getLogger(__name__) +_TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P[^/?#.]+)") + def safe_json_value(v): if isinstance(v, datetime): @@ -82,11 +86,15 @@ class ToolFileMessageTransformer: ) url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" + meta = cls._with_tool_file_meta( + message.meta, + tool_file_id=str(tool_file.id), + ) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, + meta=meta, ) except Exception as e: yield ToolInvokeMessage( @@ -122,38 +130,45 @@ class ToolFileMessageTransformer: ) url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) + meta = cls._with_tool_file_meta(meta, tool_file_id=str(tool_file.id)) # check if file is image if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("tool file is missing reference") + url = cls.get_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + ) + tool_file_meta = cls._with_tool_file_meta(meta, tool_file_id=parsed_reference.record_id) if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield message @@ -162,9 +177,40 @@ class ToolFileMessageTransformer: if isinstance(message.message, ToolInvokeMessage.JsonMessage): message.message.json_object = safe_json_value(message.message.json_object) yield message + elif message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + } and isinstance(message.message, ToolInvokeMessage.TextMessage): + yield ToolInvokeMessage( + type=message.type, + message=message.message, + meta=cls._with_tool_file_meta(message.meta, url=message.message.text), + ) else: yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" + + @staticmethod + def _with_tool_file_meta( + meta: dict | None, + *, + tool_file_id: str | None = None, + url: str | None = None, + ) -> dict: + normalized_meta = meta.copy() if meta is not None else {} + resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) + if resolved_tool_file_id and "tool_file_id" not in normalized_meta: + normalized_meta["tool_file_id"] = resolved_tool_file_id + return normalized_meta + + @staticmethod + def _extract_tool_file_id(url: str | None) -> str | None: + if not url: + return None + match = _TOOL_FILE_URL_PATTERN.search(url) + if match is None: + return None + return match.group("tool_file_id") diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9b9aa7a741..c6e4aa3739 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -17,6 +17,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError +from core.workflow.file_reference import resolve_file_record_id from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from factories.file_factory import build_from_mapping @@ -288,16 +289,21 @@ class WorkflowTool(Tool): file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [File.model_validate(f) for f in file] + file_var_list = [ + # Workflow-as-tool can receive stored file payloads + # from older runs that still include tenant_id. + File.model_validate({key: item for key, item in f.items() if key != "tenant_id"}) + for f in file + ] for file in file_var_list: file_dict: dict[str, str | None] = { "transfer_method": file.transfer_method.value, "type": file.type.value, } if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file.related_id + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file.related_id + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() @@ -340,9 +346,10 @@ class WorkflowTool(Tool): return result, files def _update_file_mapping(self, file_dict: dict): + file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_dict.get("related_id") + file_dict["tool_file_id"] = file_id elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_dict.get("related_id") + file_dict["upload_file_id"] = file_id return file_dict diff --git a/api/core/workflow/file_reference.py b/api/core/workflow/file_reference.py new file mode 100644 index 0000000000..c80acb3783 --- /dev/null +++ b/api/core/workflow/file_reference.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + + +@dataclass(frozen=True) +class FileReference: + record_id: str + storage_key: str | None = None + + +def build_file_reference(*, record_id: str, storage_key: str | None = None) -> str: + payload = {"record_id": record_id} + if storage_key is not None: + payload["storage_key"] = storage_key + encoded_payload = base64.urlsafe_b64encode(json.dumps(payload, separators=(",", ":")).encode()).decode() + return f"{_FILE_REFERENCE_PREFIX}{encoded_payload}" + + +def parse_file_reference(reference: str | None) -> FileReference | None: + if not reference: + return None + + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return FileReference(record_id=reference) + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return FileReference(record_id=reference) + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return FileReference(record_id=reference) + + storage_key = payload.get("storage_key") + if storage_key is not None and not isinstance(storage_key, str): + storage_key = None + + return FileReference(record_id=record_id, storage_key=storage_key) + + +def resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return None + return parsed_reference.record_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py new file mode 100644 index 0000000000..112304e0ad --- /dev/null +++ b/api/core/workflow/human_input_compat.py @@ -0,0 +1,247 @@ +"""Workflow-layer adapters for legacy human-input payload keys. + +Stored workflow graphs and editor payloads may still use Dify-specific human +input recipient keys. Normalize them here before handing configs to +`dify_graph` so graph-owned models only see graph-neutral field names. +""" + +from __future__ import annotations + +import enum +import uuid +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ClassVar, Literal + +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter + +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.runtime import VariablePool +from dify_graph.variables.consts import SELECTORS_LENGTH + + +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + BOUND = "member" + MEMBER = BOUND + EXTERNAL = "external" + + +class _InteractiveSurfaceDeliveryConfig(BaseModel): + pass + + +class BoundRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.BOUND] = EmailRecipientType.BOUND + reference_id: str + + +class ExternalRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +MemberRecipient = BoundRecipient +EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + model_config = ConfigDict(extra="forbid") + + include_bound_group: bool = False + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + + recipients: EmailRecipients + subject: str + body: str + debug_mode: bool = False + + def with_recipients(self, recipients: EmailRecipients) -> EmailDeliveryConfig: + return self.model_copy(update={"recipients": recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + +class _DeliveryMethodBase(BaseModel): + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class InteractiveSurfaceDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + config: _InteractiveSurfaceDeliveryConfig = Field(default_factory=_InteractiveSurfaceDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +WebAppDeliveryMethod = InteractiveSurfaceDeliveryMethod +_WebAppDeliveryConfig = _InteractiveSurfaceDeliveryConfig + +DeliveryChannelConfig = Annotated[InteractiveSurfaceDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + +_DELIVERY_METHODS_ADAPTER = TypeAdapter(list[DeliveryChannelConfig]) + + +def _copy_mapping(value: object) -> dict[str, Any] | None: + if isinstance(value, BaseModel): + return value.model_dump(mode="python") + if isinstance(value, Mapping): + return dict(value) + return None + + +def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") + + delivery_methods = normalized.get("delivery_methods") + if not isinstance(delivery_methods, list): + return normalized + + normalized_methods: list[Any] = [] + for method in delivery_methods: + method_mapping = _copy_mapping(method) + if method_mapping is None: + normalized_methods.append(method) + continue + + config_mapping = _copy_mapping(method_mapping.get("config")) + if config_mapping is not None: + recipients_mapping = _copy_mapping(config_mapping.get("recipients")) + if recipients_mapping is not None: + config_mapping["recipients"] = _normalize_email_recipients(recipients_mapping) + method_mapping["config"] = config_mapping + + normalized_methods.append(method_mapping) + + normalized["delivery_methods"] = normalized_methods + return normalized + + +def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: + normalized = normalize_human_input_node_data_for_graph(node_data) + raw_delivery_methods = normalized.get("delivery_methods") + if not isinstance(raw_delivery_methods, list): + return [] + return list(_DELIVERY_METHODS_ADAPTER.validate_python(raw_delivery_methods)) + + +def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> bool: + for method in parse_human_input_delivery_methods(node_data): + if method.enabled and method.type == DeliveryMethodType.WEBAPP: + return True + return False + + +def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") + + if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: + return normalized + return normalize_human_input_node_data_for_graph(normalized) + + +def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_config) + if normalized is None: + raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") + + data_mapping = _copy_mapping(normalized.get("data")) + if data_mapping is None: + return normalized + + normalized["data"] = normalize_node_data_for_graph(data_mapping) + return normalized + + +def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(recipients) + + legacy_include_bound_group = normalized.pop("whole_workspace", None) + if "include_bound_group" not in normalized and legacy_include_bound_group is not None: + normalized["include_bound_group"] = legacy_include_bound_group + + items = normalized.get("items") + if not isinstance(items, list): + return normalized + + normalized_items: list[Any] = [] + for item in items: + item_mapping = _copy_mapping(item) + if item_mapping is None: + normalized_items.append(item) + continue + + legacy_reference_id = item_mapping.pop("user_id", None) + if "reference_id" not in item_mapping and legacy_reference_id is not None: + item_mapping["reference_id"] = legacy_reference_id + normalized_items.append(item_mapping) + + normalized["items"] = normalized_items + return normalized + + +__all__ = [ + "BoundRecipient", + "DeliveryChannelConfig", + "DeliveryMethodType", + "EmailDeliveryConfig", + "EmailDeliveryMethod", + "EmailRecipientType", + "EmailRecipients", + "ExternalRecipient", + "MemberRecipient", + "WebAppDeliveryMethod", + "_WebAppDeliveryConfig", + "is_human_input_webapp_enabled", + "normalize_human_input_node_data_for_graph", + "normalize_node_config_for_graph", + "normalize_node_data_for_graph", + "parse_human_input_delivery_methods", +] diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 2afe9aebd6..608fa157ed 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -19,8 +19,8 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.human_input_compat import normalize_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -37,9 +37,11 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector +from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.file.file_manager import file_manager from dify_graph.graph.graph import NodeFactory from dify_graph.model_runtime.memory import PromptMessageMemory @@ -51,11 +53,8 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.nodes.document_extractor import UnstructuredApiConfig from dify_graph.nodes.http_request import build_http_request_config from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.protocols import TemplateRenderer from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.template_rendering import CodeExecutorJinja2TemplateRenderer -from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -233,16 +232,6 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) -class DefaultLLMTemplateRenderer(TemplateRenderer): - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=inputs, - ) - return str(result.get("result", "")) - - @final class DifyNodeFactory(NodeFactory): """ @@ -268,11 +257,13 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) - self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(self._dify_context) + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( + self._dify_context, + conversation_id_getter=self._conversation_id, + ) self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) self._prompt_message_serializer = DifyPromptMessageSerializer() self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( @@ -281,8 +272,15 @@ class DifyNodeFactory(NodeFactory): self._llm_file_saver = build_dify_llm_file_saver( run_context=self._dify_context, http_client=self._http_request_http_client, + conversation_id_getter=self._conversation_id, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime( + self._dify_context, + workflow_execution_id_getter=lambda: get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ), ) - self._human_input_runtime = DifyHumanInputNodeRuntime(self._dify_context) self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( @@ -314,6 +312,15 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) + def _dify_invoke_source(self) -> str: + invoke_from = self._dify_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(invoke_from.value) + + def _conversation_id(self) -> str | None: + return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) + @override def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ @@ -325,7 +332,7 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) @@ -347,10 +354,6 @@ class DifyNodeFactory(NodeFactory): "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl( - tenant_id=self._dify_context.tenant_id, - app_id=self._dify_context.app_id, - ), "runtime": self._human_input_runtime, }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( @@ -445,7 +448,7 @@ class DifyNodeFactory(NodeFactory): ), } if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: - node_init_kwargs["template_renderer"] = self._llm_template_renderer + node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client if include_llm_file_saver: @@ -456,6 +459,8 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader if include_jinja2_template_renderer: node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer + if validated_node_data.type == BuiltinNodeTypes.LLM: + node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: @@ -477,12 +482,7 @@ class DifyNodeFactory(NodeFactory): if node_data.memory is None: return None - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index a07830f7e7..94a431a9a3 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -15,13 +15,14 @@ from core.model_manager import ModelInstance from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError from core.plugin.impl.plugin import PluginInstaller from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType from core.tools.errors import ToolInvokeError -from core.tools.signature import sign_upload_file from core.tools.tool_engine import ToolEngine from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference from dify_graph.file import FileTransferMethod, FileType from dify_graph.model_runtime.entities import LLMMode from dify_graph.model_runtime.entities.llm_entities import ( @@ -34,7 +35,7 @@ from dify_graph.model_runtime.entities.llm_entities import ( from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from dify_graph.model_runtime.entities.model_entities import AIModelEntity from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, apply_debug_email_recipient +from dify_graph.nodes.human_input.entities import HumanInputNodeData from dify_graph.nodes.llm.runtime_protocols import ( PreparedLLMProtocol, PromptMessageSerializerProtocol, @@ -57,6 +58,17 @@ from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from .human_input_compat import ( + BoundRecipient, + DeliveryChannelConfig, + DeliveryMethodType, + EmailDeliveryMethod, + EmailRecipients, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from .system_variables import SystemVariableKey, get_system_text + if TYPE_CHECKING: from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage @@ -77,6 +89,31 @@ def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> return DifyRunContext.model_validate(raw_ctx) +def apply_dify_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + actor_id: str | None, +) -> DeliveryChannelConfig: + """Apply the Dify debugger-specific email recipient override outside `dify_graph`.""" + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + + if actor_id is None: + debug_recipients = EmailRecipients(include_bound_group=False, items=[]) + else: + debug_recipients = EmailRecipients( + include_bound_group=False, + items=[BoundRecipient(reference_id=actor_id)], + ) + debug_config = method.config.with_recipients(debug_recipients) + return method.model_copy(update={"config": debug_config}) + + class DifyFileReferenceFactory(FileReferenceFactoryProtocol): def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: self._run_context = resolve_dify_run_context(run_context) @@ -200,11 +237,8 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): "type": FileType.IMAGE, "transfer_method": FileTransferMethod.LOCAL_FILE, "remote_url": upload_file.source_url, - "related_id": upload_file.id, - "upload_file_id": upload_file.id, + "reference": build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key), "size": upload_file.size, - "storage_key": upload_file.key, - "url": sign_upload_file(upload_file.id, upload_file.extension), } ) for _, upload_file in attachments_with_bindings @@ -212,18 +246,28 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): class DifyToolFileManager(ToolFileManagerProtocol): - def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + """Workflow adapter that resolves conversation scope outside `dify_graph`.""" + + _conversation_id_getter: Callable[[], str | None] | None + + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + conversation_id_getter: Callable[[], str | None] | None = None, + ) -> None: self._run_context = resolve_dify_run_context(run_context) self._manager = ToolFileManager() + self._conversation_id_getter = conversation_id_getter def create_file_by_raw( self, *, - conversation_id: str | None, file_binary: bytes, mimetype: str, filename: str | None = None, ) -> Any: + conversation_id = self._conversation_id_getter() if self._conversation_id_getter is not None else None return self._manager.create_file_by_raw( user_id=self._run_context.user_id, tenant_id=self._run_context.tenant_id, @@ -246,6 +290,18 @@ class _WorkflowToolRuntimeSpec: credential_id: str | None = None +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeBinding: + """Workflow-private runtime state stored inside the opaque graph handle. + + The binding keeps conversation scope in `core.workflow` while `dify_graph` + continues to treat the handle as an opaque token. + """ + + tool: Tool + conversation_id: str | None = None + + class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: self._run_context = resolve_dify_run_context(run_context) @@ -279,7 +335,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): except Exception as exc: raise ToolRuntimeResolutionError(str(exc)) from exc - return ToolRuntimeHandle(raw=tool_runtime) + conversation_id = ( + None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + ) + return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) def get_runtime_parameters( self, @@ -298,10 +357,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): tool_runtime: ToolRuntimeHandle, tool_parameters: Mapping[str, Any], workflow_call_depth: int, - conversation_id: str | None, provider_name: str, ) -> Generator[ToolRuntimeMessage, None, None]: - tool = self._tool_from_handle(tool_runtime) + runtime_binding = self._binding_from_handle(tool_runtime) + tool = cast("Tool", runtime_binding.tool) callback = DifyWorkflowCallbackHandler() try: @@ -312,7 +371,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): workflow_tool_callback=callback, workflow_call_depth=workflow_call_depth, app_id=self._run_context.app_id, - conversation_id=conversation_id, + conversation_id=runtime_binding.conversation_id, ) except Exception as exc: raise self._map_invocation_exception(exc, provider_name=provider_name) from exc @@ -321,7 +380,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): messages=messages, user_id=self._run_context.user_id, tenant_id=self._run_context.tenant_id, - conversation_id=None, + conversation_id=runtime_binding.conversation_id, ) return self._adapt_messages(transformed_messages, provider_name=provider_name) @@ -331,7 +390,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): *, tool_runtime: ToolRuntimeHandle, ) -> LLMUsage: - latest = getattr(tool_runtime.raw, "latest_usage", None) + latest = getattr(self._binding_from_handle(tool_runtime).tool, "latest_usage", None) if isinstance(latest, LLMUsage): return latest if isinstance(latest, dict): @@ -373,7 +432,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): @staticmethod def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: - return cast("Tool", tool_runtime.raw) + return cast("Tool", DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool) + + @staticmethod + def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: + if isinstance(tool_runtime.raw, _WorkflowToolRuntimeBinding): + return tool_runtime.raw + return _WorkflowToolRuntimeBinding(tool=cast("Tool", tool_runtime.raw)) @staticmethod def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: @@ -491,42 +556,85 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): - def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + workflow_execution_id_getter: Callable[[], str | None] | None = None, + ) -> None: self._run_context = resolve_dify_run_context(run_context) + self._workflow_execution_id_getter = workflow_execution_id_getter - def invoke_source(self) -> str: + def _invoke_source(self) -> str: invoke_from = self._run_context.invoke_from if isinstance(invoke_from, str): return invoke_from return str(getattr(invoke_from, "value", invoke_from)) - def apply_delivery_runtime( - self, - *, - methods: Sequence[DeliveryChannelConfig], - ) -> Sequence[DeliveryChannelConfig]: + def _resolve_delivery_methods(self, *, node_data: HumanInputNodeData) -> Sequence[DeliveryChannelConfig]: + invoke_source = self._invoke_source() + methods = [method for method in parse_human_input_delivery_methods(node_data) if method.enabled] + if invoke_source in {"debugger", "explore"}: + methods = [method for method in methods if method.type != DeliveryMethodType.WEBAPP] return [ - apply_debug_email_recipient( + apply_dify_debug_email_recipient( method, - enabled=self.invoke_source() == "debugger", - user_id=self._run_context.user_id, + enabled=invoke_source == "debugger", + actor_id=self._run_context.user_id, ) for method in methods ] - def console_actor_id(self) -> str | None: - return self._run_context.user_id + def _display_in_ui(self, *, node_data: HumanInputNodeData) -> bool: + if self._invoke_source() == "debugger": + return True + return is_human_input_webapp_enabled(node_data) + + def _build_form_repository(self) -> HumanInputFormRepositoryImpl: + invoke_source = self._invoke_source() + return HumanInputFormRepositoryImpl( + tenant_id=self._run_context.tenant_id, + app_id=self._run_context.app_id, + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + invoke_source=invoke_source, + submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, + ) + + def get_form(self, *, node_id: str): + repo = self._build_form_repository() + return repo.get_form(node_id) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ): + repo = self._build_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=self._resolve_delivery_methods(node_data=node_data), + display_in_ui=self._display_in_ui(node_data=node_data), + resolved_default_values=resolved_default_values, + ) + return repo.create_form(params) def build_dify_llm_file_saver( *, run_context: Mapping[str, Any] | DifyRunContext, http_client: HttpClientProtocol, + conversation_id_getter: Callable[[], str | None] | None = None, ) -> LLMFileSaver: from dify_graph.nodes.llm.file_saver import FileSaverImpl return FileSaverImpl( - tool_file_manager=DifyToolFileManager(run_context), + tool_file_manager=DifyToolFileManager(run_context, conversation_id_getter=conversation_id_getter), file_reference_factory=DifyFileReferenceFactory(run_context), http_client=http_client, ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 202f6d9921..9fd1365b39 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -4,8 +4,9 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @@ -115,14 +116,14 @@ class AgentNode(Node[AgentNodeData]): ) credentials = self._runtime_support.build_credentials(parameters=parameters) - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) try: message_stream = strategy.invoke( params=parameters, user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + conversation_id=conversation_id, credentials=credentials, ) except Exception as e: @@ -149,6 +150,7 @@ class AgentNode(Node[AgentNodeData]): parameters_for_log=parameters_for_log, user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + conversation_id=conversation_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index 08d69fc2bb..3dbd419309 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -37,6 +37,7 @@ class AgentMessageTransformer: parameters_for_log: dict[str, Any], user_id: str, tenant_id: str, + conversation_id: str | None, node_type: NodeType, node_id: str, node_execution_id: str, @@ -47,7 +48,7 @@ class AgentMessageTransformer: messages=messages, user_id=user_id, tenant_id=tenant_id, - conversation_id=None, + conversation_id=conversation_id, ) text = "" @@ -70,10 +71,12 @@ class AgentMessageTransformer: url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -96,7 +99,9 @@ class AgentMessageTransformer: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 3c8f9bb2cd..06eeb32da1 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -17,10 +17,9 @@ from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from dify_graph.enums import SystemVariableKey +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -224,10 +223,9 @@ class AgentRuntimeSupport: app_id: str, model_instance: ModelInstance, ) -> TokenBufferMemory | None: - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): + conversation_id = get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + if conversation_id is None: return None - conversation_id = conversation_id_variable.value with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 9a80dd9af6..cca2cb14b5 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -5,9 +5,11 @@ from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunC from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @@ -54,11 +56,11 @@ class DatasourceNode(Node[DatasourceNodeData]): dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + datasource_type_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_TYPE) if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None - datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + datasource_info_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_INFO) if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segment.value @@ -131,12 +133,14 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) case DatasourceProviderType.LOCAL_FILE: - related_id = datasource_info.get("related_id") - if not related_id: + file_id = resolve_file_record_id( + datasource_info.get("reference") or datasource_info.get("related_id") + ) + if not file_id: raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=dify_ctx.tenant_id + file_id=file_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4ea9091c5b..f51abfdd8c 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -6,9 +6,10 @@ from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, SystemVariableKey +from dify_graph.enums import NodeExecutionType from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.template import Template @@ -46,21 +47,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # get dataset id as string - dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + dataset_id_segment = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID) if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") dataset_id: str = dataset_id_segment.value # get document id as string (may be empty when not provided) - document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id_segment = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - invoke_from_value = str(invoke_from.value) if invoke_from else None + invoke_from_value = get_system_text(variable_pool, SystemVariableKey.INVOKE_FROM) is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value @@ -87,8 +87,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): outputs=outputs.model_dump(exclude_none=True), ) - original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + original_document_id_segment = get_system_segment(variable_pool, SystemVariableKey.ORIGINAL_DOCUMENT_ID) + batch = get_system_segment(variable_pool, SystemVariableKey.BATCH) if not batch: raise KnowledgeIndexNodeError("Batch is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 871bd2e968..bdffc6cdfe 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -12,6 +12,7 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( @@ -255,7 +256,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_model_config=node_data.metadata_model_config, metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, + attachment_ids=[ + parsed_reference.record_id + for attachment in attachments + if (parsed_reference := parse_file_reference(attachment.reference)) is not None + ] + if attachments + else None, ) ) diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 118c2f2668..18444a5db4 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -53,13 +53,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]): "plugin_unique_identifier": self.node_data.plugin_unique_identifier, }, } - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b9580e6ab1..1d53c43eb5 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType from dify_graph.node_events import NodeRunResult @@ -31,13 +31,11 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): } def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index c4a98345c6..e2dee3656a 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -3,7 +3,8 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType from dify_graph.file import FileTransferMethod @@ -60,16 +61,14 @@ class TriggerWebhookNode(Node[WebhookData]): happens in the trigger controller. """ # Get webhook data from variable pool (injected by Celery task) - webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + webhook_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) # Extract webhook-specific outputs based on node configuration outputs = self._extract_configured_outputs(webhook_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + outputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=webhook_inputs, @@ -77,17 +76,17 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - related_id = file.get("related_id") + file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: - file["upload_file_id"] = related_id + file["upload_file_id"] = file_id case FileTransferMethod.TOOL_FILE: - file["tool_file_id"] = related_id + file["tool_file_id"] = file_id case FileTransferMethod.DATASOURCE_FILE: - file["datasource_file_id"] = related_id + file["datasource_file_id"] = file_id try: file_obj = self._file_reference_factory.build_from_mapping(mapping=file) diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py new file mode 100644 index 0000000000..332cfd6941 --- /dev/null +++ b/api/core/workflow/system_variables.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Protocol, cast +from uuid import uuid4 + +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.variables import build_segment, segment_to_variable +from dify_graph.variables.segments import Segment +from dify_graph.variables.variables import RAGPipelineVariableInput, Variable + +from .variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) + + +class SystemVariableKey(StrEnum): + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" + TIMESTAMP = "timestamp" + DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" + + +class _VariablePoolReader(Protocol): + def get(self, selector: Sequence[str], /) -> Segment | None: ... + + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ... + + +def system_variable_name(key: str | SystemVariableKey) -> str: + return key.value if isinstance(key, SystemVariableKey) else key + + +def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]: + return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key) + + +def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]: + raw_values = dict(values or {}) + raw_values.update(kwargs) + + workflow_execution_id = raw_values.pop("workflow_execution_id", None) + if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values: + raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id + + normalized: dict[str, Any] = {} + for key, value in raw_values.items(): + if value is None: + continue + normalized[system_variable_name(key)] = value + + normalized.setdefault(SystemVariableKey.FILES.value, []) + return normalized + + +def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]: + normalized = _normalize_system_variable_values(values, **kwargs) + + return [ + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, + ), + ) + for key, value in normalized.items() + ] + + +def default_system_variables() -> list[Variable]: + return build_system_variables(workflow_run_id=str(uuid4())) + + +def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]: + return {variable.name: variable.value for variable in system_variables} + + +def _with_selector(variable: Variable, node_id: str) -> Variable: + selector = [node_id, variable.name] + if list(variable.selector) == selector: + return variable + return variable.model_copy(update={"selector": selector}) + + +def build_bootstrap_variables( + *, + system_variables: Sequence[Variable] = (), + environment_variables: Sequence[Variable] = (), + conversation_variables: Sequence[Variable] = (), + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (), +) -> list[Variable]: + variables = [ + *(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables), + *(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables), + *(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables), + ] + + rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_var in rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + rag_pipeline_variables_map[node_id][key] = rag_var.value + + for node_id, value in rag_pipeline_variables_map.items(): + variables.append( + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, + ), + ) + ) + + return variables + + +def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None: + return variable_pool.get(system_variable_selector(key)) + + +def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any: + segment = get_system_segment(variable_pool, key) + return None if segment is None else segment.value + + +def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None: + segment = get_system_segment(variable_pool, key) + if segment is None: + return None + text = getattr(segment, "text", None) + return text if isinstance(text, str) else None + + +def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]: + return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) + + +def inject_default_system_variable_mappings( + *, + node_id: str, + node_type: str, + node_data: object, + variable_mapping: Mapping[str, Sequence[str]], +) -> Mapping[str, Sequence[str]]: + """Add workflow-owned implicit sys mappings that `dify_graph` should not know about.""" + + if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None: + return variable_mapping + + query_mapping_key = f"{node_id}.#sys.query#" + if query_mapping_key in variable_mapping: + return variable_mapping + + augmented_mapping = dict(variable_mapping) + augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY) + return augmented_mapping diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py new file mode 100644 index 0000000000..fb7f9651ec --- /dev/null +++ b/api/core/workflow/template_rendering.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=variables, + ) + except Exception as exc: + if isinstance(exc, CodeExecutionError): + raise TemplateRenderError(str(exc)) from exc + raise + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/variable_pool_initializer.py b/api/core/workflow/variable_pool_initializer.py new file mode 100644 index 0000000000..42151affdd --- /dev/null +++ b/api/core/workflow/variable_pool_initializer.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from dify_graph.runtime import VariablePool +from dify_graph.variables.variables import Variable + + +def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Variable]) -> None: + for variable in variables: + variable_pool.add(variable.selector, variable) + + +def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None: + for key, value in inputs.items(): + variable_pool.add((node_id, key), value) diff --git a/api/dify_graph/constants.py b/api/core/workflow/variable_prefixes.py similarity index 66% rename from api/dify_graph/constants.py rename to api/core/workflow/variable_prefixes.py index 7664be0983..9574bc8a42 100644 --- a/api/dify_graph/constants.py +++ b/api/core/workflow/variable_prefixes.py @@ -2,3 +2,5 @@ SYSTEM_VARIABLE_NODE_ID = "sys" ENVIRONMENT_VARIABLE_NODE_ID = "env" CONVERSATION_VARIABLE_NODE_ID = "conversation" RAG_PIPELINE_VARIABLE_NODE_ID = "rag" + +CHILD_SYNC_VARIABLE_NODE_IDS = frozenset((CONVERSATION_VARIABLE_NODE_ID,)) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06bab..60b9ea92f1 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,15 +1,22 @@ import logging import time from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager from typing import Any, cast from configs import dify_config +from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class -from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.system_variables import ( + default_system_variables, + inject_default_system_variable_mappings, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import CHILD_SYNC_VARIABLE_NODE_IDS, ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.errors import WorkflowNodeRunFailedError @@ -24,7 +31,6 @@ from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphR from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory @@ -34,6 +40,9 @@ logger = logging.getLogger(__name__) class _WorkflowChildEngineBuilder: + def __init__(self, execution_context: AbstractContextManager[object] | None) -> None: + self._execution_context = execution_context + @staticmethod def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: """ @@ -86,6 +95,7 @@ class _WorkflowChildEngineBuilder: command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, + execution_context=self._execution_context, ) child_engine.layer(LLMQuotaLayer()) for layer in layers: @@ -136,7 +146,8 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel - self._child_engine_builder = _WorkflowChildEngineBuilder() + execution_context = capture_current_context() + self._child_engine_builder = _WorkflowChildEngineBuilder(execution_context=execution_context) self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -149,6 +160,7 @@ class WorkflowEntry: scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, ), child_engine_builder=self._child_engine_builder, + execution_context=execution_context, ) # Add debug logging layer when in debug mode @@ -225,8 +237,16 @@ class WorkflowEntry: invoke_from=InvokeFrom.DEBUGGER, ), call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) + + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) # init workflow run state node_factory = DifyNodeFactory( @@ -243,6 +263,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_config_data, + variable_mapping=variable_mapping, + ) # Loading missing variable from draft var here, and set it into # variable_pool. @@ -347,11 +373,8 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, default_system_variables()) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -365,8 +388,13 @@ class WorkflowEntry: invoke_from=InvokeFrom.DEBUGGER, ), call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) @@ -384,6 +412,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_data, + variable_mapping=variable_mapping, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, diff --git a/api/dify_graph/context/__init__.py b/api/dify_graph/context/__init__.py deleted file mode 100644 index 103f526bec..0000000000 --- a/api/dify_graph/context/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Execution Context - Context management for workflow execution. - -This package provides Flask-independent context management for workflow -execution in multi-threaded environments. -""" - -from dify_graph.context.execution_context import ( - AppContext, - ContextProviderNotFoundError, - ExecutionContext, - IExecutionContext, - NullAppContext, - capture_current_context, - read_context, - register_context, - register_context_capturer, - reset_context_provider, -) -from dify_graph.context.models import SandboxContext - -__all__ = [ - "AppContext", - "ContextProviderNotFoundError", - "ExecutionContext", - "IExecutionContext", - "NullAppContext", - "SandboxContext", - "capture_current_context", - "read_context", - "register_context", - "register_context_capturer", - "reset_context_provider", -] diff --git a/api/dify_graph/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py deleted file mode 100644 index 17b19f2502..0000000000 --- a/api/dify_graph/conversation_variable_updater.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Protocol - -from dify_graph.variables import VariableBase - - -class ConversationVariableUpdater(Protocol): - """ - ConversationVariableUpdater defines an abstraction for updating conversation variable values. - - It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating - conversation variables. - - Implementations may choose to batch updates. If batching is used, the `flush` method - should be implemented to persist buffered changes, and `update` - should handle buffering accordingly. - - Note: Since implementations may buffer updates, instances of ConversationVariableUpdater - are not thread-safe. Each VariableAssignerNode should create its own instance during execution. - """ - - @abc.abstractmethod - def update(self, conversation_id: str, variable: "VariableBase"): - """ - Updates the value of the specified conversation variable in the underlying storage. - - :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `VariableBase` instance containing the updated value. - """ - pass - - @abc.abstractmethod - def flush(self): - """ - Flushes all pending updates to the underlying storage system. - - If the implementation does not buffer updates, this method can be a no-op. - """ - pass diff --git a/api/dify_graph/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py index 160be9b7b2..66218517c0 100644 --- a/api/dify_graph/entities/graph_init_params.py +++ b/api/dify_graph/entities/graph_init_params.py @@ -20,3 +20,7 @@ class GraphInitParams(BaseModel): graph_config: Mapping[str, Any] = Field(..., description="graph config") run_context: Mapping[str, Any] = Field(..., description="runtime context") call_depth: int = Field(..., description="call depth") + child_sync_variable_node_ids: frozenset[str] = Field( + default_factory=frozenset, + description="Variable node IDs whose values must be synced back from child graph executions.", + ) diff --git a/api/dify_graph/entities/pause_reason.py b/api/dify_graph/entities/pause_reason.py index 86d8c8ca16..d70110c446 100644 --- a/api/dify_graph/entities/pause_reason.py +++ b/api/dify_graph/entities/pause_reason.py @@ -18,7 +18,6 @@ class HumanInputRequired(BaseModel): form_content: str inputs: list[FormInput] = Field(default_factory=list) actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False node_id: str node_title: str @@ -33,13 +32,6 @@ class HumanInputRequired(BaseModel): # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None - class SchedulingPause(BaseModel): TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE diff --git a/api/dify_graph/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py index 293a66c56e..cf8ebf1e43 100644 --- a/api/dify_graph/entities/workflow_execution.py +++ b/api/dify_graph/entities/workflow_execution.py @@ -1,26 +1,23 @@ """ Domain entities for workflow execution. -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. +Models describe graph runtime state and avoid infrastructure-specific details. """ from __future__ import annotations from collections.abc import Mapping -from datetime import datetime +from datetime import UTC, datetime from typing import Any from pydantic import BaseModel, Field from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.utils.datetime_utils import naive_utc_now class WorkflowExecution(BaseModel): """ - Domain model for workflow execution based on WorkflowRun but without - user, tenant, and app attributes. + Domain model for a workflow execution within the graph runtime. """ id_: str = Field(...) @@ -47,7 +44,7 @@ class WorkflowExecution(BaseModel): Calculate elapsed time in seconds. If workflow is not finished, use current time. """ - end_time = self.finished_at or naive_utc_now() + end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) return (end_time - self.started_at).total_seconds() @classmethod diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py index bc7e0d02e5..3a7b02b299 100644 --- a/api/dify_graph/entities/workflow_node_execution.py +++ b/api/dify_graph/entities/workflow_node_execution.py @@ -1,9 +1,8 @@ """ Domain entities for workflow node execution. -This module contains the domain model for workflow node execution, which is used -by the core workflow module. These models are independent of the storage mechanism -and don't contain implementation details like tenant_id, app_id, etc. +These models capture node-level execution state for the graph runtime without +describing storage or application-layer concerns. """ from collections.abc import Mapping @@ -19,13 +18,8 @@ class WorkflowNodeExecution(BaseModel): """ Domain model for workflow node execution. - This model represents the core business entity of a node execution, - without implementation details like tenant_id, app_id, etc. - - Note: User/context-specific fields (triggered_from, created_by, created_by_role) - have been moved to the repository implementation to keep the domain model clean. - These fields are still accepted in the constructor for backward compatibility, - but they are not stored in the model. + This model represents the graph-level record of a node execution and + contains only execution state relevant to the runtime. """ # --------- Core identification fields --------- @@ -41,7 +35,7 @@ class WorkflowNodeExecution(BaseModel): # In most scenarios, `id` should be used as the primary identifier. node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index cfb135cbb0..94c7c57f18 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -10,30 +10,6 @@ class NodeState(StrEnum): SKIPPED = "skipped" -class SystemVariableKey(StrEnum): - """ - System Variables. - """ - - QUERY = "query" - FILES = "files" - CONVERSATION_ID = "conversation_id" - USER_ID = "user_id" - DIALOGUE_COUNT = "dialogue_count" - APP_ID = "app_id" - WORKFLOW_ID = "workflow_id" - WORKFLOW_EXECUTION_ID = "workflow_run_id" - TIMESTAMP = "timestamp" - # RAG Pipeline - DOCUMENT_ID = "document_id" - ORIGINAL_DOCUMENT_ID = "original_document_id" - BATCH = "batch" - DATASET_ID = "dataset_id" - DATASOURCE_TYPE = "datasource_type" - DATASOURCE_INFO = "datasource_info" - INVOKE_FROM = "invoke_from" - - NodeType: TypeAlias = str diff --git a/api/dify_graph/file/file_manager.py b/api/dify_graph/file/file_manager.py index 8d998054db..45b86d91cc 100644 --- a/api/dify_graph/file/file_manager.py +++ b/api/dify_graph/file/file_manager.py @@ -12,7 +12,6 @@ from dify_graph.model_runtime.entities import ( ) from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType from .runtime import get_workflow_file_runtime @@ -80,7 +79,7 @@ def download(f: File, /) -> bytes: FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE, ): - return _download_file_content(f.storage_key) + return _download_file_content(f) elif f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: raise ValueError("Missing file remote_url") @@ -90,12 +89,9 @@ def download(f: File, /) -> bytes: raise ValueError(f"unsupported transfer method: {f.transfer_method}") -def _download_file_content(path: str, /) -> bytes: +def _download_file_content(file: File, /) -> bytes: """Download and return a file from storage as bytes.""" - data = get_workflow_file_runtime().storage_load(path, stream=False) - if not isinstance(data, bytes): - raise ValueError(f"file {path} is not a bytes object") - return data + return get_workflow_file_runtime().load_file_bytes(file=file) def _get_encoded_string(f: File, /) -> str: @@ -107,30 +103,20 @@ def _get_encoded_string(f: File, /) -> str: response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) return base64.b64encode(data).decode("utf-8") def _to_url(f: File, /): - if f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - return f.remote_url - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - if f.related_id is None: - raise ValueError("Missing file related_id") - return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) - elif f.transfer_method == FileTransferMethod.TOOL_FILE: - if f.related_id is None or f.extension is None: - raise ValueError("Missing file related_id or extension") - return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension) - else: + url = f.generate_url() + if url is None: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + return url class FileManager: diff --git a/api/dify_graph/file/helpers.py b/api/dify_graph/file/helpers.py index 310cb1310b..dade761227 100644 --- a/api/dify_graph/file/helpers.py +++ b/api/dify_graph/file/helpers.py @@ -1,92 +1,48 @@ from __future__ import annotations -import base64 -import hashlib -import hmac -import os -import time -import urllib.parse +from typing import TYPE_CHECKING from .runtime import get_workflow_file_runtime +if TYPE_CHECKING: + from .models import File + + +def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: + return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) + def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url) - url = f"{base_url}/files/{upload_file_id}/file-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign} - if as_attachment: - query["as_attachment"] = "true" - query_string = urllib.parse.urlencode(query) - - return f"{url}?{query_string}" - - -def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - runtime = get_workflow_file_runtime() - # Plugin access should use internal URL for Docker network communication. - base_url = runtime.internal_files_url or runtime.files_url - url = f"{base_url}/files/upload/for-plugin" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" + return get_workflow_file_runtime().resolve_upload_file_url( + upload_file_id=upload_file_id, + as_attachment=as_attachment, + for_external=for_external, + ) def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) - - -def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str -) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout + return get_workflow_file_runtime().resolve_tool_file_url( + tool_file_id=tool_file_id, + extension=extension, + for_external=for_external, + ) def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="image", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="file", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) diff --git a/api/dify_graph/file/models.py b/api/dify_graph/file/models.py index dcba00978e..f8b1701cdb 100644 --- a/api/dify_graph/file/models.py +++ b/api/dify_graph/file/models.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any -from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator @@ -44,57 +43,39 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 -class ToolFile(BaseModel): - id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") - user_id: UUID = Field(..., description="ID of the user who owns this file") - tenant_id: UUID = Field(..., description="ID of the tenant/organization") - conversation_id: UUID | None = Field(None, description="ID of the associated conversation") - file_key: str = Field(..., max_length=255, description="Storage key for the file") - mimetype: str = Field(..., max_length=255, description="MIME type of the file") - original_url: str | None = Field( - None, max_length=2048, description="Original URL if file was fetched from external source" - ) - name: str = Field(default="", max_length=255, description="Display name of the file") - size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") - - class Config: - from_attributes = True # Enable ORM mode for SQLAlchemy compatibility - populate_by_name = True - - class File(BaseModel): + """Graph-owned file reference. + + The graph layer deliberately keeps only the metadata required to route, + serialize, and render files. Application ownership concerns such as + tenant/user/conversation identity stay in the workflow/storage layer. + """ + # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY id: str | None = None # message file id - tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. remote_url: str | None = None # remote url - # If `transfer_method` is `FileTransferMethod.local_file` or - # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. - # - # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: str | None = None + # Opaque workflow-layer reference for files resolved outside ``dify_graph``. + reference: str | None = None filename: str | None = None extension: str | None = Field(default=None, description="File extension, should contain dot") mime_type: str | None = None size: int = -1 - # Those properties are private, should not be exposed to the outside. - _storage_key: str - def __init__( self, *, id: str | None = None, - tenant_id: str, type: FileType, transfer_method: FileTransferMethod, remote_url: str | None = None, + reference: str | None = None, related_id: str | None = None, filename: str | None = None, extension: str | None = None, @@ -103,18 +84,22 @@ class File(BaseModel): storage_key: str | None = None, dify_model_identity: str | None = FILE_MODEL_IDENTITY, url: str | None = None, - # Legacy compatibility fields - explicitly handle known extra fields + # Legacy compatibility fields - explicitly accept known extra fields tool_file_id: str | None = None, upload_file_id: str | None = None, datasource_file_id: str | None = None, ): + legacy_record_id = tool_file_id or upload_file_id or datasource_file_id or related_id + normalized_reference = reference + if normalized_reference is None and legacy_record_id is not None: + normalized_reference = str(legacy_record_id) + super().__init__( id=id, - tenant_id=tenant_id, type=type, transfer_method=transfer_method, remote_url=remote_url, - related_id=related_id, + reference=normalized_reference, filename=filename, extension=extension, mime_type=mime_type, @@ -122,12 +107,12 @@ class File(BaseModel): dify_model_identity=dify_model_identity, url=url, ) - self._storage_key = str(storage_key) def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { **data, + "related_id": self.reference, "url": self.generate_url(), } @@ -142,21 +127,7 @@ class File(BaseModel): return text def generate_url(self, for_external: bool = True) -> str | None: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) - elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: - assert self.related_id is not None - assert self.extension is not None - return sign_tool_file( - tool_file_id=self.related_id, - extension=self.extension, - for_external=for_external, - ) - return None + return helpers.resolve_file_url(self, for_external=for_external) def to_plugin_parameter(self) -> dict[str, Any]: return { @@ -178,20 +149,24 @@ class File(BaseModel): if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): raise ValueError("Invalid file url") case FileTransferMethod.LOCAL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.TOOL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.DATASOURCE_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") return self @property - def storage_key(self) -> str: - return self._storage_key + def related_id(self) -> str | None: + return self.reference - @storage_key.setter - def storage_key(self, value: str) -> None: - self._storage_key = value + @related_id.setter + def related_id(self, value: str | None) -> None: + self.reference = value + + @property + def storage_key(self) -> str: + return "" diff --git a/api/dify_graph/file/protocols.py b/api/dify_graph/file/protocols.py index 24cbb42735..4246d5d6ee 100644 --- a/api/dify_graph/file/protocols.py +++ b/api/dify_graph/file/protocols.py @@ -1,7 +1,10 @@ from __future__ import annotations from collections.abc import Generator -from typing import Protocol +from typing import TYPE_CHECKING, Literal, Protocol + +if TYPE_CHECKING: + from .models import File class HttpResponseProtocol(Protocol): @@ -21,18 +24,6 @@ class WorkflowFileRuntimeProtocol(Protocol): application infrastructure modules directly. """ - @property - def files_url(self) -> str: ... - - @property - def internal_files_url(self) -> str | None: ... - - @property - def secret_key(self) -> str: ... - - @property - def files_access_timeout(self) -> int: ... - @property def multimodal_send_format(self) -> str: ... @@ -40,4 +31,26 @@ class WorkflowFileRuntimeProtocol(Protocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... + def load_file_bytes(self, *, file: File) -> bytes: ... + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: ... + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: ... diff --git a/api/dify_graph/file/runtime.py b/api/dify_graph/file/runtime.py index 94253e0255..1c5d1c3ca4 100644 --- a/api/dify_graph/file/runtime.py +++ b/api/dify_graph/file/runtime.py @@ -1,10 +1,13 @@ from __future__ import annotations from collections.abc import Generator -from typing import NoReturn +from typing import TYPE_CHECKING, Literal, NoReturn from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +if TYPE_CHECKING: + from .models import File + class WorkflowFileRuntimeNotConfiguredError(RuntimeError): """Raised when workflow file runtime dependencies were not configured.""" @@ -16,22 +19,6 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" ) - @property - def files_url(self) -> str: - self._raise() - - @property - def internal_files_url(self) -> str | None: - self._raise() - - @property - def secret_key(self) -> str: - self._raise() - - @property - def files_access_timeout(self) -> int: - self._raise() - @property def multimodal_send_format(self) -> str: self._raise() @@ -42,7 +29,33 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: self._raise() - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + self._raise() + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + self._raise() + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._raise() + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._raise() + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: self._raise() diff --git a/api/dify_graph/graph_engine/event_management/event_handlers.py b/api/dify_graph/graph_engine/event_management/event_handlers.py index 7f5ad40e0e..f0c8eada17 100644 --- a/api/dify_graph/graph_engine/event_management/event_handlers.py +++ b/api/dify_graph/graph_engine/event_management/event_handlers.py @@ -28,6 +28,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState @@ -153,6 +154,17 @@ class EventHandler: for stream_event in streaming_events: self._event_collector.collect(stream_event) + @_dispatch.register + def _(self, event: NodeRunVariableUpdatedEvent) -> None: + """ + Apply a node-requested variable mutation before downstream observers run. + + The event is collected like other node events so parent/container engines can + forward the updated payload to outer layers, including persistence listeners. + """ + self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) + self._event_collector.collect(event) + @_dispatch.register def _(self, event: NodeRunSucceededEvent) -> None: """ diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index ea98a46b06..5cdffed3b0 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -10,9 +10,9 @@ from __future__ import annotations import logging import queue from collections.abc import Generator, Mapping +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, cast, final -from dify_graph.context import capture_current_context from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import NodeExecutionType from dify_graph.graph import Graph @@ -77,6 +77,7 @@ class GraphEngine: command_channel: CommandChannel, config: GraphEngineConfig = _DEFAULT_CONFIG, child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" @@ -89,6 +90,8 @@ class GraphEngine: self._child_engine_builder = child_engine_builder if child_engine_builder is not None: self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) + if execution_context is not None: + self._graph_runtime_state.execution_context = execution_context # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) @@ -154,16 +157,13 @@ class GraphEngine: self._layers: list[GraphEngineLayer] = [] # === Worker Pool Setup === - # Capture execution context for worker threads - execution_context = capture_current_context() - # Create worker pool for parallel node execution self._worker_pool = WorkerPool( ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, layers=self._layers, - execution_context=execution_context, + execution_context=self._graph_runtime_state.execution_context, config=self._config, ) diff --git a/api/dify_graph/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py index 988c20d72a..747a16453b 100644 --- a/api/dify_graph/graph_engine/worker.py +++ b/api/dify_graph/graph_engine/worker.py @@ -9,12 +9,12 @@ import queue import threading import time from collections.abc import Sequence +from contextlib import AbstractContextManager from datetime import datetime from typing import TYPE_CHECKING, final from typing_extensions import override -from dify_graph.context import IExecutionContext from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer @@ -46,7 +46,7 @@ class Worker(threading.Thread): graph: Graph, layers: Sequence[GraphEngineLayer], worker_id: int = 0, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize worker thread. diff --git a/api/dify_graph/graph_engine/worker_management/worker_pool.py b/api/dify_graph/graph_engine/worker_management/worker_pool.py index cc93087783..eb641165a0 100644 --- a/api/dify_graph/graph_engine/worker_management/worker_pool.py +++ b/api/dify_graph/graph_engine/worker_management/worker_pool.py @@ -8,9 +8,9 @@ DynamicScaler, and WorkerFactory into a single class. import logging import queue import threading +from contextlib import AbstractContextManager from typing import final -from dify_graph.context import IExecutionContext from dify_graph.graph import Graph from dify_graph.graph_events import GraphNodeEventBase @@ -38,7 +38,7 @@ class WorkerPool: graph: Graph, layers: list[GraphEngineLayer], config: GraphEngineConfig, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize the simple worker pool. diff --git a/api/dify_graph/graph_events/__init__.py b/api/dify_graph/graph_events/__init__.py index 56ea642092..7cec587a05 100644 --- a/api/dify_graph/graph_events/__init__.py +++ b/api/dify_graph/graph_events/__init__.py @@ -46,6 +46,7 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, is_node_result_event, ) @@ -78,5 +79,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "NodeRunVariableUpdatedEvent", "is_node_result_event", ] diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py index f9a026b5e2..3b880f5f6f 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -5,6 +5,7 @@ from typing import Any from pydantic import Field from dify_graph.entities.pause_reason import PauseReason +from dify_graph.variables.variables import Variable from .base import GraphNodeEventBase @@ -39,6 +40,12 @@ class NodeRunSucceededEvent(GraphNodeEventBase): finished_at: datetime | None = Field(default=None, description="node finish time") +class NodeRunVariableUpdatedEvent(GraphNodeEventBase): + """Request that the engine apply a variable update before downstream observers continue.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class NodeRunFailedEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py index 20faf3d6cd..cc31f98b8b 100644 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ b/api/dify_graph/model_runtime/callbacks/base_callback.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -33,7 +33,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -46,7 +46,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -62,7 +62,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -76,7 +76,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -92,7 +92,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -106,7 +106,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -122,7 +122,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -136,7 +136,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py index 49b9ab27eb..29c57ed3ab 100644 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ b/api/dify_graph/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,7 @@ import json import logging import sys -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import cast from dify_graph.model_runtime.callbacks.base_callback import Callback @@ -23,7 +23,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -36,7 +36,7 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ self.print_text("\n[on_llm_before_invoke]\n", color="blue") self.print_text(f"Model: {model}\n", color="blue") @@ -54,8 +54,8 @@ class LoggingCallback(Callback): self.print_text(f"Stream: {stream}\n", color="blue") - if user: - self.print_text(f"User: {user}\n", color="blue") + if invocation_context: + self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: @@ -79,7 +79,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -93,8 +93,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = invocation_context sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() @@ -109,7 +110,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -123,8 +124,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = invocation_context self.print_text("\n[on_llm_after_invoke]\n", color="yellow") self.print_text(f"Content: {result.message.content}\n", color="yellow") @@ -150,7 +152,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -164,7 +166,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = invocation_context self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py index 5caac13c15..21c07cf5bb 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py @@ -1,7 +1,7 @@ import logging import time import uuid -from collections.abc import Callable, Generator, Iterator, Sequence +from collections.abc import Callable, Generator, Iterator, Mapping, Sequence from typing import Union from dify_graph.model_runtime.callbacks.base_callback import Callback @@ -320,7 +320,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ @@ -362,7 +362,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -393,7 +393,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -474,7 +474,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -487,7 +487,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -502,7 +502,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -516,7 +516,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -530,7 +530,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ _run_callbacks( callbacks, @@ -545,7 +545,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -559,7 +559,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -573,7 +573,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -589,7 +589,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -603,7 +603,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -617,7 +617,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -633,6 +633,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) diff --git a/api/dify_graph/model_runtime/utils/encoders.py b/api/dify_graph/model_runtime/utils/encoders.py index c85152463e..13abf74767 100644 --- a/api/dify_graph/model_runtime/utils/encoders.py +++ b/api/dify_graph/model_runtime/utils/encoders.py @@ -1,7 +1,7 @@ import dataclasses import datetime from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network @@ -99,7 +99,7 @@ def jsonable_encoder( exclude_defaults: bool = False, exclude_none: bool = False, custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - sqlalchemy_safe: bool = True, + excluded_key_prefixes: Sequence[str] = (), ) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: @@ -126,7 +126,7 @@ def jsonable_encoder( obj_dict, exclude_none=exclude_none, exclude_defaults=exclude_defaults, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if dataclasses.is_dataclass(obj): # Ensure obj is a dataclass instance, not a dataclass type @@ -139,7 +139,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if isinstance(obj, Enum): return obj.value @@ -152,26 +152,28 @@ def jsonable_encoder( if isinstance(obj, dict): encoded_dict = {} for key, value in obj.items(): - if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( - value is not None or not exclude_none - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value + if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): + continue + if value is None and exclude_none: + continue + + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] @@ -184,7 +186,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) ) return encoded_list @@ -212,5 +214,5 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) diff --git a/api/dify_graph/node_events/__init__.py b/api/dify_graph/node_events/__init__.py index a9bef8f9a2..a2bbf9f176 100644 --- a/api/dify_graph/node_events/__init__.py +++ b/api/dify_graph/node_events/__init__.py @@ -21,6 +21,7 @@ from .node import ( RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) __all__ = [ @@ -43,4 +44,5 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "VariableUpdatedEvent", ] diff --git a/api/dify_graph/node_events/node.py b/api/dify_graph/node_events/node.py index 2e3973b8fa..70aecc58b0 100644 --- a/api/dify_graph/node_events/node.py +++ b/api/dify_graph/node_events/node.py @@ -8,6 +8,7 @@ from dify_graph.entities.pause_reason import PauseReason from dify_graph.file import File from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult +from dify_graph.variables.variables import Variable from .base import NodeEventBase @@ -45,6 +46,12 @@ class StreamCompletedEvent(NodeEventBase): node_run_result: NodeRunResult = Field(..., description="run result") +class VariableUpdatedEvent(NodeEventBase): + """Notify the engine that a single variable should be applied to the shared pool.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 7ecea71356..a6684319e5 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -4,6 +4,7 @@ import logging import operator from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence +from datetime import UTC, datetime from functools import singledispatchmethod from types import MappingProxyType from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin @@ -38,6 +39,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from dify_graph.node_events import ( AgentLogEvent, @@ -57,9 +59,9 @@ from dify_graph.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) from dify_graph.runtime import GraphRuntimeState -from dify_graph.utils.datetime_utils import naive_utc_now NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -246,7 +248,7 @@ class Node(Generic[NodeDataT]): self._node_id = node_id self._node_execution_id: str = "" - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) self._node_data = self.validate_node_data(config["data"]) @@ -328,7 +330,7 @@ class Node(Generic[NodeDataT]): def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) # Create and push start event with required fields start_event = NodeRunStartedEvent( @@ -601,6 +603,15 @@ class Node(Generic[NodeDataT]): f"Node {self._node_id} does not support status {event.node_run_result.status}" ) + @_dispatch.register + def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + variable=event.variable, + ) + @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( diff --git a/api/dify_graph/nodes/http_request/executor.py b/api/dify_graph/nodes/http_request/executor.py index 892b0fc688..5798e563d9 100644 --- a/api/dify_graph/nodes/http_request/executor.py +++ b/api/dify_graph/nodes/http_request/executor.py @@ -246,7 +246,7 @@ class Executor: files: dict[str, list[tuple[str | None, bytes, str]]] = {} for key, files_in_segment in files_list: for file in files_in_segment: - if file.related_id is not None or ( + if file.reference is not None or ( file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None ): file_tuple = ( diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index d1cfbb0b44..6f13085b36 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -242,7 +242,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - conversation_id=None, file_binary=content, mimetype=mime_type, ) diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py index 2a33b4a0a8..675c1a0a3c 100644 --- a/api/dify_graph/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -1,230 +1,27 @@ -""" -Human Input node entities. +"""Human Input node entities. + +The graph package owns the workflow-facing form schema and keeps it transportable +across runtimes. Dify-specific delivery surface and recipient translation stay +outside `dify_graph`. """ import re -import uuid from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self +from typing import Any, Self -import bleach -import markdown from pydantic import BaseModel, Field, field_validator, model_validator from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool from dify_graph.variables.consts import SELECTORS_LENGTH -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit +from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit _OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") - _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ - "a", - "blockquote", - "br", - "code", - "em", - "h1", - "h2", - "h3", - "h4", - "h5", - "h6", - "hr", - "li", - "ol", - "p", - "pre", - "strong", - "table", - "tbody", - "td", - "th", - "thead", - "tr", - "ul", - ] - _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { - "a": ["href", "title"], - "td": ["align"], - "th": ["align"], - } - _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": - if user_id is None: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - @classmethod - def render_markdown_body(cls, body: str) -> str: - """Render markdown to safe HTML for email delivery.""" - sanitized_markdown = bleach.clean( - body, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - rendered_html = markdown.markdown( - sanitized_markdown, - extensions=["nl2br", "tables"], - extension_configs={"tables": {"use_align_attribute": True}}, - ) - return bleach.clean( - rendered_html, - tags=cls._ALLOWED_HTML_TAGS, - attributes=cls._ALLOWED_HTML_ATTRIBUTES, - protocols=cls._ALLOWED_PROTOCOLS, - strip=True, - strip_comments=True, - ) - - @classmethod - def sanitize_subject(cls, subject: str) -> str: - """Sanitize email subject to plain text and prevent CRLF injection.""" - sanitized_subject = bleach.clean( - subject, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) - return " ".join(sanitized_subject.split()) - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str | None, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id) - return method.model_copy(update={"config": debug_config}) - - class FormInputDefault(BaseModel): """Default configuration for form inputs.""" @@ -288,7 +85,6 @@ class HumanInputNodeData(BaseNodeData): """Human Input node data.""" type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) user_actions: list[UserAction] = Field(default_factory=list) @@ -317,14 +113,6 @@ class HumanInputNodeData(BaseNodeData): seen_ids.add(action_id) return user_actions - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - def expiration_time(self, start_time: datetime) -> datetime: if self.timeout_unit == TimeoutUnit.HOUR: return start_time + timedelta(hours=self.timeout) @@ -353,10 +141,6 @@ class HumanInputNodeData(BaseNodeData): _add_variable_selectors( [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) for input in self.inputs: default_value = input.default diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py index da85728828..3fb0ab4499 100644 --- a/api/dify_graph/nodes/human_input/enums.py +++ b/api/dify_graph/nodes/human_input/enums.py @@ -25,16 +25,6 @@ class HumanInputFormKind(enum.StrEnum): DELIVERY_TEST = enum.auto() # Form created for delivery tests. -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - class ButtonStyle(enum.StrEnum): """Button styles for user actions.""" @@ -63,10 +53,3 @@ class PlaceholderType(enum.StrEnum): VARIABLE = enum.auto() CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 575f918237..08449e5b20 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator, Mapping, Sequence +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict @@ -15,17 +16,11 @@ from dify_graph.node_events import ( from dify_graph.node_events.base import NodeEventBase from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.base.node import Node -from dify_graph.nodes.runtime import HumanInputNodeRuntimeProtocol -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.utils.datetime_utils import naive_utc_now +from dify_graph.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from .entities import DeliveryChannelConfig, HumanInputNodeData -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType +from .entities import HumanInputNodeData +from .enums import HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: from dify_graph.entities.graph_init_params import GraphInitParams @@ -33,8 +28,6 @@ if TYPE_CHECKING: _SELECTED_BRANCH_KEY = "selected_branch" -_INVOKE_FROM_DEBUGGER = "debugger" -_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -57,7 +50,6 @@ class HumanInputNode(Node[HumanInputNodeData]): ) _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository _OUTPUT_FIELD_ACTION_ID = "__action_id" _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" @@ -68,7 +60,6 @@ class HumanInputNode(Node[HumanInputNodeData]): config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository, runtime: HumanInputNodeRuntimeProtocol | None = None, ) -> None: super().__init__( @@ -77,7 +68,6 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._form_repository = form_repository if runtime is None: raise ValueError("runtime is required") self._runtime = runtime @@ -133,13 +123,7 @@ class HumanInputNode(Node[HumanInputNodeData]): return None - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): + def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): required_event = self._human_input_required_event(form_entity) pause_requested_event = PauseRequestedEvent(reason=required_event) return pause_requested_event @@ -162,45 +146,16 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults - def _should_require_console_recipient(self) -> bool: - invoke_from = self._invoke_from_value() - if invoke_from == _INVOKE_FROM_DEBUGGER: - return True - if invoke_from == _INVOKE_FROM_EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - invoke_from = self._invoke_from_value() - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return self._runtime.apply_delivery_runtime(methods=enabled_methods) - - def _invoke_from_value(self) -> str: - return self._runtime.invoke_source() - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: + def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") return HumanInputRequired( form_id=form_entity.id, form_content=form_entity.rendered_content, inputs=node_data.inputs, actions=node_data.user_actions, - display_in_ui=display_in_ui, node_id=self.id, node_title=node_data.title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -211,47 +166,32 @@ class HumanInputNode(Node[HumanInputNodeData]): This method will: 1. Generate a unique form ID 2. Create form content with variable substitution - 3. Create form in database + 3. Persist the form through the configured repository 4. Send form via configured delivery methods 5. Suspend workflow execution 6. Wait for form submission to resume """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) + form = self._runtime.get_form(node_id=self.id) if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - workflow_execution_id=self._workflow_execution_id, + form_entity = self._runtime.create_form( node_id=self.id, - form_config=self._node_data, + node_data=self._node_data, rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - self._runtime.console_actor_id() - if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} - else None - ), - backstage_recipient_required=True, ) - form_entity = self._form_repository.create_form(params) - # Create human input required event logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + "Human Input node suspended workflow for form. node_id=%s, form_id=%s", self.id, form_entity.id, ) yield self._form_to_pause_event(form_entity) return - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): + if form.status in { + HumanInputFormStatus.TIMEOUT, + HumanInputFormStatus.EXPIRED, + } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): yield HumanInputFormTimeoutEvent( node_title=self._node_data.title, expiration_time=form.expiration_time, diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py index 7c0370e48c..6ca5684666 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -57,8 +57,8 @@ class IfElseNode(Node[IfElseNodeData]): break else: - # TODO: Update database then remove this - # Fallback to old structure if cases are not defined + # TODO: Remove this once all graph definitions use the `cases` structure. + # Fallback to the legacy node shape when `cases` are not defined. input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 769ef11b62..376579d6fd 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -1,12 +1,12 @@ import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from contextlib import AbstractContextManager, nullcontext from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( BuiltinNodeTypes, @@ -34,7 +34,6 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from dify_graph.runtime import VariablePool -from dify_graph.utils.datetime_utils import naive_utc_now from dify_graph.variables import IntegerVariable, NoneSegment from dify_graph.variables.segments import ArrayAnySegment, ArraySegment from dify_graph.variables.variables import Variable @@ -49,7 +48,6 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.context import IExecutionContext from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -93,7 +91,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self._validate_start_node() - started_at = naive_utc_now() + started_at = datetime.now(UTC).replace(tzinfo=None) iter_run_map: dict[str, float] = {} outputs: list[object] = [] usage_accumulator = [LLMUsage.empty_usage()] @@ -206,10 +204,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): ) # Sync conversation variables after each iteration completes - self._sync_conversation_variables_from_snapshot( - self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) + self._sync_child_variable_snapshot( + self._extract_child_variable_snapshot(variable_pool=graph_engine.graph_runtime_state.variable_pool) ) # Accumulate usage from this iteration @@ -239,7 +235,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): float, list[GraphNodeEventBase], object | None, - dict[str, Variable], + dict[tuple[str, str], Variable], LLMUsage, ] ], @@ -264,7 +260,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iteration_duration, events, output_value, - conversation_snapshot, + child_variable_snapshot, iteration_usage, ) = result @@ -280,8 +276,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - # Sync conversation variables after iteration completion - self._sync_conversation_variables_from_snapshot(conversation_snapshot) + # Sync workflow-owned variable scopes after iteration completion + self._sync_child_variable_snapshot(child_variable_snapshot) except Exception as e: # Handle errors based on error_handle_mode @@ -305,8 +301,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self, index: int, item: object, - execution_context: "IExecutionContext", - ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: + execution_context: AbstractContextManager[object], + ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[tuple[str, str], Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with execution_context: iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -325,7 +321,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # Get the output value from the temporary outputs list output_value = outputs_temp[0] if outputs_temp else None - conversation_snapshot = self._extract_conversation_variable_snapshot( + child_variable_snapshot = self._extract_child_variable_snapshot( variable_pool=graph_engine.graph_runtime_state.variable_pool ) iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() @@ -334,15 +330,16 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iteration_duration, events, output_value, - conversation_snapshot, + child_variable_snapshot, graph_engine.graph_runtime_state.llm_usage, ) - def _capture_execution_context(self) -> "IExecutionContext": - """Capture current execution context for parallel iterations.""" - from dify_graph.context import capture_current_context - - return capture_current_context() + def _capture_execution_context(self) -> AbstractContextManager[object]: + """Return the application-supplied execution context for parallel iterations.""" + execution_context = self.graph_runtime_state.execution_context + if execution_context is not None: + return execution_context + return nullcontext() def _handle_iteration_success( self, @@ -516,22 +513,25 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: - conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} + def _extract_child_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[tuple[str, str], Variable]: + snapshot: dict[tuple[str, str], Variable] = {} + for node_id in self.graph_init_params.child_sync_variable_node_ids: + variables = variable_pool.variable_dictionary.get(node_id, {}) + for name, variable in variables.items(): + snapshot[(node_id, name)] = variable.model_copy(deep=True) + return snapshot - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: + def _sync_child_variable_snapshot(self, snapshot: dict[tuple[str, str], Variable]) -> None: parent_pool = self.graph_runtime_state.variable_pool - parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) + for node_id in self.graph_init_params.child_sync_variable_node_ids: + current_keys = set(parent_pool.variable_dictionary.get(node_id, {})) + snapshot_keys = {name for scope_node_id, name in snapshot if scope_node_id == node_id} - current_keys = set(parent_conversations.keys()) - snapshot_keys = set(snapshot.keys()) + for removed_key in current_keys - snapshot_keys: + parent_pool.remove((node_id, removed_key)) - for removed_key in current_keys - snapshot_keys: - parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) - - for name, variable in snapshot.items(): - parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) + for selector, variable in snapshot.items(): + parent_pool.add(selector, variable) def _append_iteration_info_to_event( self, @@ -595,6 +595,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_config=self.graph_config, run_context=self.run_context, call_depth=self.workflow_call_depth, + child_sync_variable_node_ids=self.graph_init_params.child_sync_variable_node_ids, ) # Create a deep copy of the variable pool for each iteration variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py index dac8680af1..5ae2362621 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -85,8 +85,6 @@ class FileSaverImpl(LLMFileSaver): extension_override: str | None = None, ) -> File: tool_file = self._tool_file_manager.create_file_by_raw( - # TODO(QuantumGhost): what is conversation id? - conversation_id=None, file_binary=data, mimetype=mime_type, ) @@ -100,9 +98,7 @@ class FileSaverImpl(LLMFileSaver): "extension": extension, "mime_type": mime_type, "size": len(data), - "tool_file_id": tool_file.id, - "related_id": tool_file.id, - "storage_key": tool_file.file_key, + "tool_file_id": str(tool_file.id), } ) diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index 8cefa1d106..a600fb4c35 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence -from typing import Any, Protocol, TypeAlias, cast +from collections.abc import Sequence +from typing import Any from dify_graph.file import FileType, file_manager from dify_graph.file.models import File @@ -20,9 +20,9 @@ from dify_graph.model_runtime.entities.message_entities import ( ) from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.nodes.base.entities import VariableSelector from dify_graph.runtime import VariablePool +from dify_graph.template_rendering import Jinja2TemplateRenderer from dify_graph.variables import ArrayFileSegment, FileSegment from dify_graph.variables.segments import ArrayAnySegment, NoneSegment @@ -33,36 +33,11 @@ from .exc import ( NoPromptFoundError, TemplateTypeNotSupportError, ) -from .protocols import TemplateRenderer from .runtime_protocols import PreparedLLMProtocol -class _LegacyModelInstance(Protocol): - model_type_instance: object - model_name: str - credentials: object - parameters: Mapping[str, Any] - - def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... - - -PreparedModelInstance: TypeAlias = PreparedLLMProtocol | _LegacyModelInstance - - -def fetch_model_schema(*, model_instance: PreparedModelInstance) -> AIModelEntity: - model_schema: AIModelEntity | None - get_model_schema = getattr(model_instance, "get_model_schema", None) - if callable(get_model_schema): - model_schema = cast(PreparedLLMProtocol, model_instance).get_model_schema() - else: - legacy_model_instance = cast(_LegacyModelInstance, model_instance) - credentials = legacy_model_instance.credentials - if isinstance(credentials, Mapping): - credentials = dict(credentials) - model_schema = cast(LargeLanguageModel, legacy_model_instance.model_type_instance).get_model_schema( - legacy_model_instance.model_name, - credentials, - ) +def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: + model_schema = model_instance.get_model_schema() if not model_schema: raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") return model_schema @@ -137,7 +112,7 @@ def fetch_prompt_messages( sys_files: Sequence[File], context: str | None = None, memory: PromptMessageMemory | None = None, - model_instance: PreparedModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -146,7 +121,7 @@ def fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] model_schema = fetch_model_schema(model_instance=model_instance) @@ -302,7 +277,7 @@ def handle_list_messages( jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] for message in messages: @@ -356,7 +331,7 @@ def render_jinja2_message( template: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> str: if not template: return "" @@ -367,7 +342,7 @@ def render_jinja2_message( for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) + return template_renderer.render_template(template, jinja2_inputs) def handle_completion_template( @@ -376,7 +351,7 @@ def handle_completion_template( context: str | None, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: if template.edition_type == "jinja2": result_text = render_jinja2_message( @@ -412,7 +387,11 @@ def combine_message_content_with_role( raise NotImplementedError(f"Role {role} is not supported") -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: PreparedModelInstance) -> int: +def calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: rest_tokens = 2000 runtime_model_schema = fetch_model_schema(model_instance=model_instance) runtime_model_parameters = model_instance.parameters @@ -442,7 +421,7 @@ def handle_memory_chat_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: PreparedModelInstance, + model_instance: PreparedLLMProtocol, ) -> Sequence[PromptMessage]: if not memory or not memory_config: return [] @@ -457,7 +436,7 @@ def handle_memory_completion_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: PreparedModelInstance, + model_instance: PreparedLLMProtocol, ) -> str: if not memory or not memory_config: return "" diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5c2bff1e3a..aa8e9e2f56 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -9,13 +9,11 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, cast -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, NodeType, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -113,6 +111,7 @@ class LLMNode(Node[LLMNodeData]): _jinja2_template_renderer: Jinja2TemplateRenderer | None _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None + _default_query_selector: tuple[str, ...] | None def __init__( self, @@ -130,6 +129,7 @@ class LLMNode(Node[LLMNodeData]): prompt_message_serializer: PromptMessageSerializerProtocol, retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, jinja2_template_renderer: Jinja2TemplateRenderer | None = None, + default_query_selector: Sequence[str] | None = None, ): super().__init__( id=id, @@ -148,6 +148,7 @@ class LLMNode(Node[LLMNodeData]): self._prompt_message_serializer = prompt_message_serializer self._retriever_attachment_loader = retriever_attachment_loader self._jinja2_template_renderer = jinja2_template_renderer + self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None @classmethod def version(cls) -> str: @@ -214,8 +215,10 @@ class LLMNode(Node[LLMNodeData]): query: str | None = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template - if not query and ( - query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if ( + not query + and self._default_query_selector + and (query_variable := variable_pool.get(self._default_query_selector)) ): query = query_variable.text @@ -974,9 +977,6 @@ class LLMNode(Node[LLMNodeData]): if node_data.vision.enabled: variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if node_data.prompt_config: enable_jinja = False diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py index e4e5963812..e1b01b2b1e 100644 --- a/api/dify_graph/nodes/llm/protocols.py +++ b/api/dify_graph/nodes/llm/protocols.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Mapping from typing import Any, Protocol from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol @@ -20,11 +19,3 @@ class ModelFactory(Protocol): def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: """Create a prepared LLM runtime that is ready for graph execution.""" ... - - -class TemplateRenderer(Protocol): - """Port for rendering prompt templates used by LLM-compatible nodes.""" - - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - """Render the given Jinja2 template into plain text.""" - ... diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index a9d4e4c450..34bfacec14 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -2,7 +2,7 @@ import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import datetime +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Literal, cast from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -31,7 +31,6 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData from dify_graph.utils.condition.processor import ConditionProcessor -from dify_graph.utils.datetime_utils import naive_utc_now from dify_graph.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable if TYPE_CHECKING: @@ -90,7 +89,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value - start_at = naive_utc_now() + start_at = datetime.now(UTC).replace(tzinfo=None) condition_processor = ConditionProcessor() loop_duration_map: dict[str, float] = {} @@ -123,10 +122,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): self._clear_loop_subgraph_variables(loop_node_ids) graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - loop_start_time = naive_utc_now() + loop_start_time = datetime.now(UTC).replace(tzinfo=None) reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) # Track loop duration - loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds() + loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() # Accumulate outputs from the sub-graph's response nodes for key, value in graph_engine.graph_runtime_state.outputs.items(): @@ -417,6 +416,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): graph_config=self.graph_config, run_context=self.run_context, call_depth=self.workflow_call_depth, + child_sync_variable_node_ids=self.graph_init_params.child_sync_variable_node_ids, ) # Create a new GraphRuntimeState for this iteration diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py index 16d1609742..b2046b8da8 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -4,7 +4,6 @@ from typing import Any, Protocol import httpx from dify_graph.file import File -from dify_graph.file.models import ToolFile class HttpClientProtocol(Protocol): @@ -35,13 +34,12 @@ class ToolFileManagerProtocol(Protocol): def create_file_by_raw( self, *, - conversation_id: str | None, file_binary: bytes, mimetype: str, filename: str | None = None, ) -> Any: ... - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... class FileReferenceFactoryProtocol(Protocol): diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index fbc21d5548..b8a13dea28 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -25,9 +25,9 @@ from dify_graph.nodes.llm import ( llm_utils, ) from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.protocols import TemplateRenderer from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.template_rendering import Jinja2TemplateRenderer from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -62,7 +62,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _prompt_message_serializer: PromptMessageSerializerProtocol _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _template_renderer: Jinja2TemplateRenderer def __init__( self, @@ -75,7 +75,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_factory: object | None = None, model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, + template_renderer: Jinja2TemplateRenderer, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver, prompt_message_serializer: PromptMessageSerializerProtocol | None = None, diff --git a/api/dify_graph/nodes/runtime.py b/api/dify_graph/nodes/runtime.py index fc1ee84c6d..0c9b6493ba 100644 --- a/api/dify_graph/nodes/runtime.py +++ b/api/dify_graph/nodes/runtime.py @@ -11,7 +11,8 @@ from dify_graph.nodes.tool_runtime_entities import ( ) if TYPE_CHECKING: - from dify_graph.nodes.human_input.entities import DeliveryChannelConfig + from dify_graph.nodes.human_input.entities import HumanInputNodeData + from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.tool.entities import ToolNodeData from dify_graph.runtime import VariablePool @@ -43,7 +44,6 @@ class ToolNodeRuntimeProtocol(Protocol): tool_runtime: ToolRuntimeHandle, tool_parameters: Mapping[str, Any], workflow_call_depth: int, - conversation_id: str | None, provider_name: str, ) -> Generator[ToolRuntimeMessage, None, None]: ... @@ -64,12 +64,42 @@ class ToolNodeRuntimeProtocol(Protocol): class HumanInputNodeRuntimeProtocol(Protocol): - def invoke_source(self) -> str: ... + """Workflow-layer adapter for human-input runtime persistence and delivery.""" - def apply_delivery_runtime( + def get_form( self, *, - methods: Sequence[DeliveryChannelConfig], - ) -> Sequence[DeliveryChannelConfig]: ... + node_id: str, + ) -> HumanInputFormStateProtocol | None: ... - def console_actor_id(self) -> str | None: ... + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: ... + + +class HumanInputFormStateProtocol(Protocol): + @property + def id(self) -> str: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self): ... diff --git a/api/dify_graph/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py index 5e6055ea34..3f18555d99 100644 --- a/api/dify_graph/nodes/start/start_node.py +++ b/api/dify_graph/nodes/start/start_node.py @@ -2,7 +2,6 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -19,14 +18,8 @@ class StartNode(Node[StartNodeData]): return "1" def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) self._validate_and_normalize_json_object_inputs(node_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() - - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] outputs = dict(node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index ecb0540638..ce8ceb999c 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) @@ -20,8 +19,7 @@ from dify_graph.nodes.tool_runtime_entities import ( ToolRuntimeMessage, ToolRuntimeParameter, ) -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment -from dify_graph.variables.variables import ArrayAnyVariable +from dify_graph.variables.segments import ArrayFileSegment from .entities import ToolNodeData from .exc import ( @@ -121,15 +119,11 @@ class ToolNode(Node[ToolNodeData]): node_data=self.node_data, for_log=True, ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - try: message_stream = self._runtime.invoke( tool_runtime=tool_runtime, tool_parameters=parameters, workflow_call_depth=self.workflow_call_depth, - conversation_id=conversation_id.text if conversation_id else None, provider_name=self.node_data.provider_name, ) except ToolNodeError as e: @@ -209,11 +203,6 @@ class ToolNode(Node[ToolNodeData]): return result - def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - def _transform_message( self, messages: Generator[ToolRuntimeMessage, None, None], @@ -243,18 +232,22 @@ class ToolNode(Node[ToolNodeData]): url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool message is missing tool_file_id metadata") _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not found") + if tool_file.mime_type is None: + raise ToolFileError(f"tool file {tool_file_id} is missing mime type") file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, - "type": get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mime_type), "transfer_method": transfer_method, "url": url, } @@ -265,7 +258,9 @@ class ToolNode(Node[ToolNodeData]): assert isinstance(message.message, ToolRuntimeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool blob message is missing tool_file_id metadata") _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not exists") diff --git a/api/dify_graph/nodes/tool_runtime_entities.py b/api/dify_graph/nodes/tool_runtime_entities.py index d1a069fc5d..5bb0c16573 100644 --- a/api/dify_graph/nodes/tool_runtime_entities.py +++ b/api/dify_graph/nodes/tool_runtime_entities.py @@ -13,7 +13,11 @@ class _ToolRuntimeModel(BaseModel): @dataclass(frozen=True, slots=True) class ToolRuntimeHandle: - """Opaque graph-owned handle for a workflow-layer tool runtime.""" + """Opaque graph-owned handle for a workflow-layer tool runtime. + + Workflow-specific execution context must stay behind `raw` so the graph + contract does not absorb application-owned concepts. + """ raw: object diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py index f9b261b191..cb45dd7e08 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -1,11 +1,10 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError @@ -56,18 +55,16 @@ class VariableAssignerNode(Node[VariableAssignerData]): node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" mapping[key] = node_data.input_variable_selector return mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) @@ -92,18 +89,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): income_value = SegmentType.get_zero_value(original_variable.value_type) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - # Over write the variable. - self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, + yield VariableUpdatedEvent(variable=updated_variable) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "value": income_value.to_object(), + }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, + ) ) diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py index f04a6b3b80..dad7a391cd 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -1,11 +1,10 @@ import json -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Generator, Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError @@ -29,9 +28,6 @@ if TYPE_CHECKING: def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_node_id = item.variable_selector[0] - if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: - return selector_str = ".".join(item.variable_selector) key = f"{node_id}.#{selector_str}#" mapping[key] = item.variable_selector @@ -103,15 +99,18 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): _source_mapping_from_item(var_mapping, node_id, item) return var_mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] + # Preserve intra-node read-after-write behavior without mutating the shared pool + # until the engine processes the emitted VariableUpdatedEvent instances. + working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) try: for item in self.node_data.items: - variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) + variable = working_variable_pool.get(item.variable_selector) # ==================== Validation Part @@ -136,60 +135,64 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) # Get value from variable pool + input_value = item.value if ( item.input_type == InputType.VARIABLE and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} and item.value is not None ): - value = self.graph_runtime_state.variable_pool.get(item.value) + value = working_variable_pool.get(item.value) if value is None: raise VariableNotFoundError(variable_selector=item.value) # Skip if value is NoneSegment if value.value_type == SegmentType.NONE: continue - item.value = value.value + input_value = value.value # If set string / bytes / bytearray to object, try convert string to object. if ( item.operation == Operation.SET and variable.value_type == SegmentType.OBJECT - and isinstance(item.value, str | bytes | bytearray) + and isinstance(input_value, str | bytes | bytearray) ): try: - item.value = json.loads(item.value) + input_value = json.loads(input_value) except json.JSONDecodeError: - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # Check if input value is valid if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=item.value + variable_type=variable.value_type, operation=item.operation, value=input_value ): - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # ==================== Execution Part updated_value = self._handle_item( variable=variable, operation=item.operation, - value=item.value, + value=input_value, ) - variable = variable.model_copy(update={"value": updated_value}) - self.graph_runtime_state.variable_pool.add(variable.selector, variable) - updated_variable_selectors.append(variable.selector) + updated_variable = variable.model_copy(update={"value": updated_value}) + working_variable_pool.add(updated_variable.selector, updated_variable) + updated_variable_selectors.append(updated_variable.selector) except VariableOperatorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + error=str(e), + ) ) + return # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove the duplicated items first. - updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + # remove duplicated items while preserving the first update order. + updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) for selector in updated_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(selector) + variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -197,15 +200,23 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors - if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + if (seg := working_variable_pool.get(selector)) is not None ] process_data = common_helpers.set_updated_variables(process_data, updated_variables) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, + for selector in updated_variable_selectors: + variable = working_variable_pool.get(selector) + if not isinstance(variable, VariableBase): + raise VariableNotFoundError(variable_selector=selector) + yield VariableUpdatedEvent(variable=variable) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={}, + ) ) def _handle_item( diff --git a/api/dify_graph/repositories/__init__.py b/api/dify_graph/repositories/__init__.py deleted file mode 100644 index ef70eb09cc..0000000000 --- a/api/dify_graph/repositories/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Repository interfaces for data access. - -This package contains repository interfaces that define the contract -for accessing and manipulating data, regardless of the underlying -storage mechanism. -""" - -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository - -__all__ = [ - "OrderConfig", - "WorkflowNodeExecutionRepository", -] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py deleted file mode 100644 index e43185e9d7..0000000000 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Optional application identifier. Implementations may bind this at construction time. - app_id: str | None = None - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/dify_graph/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py deleted file mode 100644 index ef83f07649..0000000000 --- a/api/dify_graph/repositories/workflow_execution_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Protocol - -from dify_graph.entities import WorkflowExecution - - -class WorkflowExecutionRepository(Protocol): - """ - Repository interface for WorkflowExecution. - - This interface defines the contract for accessing and manipulating - WorkflowExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowExecution): - """ - Save or update a WorkflowExecution instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The WorkflowExecution instance to save or update - """ - ... diff --git a/api/dify_graph/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py deleted file mode 100644 index e6c1c3e497..0000000000 --- a/api/dify_graph/repositories/workflow_node_execution_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Literal, Protocol - -from dify_graph.entities import WorkflowNodeExecution - - -@dataclass -class OrderConfig: - """Configuration for ordering NodeExecution instances.""" - - order_by: list[str] - order_direction: Literal["asc", "desc"] | None = None - - -class WorkflowNodeExecutionRepository(Protocol): - """ - Repository interface for NodeExecution. - - This interface defines the contract for accessing and manipulating - NodeExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and trigger sources (triggered_from) should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowNodeExecution): - """ - Save or update a NodeExecution instance. - - This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`, - and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time - and execution-related details. - - It's main purpose is to save the status and various metadata (execution time, execution metadata etc.) - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The NodeExecution instance to save or update - """ - ... - - def save_execution_data(self, execution: WorkflowNodeExecution): - """Save or update the inputs, process_data, or outputs associated with a specific - node_execution record. - - If any of the inputs, process_data, or outputs are None, those fields will not be updated. - """ - ... - - def get_by_workflow_run( - self, - workflow_run_id: str, - order_config: OrderConfig | None = None, - ) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - order_config: Optional configuration for ordering results - order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) - order_config.order_direction: Direction to order ("asc" or "desc") - - Returns: - A list of NodeExecution instances - """ - ... diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py index 41acc6db35..984dcc3526 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -3,6 +3,7 @@ from __future__ import annotations import importlib import json from collections.abc import Mapping, Sequence +from contextlib import AbstractContextManager from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol @@ -211,6 +212,7 @@ class GraphRuntimeState: graph_execution: GraphExecutionProtocol | None = None, response_coordinator: ResponseStreamCoordinatorProtocol | None = None, graph: GraphProtocol | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: self._variable_pool = variable_pool self._start_at = start_at @@ -231,6 +233,9 @@ class GraphRuntimeState: self._ready_queue = ready_queue self._graph_execution = graph_execution self._response_coordinator = response_coordinator + # Application code injects this when worker threads must restore request + # or framework-local state. It is intentionally excluded from snapshots. + self._execution_context = execution_context self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() @@ -329,6 +334,14 @@ class GraphRuntimeState: self._response_coordinator = self._build_response_coordinator(self._graph) return self._response_coordinator + @property + def execution_context(self) -> AbstractContextManager[object] | None: + return self._execution_context + + @execution_context.setter + def execution_context(self, value: AbstractContextManager[object] | None) -> None: + self._execution_context = value + # ------------------------------------------------------------------ # Scalar state # ------------------------------------------------------------------ diff --git a/api/dify_graph/runtime/graph_runtime_state_protocol.py b/api/dify_graph/runtime/graph_runtime_state_protocol.py index 7e55ece3f1..b9674edfed 100644 --- a/api/dify_graph/runtime/graph_runtime_state_protocol.py +++ b/api/dify_graph/runtime/graph_runtime_state_protocol.py @@ -2,7 +2,6 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView from dify_graph.variables.segments import Segment @@ -31,9 +30,6 @@ class ReadOnlyGraphRuntimeState(Protocol): All methods return defensive copies to ensure immutability. """ - @property - def system_variable(self) -> SystemVariableReadOnlyView: ... - @property def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" diff --git a/api/dify_graph/runtime/read_only_wrappers.py b/api/dify_graph/runtime/read_only_wrappers.py index ca06d88c3d..66252ca3fb 100644 --- a/api/dify_graph/runtime/read_only_wrappers.py +++ b/api/dify_graph/runtime/read_only_wrappers.py @@ -5,7 +5,6 @@ from copy import deepcopy from typing import Any from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView from dify_graph.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState @@ -43,10 +42,6 @@ class ReadOnlyGraphRuntimeStateWrapper: self._state = state self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - @property - def system_variable(self) -> SystemVariableReadOnlyView: - return self._state.variable_pool.system_variables.as_view() - @property def variable_pool(self) -> ReadOnlyVariablePoolWrapper: return self._variable_pool_wrapper diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py index 83416ef130..579d122bfe 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -8,18 +8,11 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) from dify_graph.file import File, FileAttribute, file_manager -from dify_graph.system_variable import SystemVariable from dify_graph.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable from dify_graph.variables.consts import SELECTORS_LENGTH from dify_graph.variables.segments import FileSegment, ObjectSegment -from dify_graph.variables.variables import RAGPipelineVariableInput, Variable +from dify_graph.variables.variables import Variable VariableValue = Union[str, int, float, dict[str, object], list[object], File] @@ -36,54 +29,6 @@ class VariablePool(BaseModel): default=defaultdict(dict), ) - # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. - user_inputs: Mapping[str, Any] = Field( - description="User inputs", - default_factory=dict, - ) - system_variables: SystemVariable = Field( - description="System variables", - default_factory=SystemVariable.default, - ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list[Variable], - ) - conversation_variables: Sequence[Variable] = Field( - description="Conversation variables.", - default_factory=list[Variable], - ) - rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( - description="RAG pipeline variables.", - default_factory=list, - ) - - def model_post_init(self, context: Any, /): - # Create a mapping from field names to SystemVariableKey enum values - self._add_system_variables(self.system_variables) - # Add environment variables to the variable pool - for var in self.environment_variables: - self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool. When restoring from a serialized - # snapshot, `variable_dictionary` already carries the latest runtime values. - # In that case, keep existing entries instead of overwriting them with the - # bootstrap list. - for var in self.conversation_variables: - selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) - if self._has(selector): - continue - self.add(selector, var) - # Add rag pipeline variables to the variable pool - if self.rag_pipeline_variables: - rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) - for rag_var in self.rag_pipeline_variables: - node_id = rag_var.variable.belong_to_node_id - key = rag_var.variable.variable - value = rag_var.value - rag_pipeline_variables_map[node_id][key] = value - for key, value in rag_pipeline_variables_map.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) - def add(self, selector: Sequence[str], value: Any, /): """ Add a variable to the variable pool. @@ -261,19 +206,7 @@ class VariablePool(BaseModel): return result - def _add_system_variables(self, system_variable: SystemVariable): - sys_var_mapping = system_variable.to_dict() - for key, value in sys_var_mapping.items(): - if value is None: - continue - selector = (SYSTEM_VARIABLE_NODE_ID, key) - # If the system variable already exists, do not add it again. - # This ensures that we can keep the id of the system variables intact. - if self._has(selector): - continue - self.add(selector, value) - @classmethod def empty(cls) -> VariablePool: """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.default()) + return cls() diff --git a/api/dify_graph/system_variable.py b/api/dify_graph/system_variable.py deleted file mode 100644 index cc5deda892..0000000000 --- a/api/dify_graph/system_variable.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from types import MappingProxyType -from typing import Any -from uuid import uuid4 - -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator - -from dify_graph.enums import SystemVariableKey -from dify_graph.file.models import File - - -class SystemVariable(BaseModel): - """A model for managing system variables. - - Fields with a value of `None` are treated as absent and will not be included - in the variable pool. - """ - - model_config = ConfigDict( - extra="forbid", - serialize_by_alias=True, - validate_by_alias=True, - ) - - user_id: str | None = None - - # Ideally, `app_id` and `workflow_id` should be required and not `None`. - # However, there are scenarios in the codebase where these fields are not set. - # To maintain compatibility, they are marked as optional here. - app_id: str | None = None - workflow_id: str | None = None - - timestamp: int | None = None - - files: Sequence[File] = Field(default_factory=list) - - # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. - # To maintain compatibility with existing workflows, it must be serialized - # as `workflow_run_id` in dictionaries or JSON objects, and also referenced - # as `workflow_run_id` in the variable pool. - workflow_execution_id: str | None = Field( - validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), - serialization_alias="workflow_run_id", - default=None, - ) - # Chatflow related fields. - query: str | None = None - conversation_id: str | None = None - dialogue_count: int | None = None - document_id: str | None = None - original_document_id: str | None = None - dataset_id: str | None = None - batch: str | None = None - datasource_type: str | None = None - datasource_info: Mapping[str, Any] | None = None - invoke_from: str | None = None - - @model_validator(mode="before") - @classmethod - def validate_json_fields(cls, data): - if isinstance(data, dict): - # For JSON validation, only allow workflow_run_id - if "workflow_execution_id" in data and "workflow_run_id" not in data: - # This is likely from direct instantiation, allow it - return data - elif "workflow_execution_id" in data and "workflow_run_id" in data: - # Both present, remove workflow_execution_id - data = data.copy() - data.pop("workflow_execution_id") - return data - return data - - @classmethod - def default(cls) -> SystemVariable: - return cls(workflow_execution_id=str(uuid4())) - - def to_dict(self) -> dict[SystemVariableKey, Any]: - # NOTE: This method is provided for compatibility with legacy code. - # New code should use the `SystemVariable` object directly instead of converting - # it to a dictionary, as this conversion results in the loss of type information - # for each key, making static analysis more difficult. - - d: dict[SystemVariableKey, Any] = { - SystemVariableKey.FILES: self.files, - } - if self.user_id is not None: - d[SystemVariableKey.USER_ID] = self.user_id - if self.app_id is not None: - d[SystemVariableKey.APP_ID] = self.app_id - if self.workflow_id is not None: - d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id - if self.workflow_execution_id is not None: - d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id - if self.query is not None: - d[SystemVariableKey.QUERY] = self.query - if self.conversation_id is not None: - d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id - if self.dialogue_count is not None: - d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count - if self.document_id is not None: - d[SystemVariableKey.DOCUMENT_ID] = self.document_id - if self.original_document_id is not None: - d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id - if self.dataset_id is not None: - d[SystemVariableKey.DATASET_ID] = self.dataset_id - if self.batch is not None: - d[SystemVariableKey.BATCH] = self.batch - if self.datasource_type is not None: - d[SystemVariableKey.DATASOURCE_TYPE] = self.datasource_type - if self.datasource_info is not None: - d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info - if self.invoke_from is not None: - d[SystemVariableKey.INVOKE_FROM] = self.invoke_from - if self.timestamp is not None: - d[SystemVariableKey.TIMESTAMP] = self.timestamp - return d - - def as_view(self) -> SystemVariableReadOnlyView: - return SystemVariableReadOnlyView(self) - - -class SystemVariableReadOnlyView: - """ - A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. - - This class wraps a SystemVariable instance and provides read-only access to all its fields. - It always reads the latest data from the wrapped instance and prevents any write operations. - """ - - def __init__(self, system_variable: SystemVariable) -> None: - """ - Initialize the read-only view with a SystemVariable instance. - - Args: - system_variable: The SystemVariable instance to wrap - """ - self._system_variable = system_variable - - @property - def user_id(self) -> str | None: - return self._system_variable.user_id - - @property - def app_id(self) -> str | None: - return self._system_variable.app_id - - @property - def workflow_id(self) -> str | None: - return self._system_variable.workflow_id - - @property - def workflow_execution_id(self) -> str | None: - return self._system_variable.workflow_execution_id - - @property - def query(self) -> str | None: - return self._system_variable.query - - @property - def conversation_id(self) -> str | None: - return self._system_variable.conversation_id - - @property - def dialogue_count(self) -> int | None: - return self._system_variable.dialogue_count - - @property - def document_id(self) -> str | None: - return self._system_variable.document_id - - @property - def original_document_id(self) -> str | None: - return self._system_variable.original_document_id - - @property - def dataset_id(self) -> str | None: - return self._system_variable.dataset_id - - @property - def batch(self) -> str | None: - return self._system_variable.batch - - @property - def datasource_type(self) -> str | None: - return self._system_variable.datasource_type - - @property - def invoke_from(self) -> str | None: - return self._system_variable.invoke_from - - @property - def files(self) -> Sequence[File]: - """ - Get a copy of the files from the wrapped SystemVariable. - - Returns: - A defensive copy of the files sequence to prevent modification - """ - return tuple(self._system_variable.files) # Convert to immutable tuple - - @property - def datasource_info(self) -> Mapping[str, Any] | None: - """ - Get a copy of the datasource info from the wrapped SystemVariable. - - Returns: - A view of the datasource info mapping to prevent modification - """ - if self._system_variable.datasource_info is None: - return None - return MappingProxyType(self._system_variable.datasource_info) - - def __repr__(self) -> str: - """Return a string representation of the read-only view.""" - return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/dify_graph/template_rendering.py b/api/dify_graph/template_rendering.py index 6fee97dbfc..0527e58f6d 100644 --- a/api/dify_graph/template_rendering.py +++ b/api/dify_graph/template_rendering.py @@ -1,39 +1,18 @@ from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any, Protocol - -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage +from typing import Any class TemplateRenderError(ValueError): """Raised when rendering a template fails.""" -class Jinja2TemplateRenderer(Protocol): - """Shared contract for rendering Jinja2 templates in graph nodes.""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: ... - - -class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via the workflow code executor.""" - - _code_executor: WorkflowCodeExecutor - - def __init__(self, code_executor: WorkflowCodeExecutor) -> None: - self._code_executor = code_executor +class Jinja2TemplateRenderer(ABC): + """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" + @abstractmethod def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - try: - result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) - except Exception as exc: - if self._code_executor.is_execution_error(exc): - raise TemplateRenderError(str(exc)) from exc - raise - - rendered = result.get("result") - if not isinstance(rendered, str): - raise TemplateRenderError("Template render result must be a string.") - return rendered + """Render the template into plain text.""" + raise NotImplementedError diff --git a/api/dify_graph/utils/datetime_utils.py b/api/dify_graph/utils/datetime_utils.py deleted file mode 100644 index 71af44ca92..0000000000 --- a/api/dify_graph/utils/datetime_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -import abc -import datetime -from typing import Protocol - - -class _NowFunction(Protocol): - @abc.abstractmethod - def __call__(self, tz: datetime.timezone | None) -> datetime.datetime: - """Return the current time for the requested timezone.""" - ... - - -_now_func: _NowFunction = datetime.datetime.now - - -def naive_utc_now() -> datetime.datetime: - """Return the current UTC time as a naive datetime.""" - return _now_func(datetime.UTC).replace(tzinfo=None) diff --git a/api/dify_graph/variable_loader.py b/api/dify_graph/variable_loader.py index d263450334..707fb9fb7d 100644 --- a/api/dify_graph/variable_loader.py +++ b/api/dify_graph/variable_loader.py @@ -13,14 +13,6 @@ class VariableLoader(Protocol): A `VariableLoader` is responsible for retrieving additional variables required during the execution of a single node, which are not provided as user inputs. - NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same - application and share the same `app_id`. However, this interface does not enforce that constraint, - and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of - concern and allow for flexible implementations. - - Implementations of `VariableLoader` should almost always have an `app_id` parameter in - their constructor. - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into `WorkflowService.single_step_run`, we may get rid of this interface. """ diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index c58aa6adbb..c0dc3c1a0a 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -7,9 +7,9 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from dify_graph.entities import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index d84c0bc432..c38f9950f6 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -17,10 +17,10 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.entities import WorkflowNodeExecution from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int @@ -304,35 +304,39 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) # Don't raise - LogStore write succeeded, SQL is just a backup - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. Uses LogStore SQL query with window function to get the latest version of each node execution. This ensures we only get the most recent version of each node execution record. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of workflow node execution instances Note: This method uses ROW_NUMBER() window function partitioned by node_execution_id to get the latest version (highest log_version) of each node execution. """ - logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + logger.debug( + "get_by_workflow_execution: workflow_execution_id=%s, order_config=%s", + workflow_execution_id, + order_config, + ) # Build SQL query with deduplication using window function # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) # ensures we get the latest version of each node execution # Escape parameters to prevent SQL injection - escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_workflow_execution_id = escape_identifier(workflow_execution_id) escaped_tenant_id = escape_identifier(self._tenant_id) # Build ORDER BY clause for outer query @@ -360,7 +364,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{escaped_workflow_run_id}' + WHERE workflow_run_id='{escaped_workflow_execution_id}' AND tenant_id='{escaped_tenant_id}' {app_id_filter} ) t @@ -391,5 +395,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): return executions except Exception: - logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + logger.exception( + "Failed to retrieve node executions from LogStore: workflow_execution_id=%s", + workflow_execution_id, + ) raise diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 187d6d0bb7..16f6a0d3b7 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session from werkzeug.http import parse_options_header from core.helper import ssrf_proxy +from core.workflow.file_reference import build_file_reference, resolve_file_record_id from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers from dify_graph.file.file_factory import standardize_file_type from extensions.ext_database import db @@ -21,6 +22,16 @@ from models import MessageFile, ToolFile, UploadFile logger = logging.getLogger(__name__) +def _resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None: + for key in (*keys, "reference", "related_id"): + raw_value = mapping.get(key) + if isinstance(raw_value, str) and raw_value: + resolved_value = resolve_file_record_id(raw_value) + if resolved_value: + return resolved_value + return None + + def build_from_message_files( *, message_files: Sequence["MessageFile"], @@ -158,7 +169,7 @@ def _build_from_local_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - upload_file_id = mapping.get("upload_file_id") + upload_file_id = _resolve_mapping_file_id(mapping, "upload_file_id") if not upload_file_id: raise ValueError("Invalid upload file id") # check if upload_file_id is a valid uuid @@ -191,13 +202,11 @@ def _build_from_local_file( filename=row.name, extension="." + row.extension, mime_type=row.mime_type, - tenant_id=tenant_id, type=file_type, transfer_method=transfer_method, remote_url=row.source_url, - related_id=mapping.get("upload_file_id"), + reference=build_file_reference(record_id=str(row.id), storage_key=row.key), size=row.size, - storage_key=row.key, ) @@ -208,7 +217,7 @@ def _build_from_remote_url( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - upload_file_id = mapping.get("upload_file_id") + upload_file_id = _resolve_mapping_file_id(mapping, "upload_file_id") if upload_file_id: try: uuid.UUID(upload_file_id) @@ -242,13 +251,11 @@ def _build_from_remote_url( filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=file_type, transfer_method=transfer_method, remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - related_id=mapping.get("upload_file_id"), + reference=build_file_reference(record_id=str(upload_file.id), storage_key=upload_file.key), size=upload_file.size, - storage_key=upload_file.key, ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -271,14 +278,12 @@ def _build_from_remote_url( return File( id=mapping.get("id"), filename=filename, - tenant_id=tenant_id, type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, extension=extension, size=file_size, - storage_key="", ) @@ -370,8 +375,7 @@ def _build_from_tool_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - # Backward/interop compatibility: allow tool_file_id to come from related_id or URL - tool_file_id = mapping.get("tool_file_id") + tool_file_id = _resolve_mapping_file_id(mapping, "tool_file_id") if not tool_file_id: raise ValueError(f"ToolFile {tool_file_id} not found") @@ -401,16 +405,14 @@ def _build_from_tool_file( return File( id=mapping.get("id"), - tenant_id=tenant_id, filename=tool_file.name, type=file_type, transfer_method=transfer_method, remote_url=tool_file.original_url, - related_id=tool_file.id, + reference=build_file_reference(record_id=str(tool_file.id), storage_key=tool_file.file_key), extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, - storage_key=tool_file.file_key, ) @@ -421,7 +423,7 @@ def _build_from_datasource_file( transfer_method: FileTransferMethod, strict_type_validation: bool = False, ) -> File: - datasource_file_id = mapping.get("datasource_file_id") + datasource_file_id = _resolve_mapping_file_id(mapping, "datasource_file_id") if not datasource_file_id: raise ValueError(f"DatasourceFile {datasource_file_id} not found") datasource_file = db.session.scalar( @@ -450,16 +452,14 @@ def _build_from_datasource_file( return File( id=mapping.get("datasource_file_id"), - tenant_id=tenant_id, filename=datasource_file.name, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, remote_url=datasource_file.source_url, - related_id=datasource_file.id, + reference=build_file_reference(record_id=str(datasource_file.id), storage_key=datasource_file.key), extension=extension, mime_type=datasource_file.mime_type, size=datasource_file.size, - storage_key=datasource_file.key, url=datasource_file.source_url, ) @@ -533,6 +533,8 @@ class StorageKeyLoader: This method doesn't modify the input sequence structure but updates the `_storage_key` property of each file object by extracting the relevant key from its database record. + Tenant scoping is enforced by this loader's context rather than by embedding tenant + identity inside graph-layer ``File`` values. Performance note: This is a batched operation where database query count remains constant regardless of input size. However, for optimal performance, input sequences should contain @@ -543,16 +545,10 @@ class StorageKeyLoader: upload_file_ids: list[uuid.UUID] = [] tool_file_ids: list[uuid.UUID] = [] for file in files: - related_model_id = file.related_id - if file.related_id is None: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: raise ValueError("file id should not be None.") - if file.tenant_id != self._tenant_id: - err_msg = ( - f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" - ) - raise ValueError(err_msg) - model_id = uuid.UUID(related_model_id) + model_id = uuid.UUID(parsed_reference.record_id) if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): upload_file_ids.append(model_id) @@ -562,14 +558,23 @@ class StorageKeyLoader: tool_files = self._load_tool_files(tool_file_ids) upload_files = self._load_upload_files(upload_file_ids) for file in files: - model_id = uuid.UUID(file.related_id) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + model_id = uuid.UUID(parsed_reference.record_id) if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(f"Upload file not found for id: {model_id}") - file.storage_key = upload_file_row.key + file.reference = build_file_reference( + record_id=str(upload_file_row.id), + storage_key=upload_file_row.key, + ) elif file.transfer_method == FileTransferMethod.TOOL_FILE: tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(f"Tool file not found for id: {model_id}") - file.storage_key = tool_file_row.file_key + file.reference = build_file_reference( + record_id=str(tool_file_row.id), + storage_key=tool_file_row.file_key, + ) diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 407e6d0be6..7dad008df9 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -9,7 +9,7 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from configs import dify_config -from dify_graph.constants import ( +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) diff --git a/api/models/human_input.py b/api/models/human_input.py index 48e7fbb9ea..93efd55e34 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,11 +6,8 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) +from core.workflow.human_input_compat import DeliveryMethodType +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 3bd68d1d95..d8dccba687 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -20,6 +20,7 @@ from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file +from core.workflow.file_reference import parse_file_reference from dify_graph.enums import WorkflowExecutionStatus from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from dify_graph.file import helpers as file_helpers @@ -52,6 +53,21 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- +def _resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + return None + + class EnabledConfig(TypedDict): enabled: bool @@ -1057,10 +1073,11 @@ class Conversation(Base): and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) + record_id = _resolve_file_record_id(value_dict) if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] + value_dict["tool_file_id"] = record_id elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] + value_dict["upload_file_id"] = record_id tenant_id = cast(str, value_dict.get("tenant_id", "")) inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) elif isinstance(value, list): @@ -1075,13 +1092,14 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) + record_id = _resolve_file_record_id(item_dict) if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] + item_dict["tool_file_id"] = record_id elif item_dict["transfer_method"] in [ FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL, ]: - item_dict["upload_file_id"] = item_dict["related_id"] + item_dict["upload_file_id"] = record_id tenant_id = cast(str, item_dict.get("tenant_id", "")) file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) inputs[key] = file_list @@ -1398,10 +1416,11 @@ class Message(Base): and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) + record_id = _resolve_file_record_id(value_dict) if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] + value_dict["tool_file_id"] = record_id elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] + value_dict["upload_file_id"] = record_id tenant_id = cast(str, value_dict.get("tenant_id", "")) inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) elif isinstance(value, list): @@ -1416,13 +1435,14 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) + record_id = _resolve_file_record_id(item_dict) if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] + item_dict["tool_file_id"] = record_id elif item_dict["transfer_method"] in [ FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL, ]: - item_dict["upload_file_id"] = item_dict["related_id"] + item_dict["upload_file_id"] = record_id tenant_id = cast(str, item_dict.get("tenant_id", "")) file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) inputs[key] = file_list diff --git a/api/models/workflow.py b/api/models/workflow.py index 6e8dda429d..242dc6814a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -24,6 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.human_input_compat import normalize_node_config_for_graph from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, @@ -273,7 +274,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(node_config) + return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -1576,13 +1577,23 @@ class WorkflowDraftVariable(Base): if isinstance(value, dict): if not maybe_file_object(value): return cast(Any, value) - return File.model_validate(value) + # Older serialized File payloads may still carry the removed + # graph-level tenant_id field. Strip it at this deserialization edge. + normalized_file = dict(value) + normalized_file.pop("tenant_id", None) + return File.model_validate(normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] if not maybe_file_object(first): return cast(Any, value) - file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + file_list: list[File] = [] + for item in value_list: + # Keep the compatibility handling local to the payload rebuild + # path instead of weakening the File model itself. + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(File.model_validate(normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 2fa065bcc8..3595ea33f0 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index a96c4acb31..a48e67b0e2 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -40,9 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session +from core.repositories.factory import WorkflowExecutionRepository from dify_graph.entities.pause_reason import PauseReason from dify_graph.enums import WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index fdd3e123e4..75ea7bc102 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -43,7 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,25 +61,13 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], ) -> HumanInputRequired: form_content = "" inputs = [] actions = [] - display_in_ui = False resolved_default_values: dict[str, Any] = {} node_title = "Human Input" form_id = reason_model.form_id @@ -99,25 +87,16 @@ def _build_human_input_required_reason( form_content = definition.form_content inputs = list(definition.inputs) actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - return HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, actions=actions, - display_in_ui=display_in_ui, node_id=node_id, node_title=node_title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -823,22 +802,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id ] form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} if form_ids: form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + pause_reasons.append(_build_human_input_required_reason(reason, form_model)) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 229e6608da..9ba58ac6ec 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -177,21 +177,21 @@ class EmailDeliveryTestHandler: def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] - member_user_ids: list[str] = [] + bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) + bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) - if recipients.whole_workspace: - member_user_ids = [] + if recipients.include_bound_group: + bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: + elif bound_reference_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) + for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 296b9f0890..28047bcceb 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -34,25 +34,31 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping +from core.workflow.system_variables import ( + SystemVariableKey, + build_bootstrap_variables, + build_system_variables, + default_system_variables, + get_system_segment, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent from dify_graph.graph_events.base import GraphNodeEventBase from dify_graph.node_events.base import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.variables import VariableBase +from dify_graph.variables.variables import Variable, VariableBase from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -88,6 +94,12 @@ from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) +def _build_seeded_variable_pool(variables: Sequence[Variable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + return variable_pool + + class RagPipelineService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize RagPipelineService with repository dependencies.""" @@ -521,13 +533,7 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ), + variable_pool=_build_seeded_variable_pool(default_system_variables()), variable_loader=DraftVarLoader( engine=db.engine, app_id=pipeline.id, @@ -959,10 +965,10 @@ class RagPipelineService: workflow_node_execution.error = error # update document status variable_pool = node_instance.graph_runtime_state.variable_pool - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + invoke_from = get_system_segment(variable_pool, SystemVariableKey.INVOKE_FROM) if invoke_from: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: @@ -1276,7 +1282,7 @@ class RagPipelineService: else: enclosing_node_id = None - system_inputs = SystemVariable( + system_inputs = build_system_variables( datasource_type=args.get("datasource_type", "online_document"), datasource_info=args.get("datasource_info", {}), ) @@ -1287,12 +1293,11 @@ class RagPipelineService: node_id=node_id, user_inputs={}, user_id=current_user.id, - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs={}, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], + variable_pool=_build_seeded_variable_pool( + build_bootstrap_variables( + system_variables=system_inputs, + rag_pipeline_variables=(), + ) ), variable_loader=DraftVarLoader( engine=db.engine, diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index fb1a3f30c0..d98c36153e 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -15,8 +15,13 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.trigger.constants import is_trigger_node_type -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, SystemVariableKey +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from dify_graph.enums import NodeType from dify_graph.file.models import File from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables @@ -912,7 +917,12 @@ class DraftVariableSaver: if name == SystemVariableKey.FILES: # Here we know the type of variable must be `array[file]`, we # just build files from the value. - files = [File.model_validate(v) for v in value] + files = [ + # Draft variable payloads may predate the removal of + # File.tenant_id, so normalize the raw dict first. + File.model_validate({key: item for key, item in v.items() if key != "tenant_id"}) + for v in value + ] if files: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 72dc26a29a..4abbf8cbc6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -14,10 +14,17 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type +from core.workflow.human_input_compat import ( + DeliveryChannelConfig, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.graph_config import NodeConfigDict @@ -35,18 +42,11 @@ from dify_graph.node_events import NodeRunResult from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.human_input.entities import ( - DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, -) +from dify_graph.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission from dify_graph.nodes.human_input.enums import HumanInputFormKind from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.repositories.human_input_form_repository import FormCreateParams from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import load_into_variable_pool from dify_graph.variables import VariableBase from dify_graph.variables.input_entities import VariableEntityType @@ -765,6 +765,7 @@ class WorkflowService: user_id=account.id, user_inputs=user_inputs, workflow=draft_workflow, + node_id=node_id, # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. conversation_variables=[], node_type=node_type, @@ -772,11 +773,13 @@ class WorkflowService: ) else: - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=draft_workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -995,17 +998,20 @@ class WorkflowService: if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) + node_data = HumanInputNodeData.model_validate( + normalize_human_input_node_data_for_graph(node_config["data"]), + from_attributes=True, + ) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, ) if delivery_method is None: raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( + delivery_method = apply_dify_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id, + actor_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1055,7 +1061,7 @@ class WorkflowService: node_data: HumanInputNodeData, delivery_method_id: str, ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: + for method in parse_human_input_delivery_methods(node_data): if str(method.id) == delivery_method_id: return method return None @@ -1070,9 +1076,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( - app_id=app_model.id, workflow_execution_id=None, node_id=node_id, form_config=node_data, @@ -1138,7 +1143,6 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node @@ -1156,11 +1160,13 @@ class WorkflowService: draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -1423,7 +1429,7 @@ class WorkflowService: from dify_graph.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(node_data) + HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1512,38 +1518,48 @@ def _setup_variable_pool( user_id: str, user_inputs: Mapping[str, Any], workflow: Workflow, + node_id: str, node_type: NodeType, conversation_id: str, conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if is_start_node_type(node_type): - system_variable = SystemVariable( - user_id=user_id, - app_id=workflow.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=workflow.id, - files=files or [], - workflow_execution_id=str(uuid.uuid4()), - ) + system_variable_values: dict[str, Any] = { + "user_id": user_id, + "app_id": workflow.app_id, + "timestamp": int(naive_utc_now().timestamp()), + "workflow_id": workflow.id, + "files": files or [], + "workflow_execution_id": str(uuid.uuid4()), + } - # Only add chatflow-specific variables for non-workflow types + # Only add chatflow-specific variables for non-workflow types. if workflow.type != WorkflowType.WORKFLOW: - system_variable.query = query - system_variable.conversation_id = conversation_id - system_variable.dialogue_count = 1 + system_variable_values.update( + { + "query": query, + "conversation_id": conversation_id, + "dialogue_count": 1, + } + ) + + system_variable = build_system_variables(system_variable_values) else: - system_variable = SystemVariable.default() + system_variable = default_system_variables() # init variable pool - variable_pool = VariablePool( - system_variables=system_variable, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=cast(list[Variable], conversation_variables), # + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variable, + environment_variables=workflow.environment_variables, + conversation_variables=cast(list[Variable], conversation_variables), + ), ) + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d241783359..ac5f873e61 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_database import db from extensions.ext_mail import mail diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index db4bbc1ca1..66f23bdc77 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -118,7 +118,6 @@ class TestStorageKeyLoader(unittest.TestCase): return File( id=str(uuid4()), # Generate new UUID for File.id - tenant_id=tenant_id, type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 9d3a869691..c3b52adef9 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,7 +6,7 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.nodes import BuiltinNodeTypes from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import SegmentType diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3a2b6b866..989078bba6 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -6,13 +6,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import NodeRunResult from dify_graph.nodes.code.code_node import CodeNode from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f71c263914..e866b15bab 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -10,12 +10,12 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.graph import Graph from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -55,7 +55,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -191,6 +191,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -200,11 +201,10 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from dify_graph.nodes.http_request.exc import AuthorizationConfigError from dify_graph.nodes.http_request.executor import Executor from dify_graph.runtime import VariablePool - from dify_graph.system_variable import SystemVariable # Create variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="test", files=[]), + system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock): # Create independent variable pool for this test only variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index dedccb0fd3..f2a1e0b53a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.llm.file_saver import LLMFileSaver @@ -15,7 +16,6 @@ from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", app_id=app_id, workflow_id=workflow_id, diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 4479a9f881..d8ce3fe98c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -6,12 +6,12 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -57,7 +57,7 @@ def init_parameter_extractor_node(config: dict, memory=None): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), user_inputs={}, diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 8324a96c0c..bd41a5fbf9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -3,11 +3,11 @@ import uuid from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -66,7 +66,7 @@ def test_execute_template_transform(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 8b7251842a..dbd3797166 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -6,13 +6,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -41,7 +41,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 96fb7ea293..7107c3d56b 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,6 +31,7 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) +from core.workflow.system_variables import build_system_variables from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.enums import WorkflowExecutionStatus from dify_graph.graph_engine.entities.commands import GraphEngineCommand @@ -40,7 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.variable_pool import SystemVariable, VariablePool +from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account @@ -212,7 +213,7 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 9d0fad4b12..fc25972a3b 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,20 +7,17 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.nodes.human_input.entities import ( +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.repositories.human_input_form_repository import FormCreateParams +from dify_graph.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -68,7 +65,6 @@ def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCre user_actions=[UserAction(id="approve", title="Approve")], ) return FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=form_config, @@ -84,7 +80,7 @@ def _build_email_delivery( ) -> EmailDeliveryMethod: return EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + recipients=EmailRecipients(include_bound_group=whole_workspace, items=recipients), subject="Approval Needed", body="Please review", ) @@ -100,7 +96,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,13 +125,13 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[ _build_email_delivery( whole_workspace=False, recipients=[ - MemberRecipient(user_id=members[0].id), + MemberRecipient(reference_id=members[0].id), ExternalRecipient(email="external@example.com"), ], ) @@ -173,10 +169,9 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( @@ -210,9 +205,8 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index a7ab3e2869..1efb0db9c2 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -10,9 +10,11 @@ from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -24,9 +26,7 @@ from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole @@ -40,7 +40,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -53,7 +53,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -67,7 +67,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, workflow_id=workflow_id, diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 8e70fc0bb0..842ca7f730 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -119,7 +119,6 @@ class TestStorageKeyLoader(unittest.TestCase): return File( id=str(uuid4()), # Generate new UUID for File.id - tenant_id=tenant_id, type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 70d05792ce..4c6718f959 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,14 +4,14 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, ) +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType @@ -54,7 +54,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="recipient@example.com")], ), subject="Test {{recipient_email}}", diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 572cf72fa0..a22981ca93 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -2,7 +2,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 0876a39f82..241bd9f9f8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -9,15 +9,15 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import HumanInputNodeData from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole @@ -79,9 +79,9 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id=account.id), + MemberRecipient(reference_id=account.id), ExternalRecipient(email="external@example.com"), ], ), @@ -96,9 +96,8 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=app_id) params = FormCreateParams( - app_id=app_id, workflow_execution_id=workflow_execution_id, node_id="node-1", form_config=node_data, diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 0e22db9f9b..02190b0f47 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -30,7 +30,6 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: config = object() file_list = [ File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f34702a257..530a4f5906 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,7 +13,7 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.types import SegmentType from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now @@ -316,7 +316,6 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", @@ -374,7 +373,6 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index b4c0903f63..1f5f0646be 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -14,7 +14,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor RagPipelineVariableResetApi, ) from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/files/test_tool_files.py b/api/tests/unit_tests/controllers/files/test_tool_files.py index e5df7a1eea..edb91c3f26 100644 --- a/api/tests/unit_tests/controllers/files/test_tool_files.py +++ b/api/tests/unit_tests/controllers/files/test_tool_files.py @@ -18,10 +18,10 @@ def fake_request(args: dict): class DummyToolFile: - def __init__(self, mimetype="text/plain", size=10, name="tool.txt"): - self.mimetype = mimetype + def __init__(self, mime_type="text/plain", size=10, filename="tool.txt"): + self.mime_type = mime_type self.size = size - self.name = name + self.filename = filename @pytest.fixture(autouse=True) @@ -87,8 +87,8 @@ class TestToolFileApi: stream = iter([b"data"]) tool_file = DummyToolFile( - mimetype="application/pdf", - name="doc.pdf", + mime_type="application/pdf", + filename="doc.pdf", ) mock_tool_file_manager.return_value.get_file_generator_by_tool_file_id.return_value = ( diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 15aceef2c7..5d241432a0 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -11,6 +11,19 @@ from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -49,7 +62,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variable (only var1 exists in DB) @@ -200,7 +213,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Mock conversation and message @@ -349,7 +362,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variables (both exist in DB) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5792a2f1e2..079df0b4e6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -8,6 +8,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationError +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + @pytest.fixture def build_runner(): @@ -30,7 +43,7 @@ def build_runner(): mock_workflow.tenant_id = str(uuid4()) mock_workflow.app_id = app_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] mock_app_config = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 0a244b3fea..e10973f271 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -42,11 +42,12 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk +from core.workflow.system_variables import build_system_variables from dify_graph.enums import BuiltinNodeTypes from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from models.enums import MessageStatus from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -166,7 +167,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -311,7 +312,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -522,7 +523,7 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -556,7 +557,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index b0789bbc1e..e796303583 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -3,13 +3,15 @@ from types import SimpleNamespace import pytest from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.runtime import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, build_system_variables(workflow_execution_id=workflow_run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 72430a3347..433e7c6b0f 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -12,7 +12,6 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Create a test File object""" return File( id=file_id, - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", @@ -223,7 +222,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: assert len(result) == 1 file_dict = result[0] assert file_dict["id"] == "property_test" - assert file_dict["tenant_id"] == "test_tenant" + assert "tenant_id" not in file_dict assert file_dict["type"] == "document" assert file_dict["transfer_method"] == "local_file" assert file_dict["filename"] == "property_test.txt" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 4ed7d73cd0..643f863479 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -4,13 +4,13 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable def _build_converter(): - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 5879e8fb9b..c1d7f4687e 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable def _build_converter() -> WorkflowResponseConverter: """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 374af5ddc4..4e406dc12a 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -24,9 +24,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import BuiltinNodeTypes -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -54,7 +54,7 @@ class TestWorkflowResponseConverter: mock_user.name = "Test User" mock_user.email = "test@example.com" - system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + system_variables = build_system_variables(workflow_id="wf-id", workflow_execution_id="initial-run-id") return WorkflowResponseConverter( application_generate_entity=mock_entity, user=mock_user, @@ -451,9 +451,9 @@ class TestWorkflowResponseConverterServiceApiTruncation: account.id = "test_user_id" return account - def create_test_system_variables(self) -> SystemVariable: + def create_test_system_variables(self): """Create test system variables.""" - return SystemVariable() + return build_system_variables() def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter: """Create WorkflowResponseConverter with specified invoke_from.""" diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index eec95b7f39..b2fba7a388 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -284,7 +284,12 @@ def test_run_normal_path_builds_graph(mocker): return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), ) mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) - mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + class FakeVariablePool: + def add(self, selector, value): + return None + + mocker.patch.object(module, "VariablePool", return_value=FakeVariablePool()) workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index a3ced02394..d5f201bd12 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock - import pytest from core.app.apps.base_app_generator import BaseAppGenerator @@ -489,7 +487,6 @@ class TestBaseAppGeneratorExtras: factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) saver = factory( - session=MagicMock(), app_id="app-id", node_id="node-id", node_type=BuiltinNodeTypes.START, diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 2f73a8cda8..410136728b 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -8,6 +8,7 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import SchedulingPause @@ -29,7 +30,6 @@ from dify_graph.nodes.base.node import Node from dify_graph.nodes.end.entities import EndNodeData from dify_graph.nodes.start.entities import StartNodeData from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -162,11 +162,11 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) - variable_pool.system_variables.workflow_execution_id = run_id + variable_pool.add(("sys", "workflow_run_id"), run_id) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 3f1dd14569..992876d61a 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.system_variables import default_system_variables from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( @@ -29,7 +30,6 @@ from dify_graph.graph_events import ( NodeRunStreamChunkEvent, ) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable class TestWorkflowBasedAppRunner: @@ -44,7 +44,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -83,7 +83,7 @@ class TestWorkflowBasedAppRunner: def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -140,7 +140,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -183,7 +183,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118e..1456262416 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -9,15 +9,15 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.workflow.system_variables import default_system_variables from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from models.workflow import Workflow def _make_graph_state(): variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 65c6bd6654..7f26491e5b 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -10,12 +10,12 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse +from core.workflow.system_variables import build_system_variables from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph_events.graph import GraphRunPausedEvent from dify_graph.nodes.human_input.entities import FormInput, UserAction from dify_graph.nodes.human_input.enums import FormInputType -from dify_graph.system_variable import SystemVariable from models.account import Account @@ -98,7 +98,7 @@ def _build_converter(): invoke_from=InvokeFrom.SERVICE_API, app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="user", app_id="app-id", workflow_id="workflow-id", diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 5b23e71035..51472eb236 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -7,11 +7,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode +from tests.workflow_test_utils import build_test_variable_pool def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: @@ -37,11 +38,7 @@ def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) + variable_pool = build_test_variable_pool(variables=build_system_variables(workflow_execution_id=run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index f35710d207..cb7eb1169c 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -44,11 +44,12 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from models.enums import CreatorUserRole from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -164,7 +165,7 @@ class TestWorkflowGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -205,7 +206,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -257,7 +258,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -451,7 +452,7 @@ class TestWorkflowGenerateTaskPipeline: ) assert pipeline._created_by_role == CreatorUserRole.END_USER - assert pipeline._workflow_system_variables.user_id == "session-id" + assert system_variables_to_mapping(pipeline._workflow_system_variables)["user_id"] == "session-id" def test_process_returns_stream_and_blocking_variants(self): pipeline = _make_pipeline() @@ -699,7 +700,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -727,7 +728,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -743,7 +744,7 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( @@ -815,7 +816,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) assert len(added) == count_before - def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + def test_save_output_for_event_writes_draft_variables(self): pipeline = _make_pipeline() saver_calls: list[tuple[object, object]] = [] captured_factory_args: dict[str, object] = {} @@ -828,29 +829,7 @@ class TestWorkflowGenerateTaskPipeline: captured_factory_args.update(kwargs) return _Saver() - class _Begin: - def __enter__(self): - return None - - def __exit__(self, exc_type, exc, tb): - return False - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return _Begin() - pipeline._draft_var_saver_factory = _factory - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) event = QueueNodeSucceededEvent( node_execution_id="exec-id", diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index bdc889d941..ba009ece6b 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -3,16 +3,15 @@ from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringVariable -from dify_graph.variables.segments import Segment +from dify_graph.variables.segments import Segment, StringSegment class MockReadOnlyVariablePool: @@ -36,31 +35,38 @@ def _build_graph_runtime_state( conversation_id: str | None = None, ) -> ReadOnlyGraphRuntimeState: graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + if conversation_id is not None: + variable_pool._variables[("sys", SystemVariableKey.CONVERSATION_ID.value)] = StringSegment( + value=conversation_id + ) graph_runtime_state.variable_pool = variable_pool - graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() return graph_runtime_state -def _build_node_run_succeeded_event( - *, - node_type: NodeType, - outputs: dict[str, object] | None = None, - process_data: dict[str, object] | None = None, -) -> NodeRunSucceededEvent: +def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="node-exec-id", node_id="assigner", - node_type=node_type, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.utcnow(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs or {}, - process_data=process_data or {}, + outputs={}, + process_data={}, ), ) -def test_persists_conversation_variables_from_assigner_output(): +def _build_variable_updated_event(variable: StringVariable) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id="node-exec-id", + node_id="assigner", + node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, + variable=variable, + ) + + +def test_persists_conversation_variables_from_variable_update_event(): conversation_id = "conv-123" variable = StringVariable( id="var-1", @@ -68,55 +74,26 @@ def test_persists_conversation_variables_from_assigner_output(): value="updated", selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(variable) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) - updater.flush.assert_called_once() -def test_skips_when_outputs_missing(): +def test_skips_non_variable_update_events(): conversation_id = "conv-456" - variable = StringVariable( - id="var-2", - name="name", - value="updated", - selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event() layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() - - -def test_skips_non_assigner_nodes(): - updater = Mock() - layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) - layer.on_event(event) - - updater.update.assert_not_called() - updater.flush.assert_not_called() def test_skips_non_conversation_variables(): @@ -127,18 +104,11 @@ def test_skips_non_conversation_variables(): value="updated", selector=["environment", "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] - ) - - variable_pool = MockReadOnlyVariablePool() - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(non_conversation_variable) layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035f0ee05c..efe952a54b 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,6 +13,7 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) +from core.workflow.system_variables import SystemVariableKey from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.graph_engine.entities.commands import GraphEngineCommand from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -51,17 +52,6 @@ class TestDataFactory: return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) -class MockSystemVariableReadOnlyView: - """Minimal read-only system variable view for testing.""" - - def __init__(self, workflow_execution_id: str | None = None) -> None: - self._workflow_execution_id = workflow_execution_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - class MockReadOnlyVariablePool: """Mock implementation of ReadOnlyVariablePool for testing.""" @@ -76,13 +66,14 @@ class MockReadOnlyVariablePool: return None mock_segment = Mock(spec=Segment) mock_segment.value = value + mock_segment.text = value if isinstance(value, str) else None return mock_segment def get_all_by_node(self, node_id: str) -> dict[str, object]: return {key: value for (nid, key), value in self._variables.items() if nid == node_id} def get_by_prefix(self, prefix: str) -> dict[str, object]: - return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} class MockReadOnlyGraphRuntimeState: @@ -105,12 +96,10 @@ class MockReadOnlyGraphRuntimeState: self._ready_queue_size = ready_queue_size self._exceptions_count = exceptions_count self._outputs = outputs or {} - self._variable_pool = MockReadOnlyVariablePool(variables) - self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) - - @property - def system_variable(self) -> MockSystemVariableReadOnlyView: - return self._system_variable + resolved_variables = dict(variables or {}) + if workflow_execution_id is not None: + resolved_variables[("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value)] = workflow_execution_id + self._variable_pool = MockReadOnlyVariablePool(resolved_variables) @property def variable_pool(self) -> ReadOnlyVariablePool: @@ -161,7 +150,9 @@ class MockReadOnlyGraphRuntimeState: "exceptions_count": self._exceptions_count, "outputs": self._outputs, "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, - "workflow_execution_id": self._system_variable.workflow_execution_id, + "workflow_execution_id": self._variable_pool._variables.get( + ("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value) + ), } ) diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 72da2b4252..f0568c68e9 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -430,7 +430,6 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", @@ -531,7 +530,6 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - tenant_id="t1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index deebf41320..29f6ac6469 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,10 +1,11 @@ +import pytest + from dify_graph.file import File, FileTransferMethod, FileType def test_file(): file = File( id="test-file", - tenant_id="test-tenant-id", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", @@ -15,7 +16,6 @@ def test_file(): storage_key="test-storage-key", url="https://example.com/image.png", ) - assert file.tenant_id == "test-tenant-id" assert file.type == FileType.IMAGE assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.related_id == "test-related-id" @@ -25,8 +25,7 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_with_legacy_fields(): - """Test `File` model can handle data containing compatibility fields.""" +def test_file_model_validate_rejects_removed_tenant_id(): data = { "id": "test-file", "tenant_id": "test-tenant-id", @@ -45,10 +44,5 @@ def test_file_model_validate_with_legacy_fields(): "datasource_file_id": "datasource-file-789", } - # Should be able to create `File` object without raising an exception - file = File.model_validate(data) - - # The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes. - # Instead, check it does not expose unrecognized legacy fields (should raise on getattr). - for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"): - assert not hasattr(file, legacy_field) + with pytest.raises(TypeError, match="tenant_id"): + File.model_validate(data) diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index dfd61acfa7..98076c0b99 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -396,14 +396,14 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = ["node1"] + repo.get_by_workflow_execution.return_value = ["node1"] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) result = trace_instance.get_workflow_node_executions(trace_info) assert result == ["node1"] - repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + repo.get_by_workflow_execution.assert_called_once_with(workflow_execution_id=trace_info.workflow_run_id) def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py index 1cee2f5b68..4ce9e22fd7 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -254,7 +254,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac node1.id = "n1" node1.error = None - repo.get_by_workflow_run.return_value = [node1] + repo.get_by_workflow_execution.return_value = [node1] with patch.object(trace_instance, "get_service_account_with_tenant"): trace_instance.workflow_trace(info) diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 0ff135562c..33adfc139b 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -174,7 +174,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = None repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -244,7 +244,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) @@ -680,7 +680,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index f656f7435f..88706b88a4 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -184,7 +184,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + repo.get_by_workflow_execution.return_value = [node_llm, node_other, node_retrieval] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -255,7 +255,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) @@ -565,7 +565,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm] + repo.get_by_workflow_execution.return_value = [node_llm] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index b2cb7d5109..575f2b1109 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -199,7 +199,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -253,7 +253,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) @@ -657,7 +657,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index f259e4639f..7fdfdc592a 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -413,7 +413,7 @@ class TestTencentDataTrace: with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" ) as mock_repo: - mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py index 7660967183..ad9d0846be 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -130,7 +130,7 @@ class TestWorkflowTraceWithoutMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, @@ -262,7 +262,7 @@ class TestWorkflowTraceWithMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8057bbbad5..221ba49d3b 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -589,7 +589,7 @@ class TestWorkflowTrace: nodes = [] repo = MagicMock() - repo.get_by_workflow_run.return_value = nodes + repo.get_by_workflow_execution.return_value = nodes mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index c7e94aa4cf..fd1ba33890 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -466,7 +466,6 @@ class TestChunkMerger: class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( - tenant_id="tenant-1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", @@ -499,14 +498,12 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( - tenant_id="tenant-1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( - tenant_id="tenant-1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3d08525aba..3be62a444e 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -135,7 +135,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( id="file1", - tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", @@ -246,7 +245,6 @@ def test_completion_prompt_jinja2_with_files(): file = File( id="file1", - tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", @@ -380,7 +378,6 @@ def test_chat_prompt_memory_with_files_and_query(): prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( id="file1", - tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", @@ -414,7 +411,6 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( id="file1", - tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", @@ -464,7 +460,6 @@ def test_chat_prompt_files_with_query_branch(): model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( id="file1", - tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 2a83a4e802..00814e1bc5 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from dify_graph.enums import BuiltinNodeTypes -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +181,10 @@ class TestCeleryWorkflowNodeExecutionRepository: repo.save(sample_workflow_node_execution) @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") - def test_get_by_workflow_run_from_cache( + def test_get_by_workflow_execution_from_cache( self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution ): - """Test that get_by_workflow_run retrieves executions from cache.""" + """Test that get_by_workflow_execution retrieves executions from cache.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -195,18 +195,18 @@ class TestCeleryWorkflowNodeExecutionRepository: # Save execution to cache first repo.save(sample_workflow_node_execution) - workflow_run_id = sample_workflow_node_execution.workflow_execution_id + workflow_execution_id = sample_workflow_node_execution.workflow_execution_id order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) # Verify results were retrieved from cache assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id assert result[0] is sample_workflow_node_execution - def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): - """Test get_by_workflow_run without order configuration.""" + def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_execution without order configuration.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -214,7 +214,7 @@ class TestCeleryWorkflowNodeExecutionRepository: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - result = repo.get_by_workflow_run("workflow-run-id") + result = repo.get_by_workflow_execution("workflow-run-id") # Should return empty list since nothing in cache assert len(result) == 0 @@ -236,7 +236,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert sample_workflow_node_execution.id in repo._execution_cache # Test retrieving from cache - result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id) assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id @@ -251,12 +251,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create multiple executions for the same workflow - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.START, @@ -269,7 +269,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.LLM, @@ -285,10 +285,10 @@ class TestCeleryWorkflowNodeExecutionRepository: # Verify both are cached and mapped assert len(repo._execution_cache) == 2 - assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2 # Test retrieval - result = repo.get_by_workflow_run(workflow_run_id) + result = repo.get_by_workflow_execution(workflow_execution_id) assert len(result) == 2 @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") @@ -302,12 +302,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create executions with different indices - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.START, @@ -320,7 +320,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.LLM, @@ -336,14 +336,14 @@ class TestCeleryWorkflowNodeExecutionRepository: # Test ascending order order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 1 assert result[1].index == 2 # Test descending order order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 2 assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fe9eed0307..48327c3913 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -11,9 +11,12 @@ import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 9af4d12664..7dde4c5c77 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -14,13 +14,15 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, MemberRecipient, +) +from dify_graph.nodes.human_input.entities import ( + FormDefinition, UserAction, ) from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus @@ -89,9 +91,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="external@example.com"), ], ), @@ -125,9 +127,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="missing-member"), + MemberRecipient(reference_id="missing-member"), ExternalRecipient(email="external@example.com"), ], ), @@ -156,7 +158,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[], ), ) @@ -182,7 +184,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ ExternalRecipient(email="external@example.com"), ExternalRecipient(email="external@example.com"), @@ -212,9 +214,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="shared@example.com"), ], ), @@ -243,7 +245,7 @@ class TestHumanInputFormRepositoryImplHelpers: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[ExternalRecipient(email="external@example.com")], ), subject="subject", @@ -421,22 +423,22 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.id == form.id - assert entity.web_app_token == "token-123" + assert entity.submission_token == "token-123" assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id="run-1") - assert repo.get_form("run-1", "node-1") is None + assert repo.get_form("node-1") is None def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( @@ -451,9 +453,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is False @@ -476,9 +478,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is True diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 251d6fd25e..a5bbe374b1 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -7,7 +7,6 @@ from models.workflow import Workflow def test_file_to_dict(): file = File( id="file1", - tenant_id="tenant1", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index cca8254dd6..b94834d8bc 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -14,6 +14,7 @@ import httpx import pytest from core.tools.tool_file_manager import ToolFileManager +from dify_graph.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: @@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: def test_get_file_generator_returns_stream_when_found() -> None: # Arrange manager = ToolFileManager() - tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + tool_file = SimpleNamespace( + id="tool123", + file_key="k2", + mimetype="image/png", + original_url=None, + name="image.png", + size=12, + ) session = Mock() session.query.return_value.where.return_value.first.return_value = tool_file @@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None: with patch("core.tools.tool_file_manager.storage") as storage: stream = iter([b"a", b"b"]) storage.load_stream.return_value = stream - with ( - _patch_session_factory(session), - patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), - ): + with _patch_session_factory(session): result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") assert list(result_stream) == [b"a", b"b"] - assert result_file == "validated-file" + assert result_file is not None + assert result_file.related_id == "tool123" + assert result_file.mime_type == "image/png" + assert result_file.transfer_method == FileTransferMethod.TOOL_FILE diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index af3cdddd5f..6454a5bcd1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): # meta is preserved (still contains mime_type: None) assert "mime_type" in (o.meta or {}) assert o.meta["mime_type"] is None + assert o.meta["tool_file_id"] == "fake-tool-file-id" + + +def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta(): + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"), + meta={}, + ) + + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + assert len(out) == 1 + assert out[0].meta["tool_file_id"] == "existing-tool-file" diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 91259c9a45..0c7e5899de 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -5,9 +5,10 @@ import pytest from pydantic import BaseModel from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segment_group import SegmentGroup from dify_graph.variables.segments import ( ArrayAnySegment, @@ -48,14 +49,28 @@ from dify_graph.variables.variables import ( ) +def _build_variable_pool( + *, + system_variables: list[Variable] | None = None, + environment_variables: list[Variable] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables or [], + environment_variables=environment_variables or [], + ), + ) + return variable_pool + + def test_segment_group_to_text(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="fake-user-id"), environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], - conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( @@ -71,11 +86,8 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="1", app_id="1", workflow_id="1"), ) template = "Hello, world!" segments_group = variable_pool.convert_template(template) @@ -84,12 +96,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(system_variables=build_system_variables(user_id="fake-user-id")) template = "{{#sys.user_id#}}" segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" @@ -116,7 +123,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, @@ -190,7 +196,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_segment.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: @@ -234,7 +239,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_variable.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 41ce483447..19f7673f6b 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -35,7 +35,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing.""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index d09b8397c3..3ce4bb753b 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from dify_graph.context.execution_context import ( +from context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() def teardown_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from dify_graph.context import ContextProviderNotFoundError + from context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 22792eb5b3..491326916b 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from dify_graph.variables.variables import StringVariable diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index 6100ebede5..0515b9d93a 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -126,7 +126,7 @@ class TestVariablePoolGetNotModifyVariableDictionary: def test_get_should_not_modify_variable_dictionary(self): pool = VariablePool.empty() pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 1 # only contains `sys` node id + assert len(pool.variable_dictionary) == 0 assert "start" not in pool.variable_dictionary pool = VariablePool.empty() diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index 75de07bd8b..b61b8d59c3 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -5,11 +5,11 @@ from typing import Any import pytest from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import default_system_variables from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes import BuiltinNodeTypes from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -63,7 +63,7 @@ def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], ), diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index e94ad74eb0..a42b4e9c8e 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import pytest +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType @@ -13,7 +14,6 @@ from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -96,7 +96,7 @@ def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: invoke_from="service-api", call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) return factory, graph_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index fc8133f5e1..f3db1f5e0c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -7,13 +7,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import ( +from core.repositories.human_input_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRecipientEntity, HumanInputFormRepository, ) +from dify_graph.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now @@ -49,7 +49,7 @@ class _InMemoryFormEntity(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return self.token @property @@ -88,24 +88,24 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): self._form_counter = 0 self.created_params: list[FormCreateParams] = [] self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: self.created_params.append(params) self._form_counter += 1 form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" + token = f"token-{form_id}" entity = _InMemoryFormEntity( form_id=form_id, rendered=params.rendered_content, token=token, ) self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + self._forms_by_node_id[params.node_id] = entity return entity - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) # Convenience helpers for tests ------------------------------------- diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index d54f0be190..bc5d6c2c45 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -1,6 +1,7 @@ import time from collections.abc import Mapping +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.enums import NodeState from dify_graph.graph import Graph @@ -20,7 +21,6 @@ from dify_graph.nodes.llm.entities import ( from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig @@ -29,7 +29,7 @@ from .test_mock_nodes import MockLLMNode def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 99b8d8b34d..9ffd1c7168 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,7 +4,9 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -31,9 +33,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -61,7 +61,7 @@ def _build_branching_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -248,7 +248,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -304,7 +304,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 6b1c6ad936..dc9db6d725 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,7 +3,9 @@ import time from unittest import mock from unittest.mock import MagicMock +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -30,9 +32,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +60,7 @@ def _build_llm_human_llm_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," ), user_inputs={}, @@ -193,7 +193,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -262,7 +262,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index 8da179c15e..39759a7555 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,6 +1,7 @@ import time from unittest import mock +from core.workflow.system_variables import build_system_variables from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunStartedEvent, @@ -26,7 +27,6 @@ from dify_graph.nodes.llm.entities import ( from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr ) variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py index 733fd53bc8..a8bc96117f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py @@ -14,6 +14,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -50,6 +51,7 @@ def test_loop_contains_answer(): NodeRunLoopStartedEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 1 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, @@ -60,6 +62,7 @@ def test_loop_contains_answer(): NodeRunLoopNextEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 2 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 482dbd4a91..87284c6740 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -60,7 +60,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -123,7 +123,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -187,7 +187,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -249,7 +249,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -318,7 +318,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -384,7 +384,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -458,7 +458,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -534,7 +534,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -579,7 +579,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -634,7 +634,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 6453785c20..08d30315ac 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,7 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -24,13 +30,7 @@ from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -68,7 +68,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -104,7 +104,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -113,7 +113,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 46ad6775d4..65317292ed 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,7 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -31,13 +37,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +60,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -96,7 +96,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -116,7 +116,7 @@ class DelayedHumanInputNode(HumanInputNode): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b954a4faac..ba9460e08e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -15,6 +15,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_system_variables from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -28,7 +29,6 @@ from dify_graph.graph_events import ( from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -98,7 +98,7 @@ def test_parallel_streaming_workflow(): ) # Create variable pool with system variables - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=init_params.workflow_id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index e2f92663b1..399c73b2ac 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,7 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -32,13 +38,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -61,7 +61,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -97,7 +97,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, form: HumanInputFormEntity) -> None: self._form = form - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: if node_id != "human_pause": return None return self._form @@ -108,7 +108,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index ddce34c663..a528412682 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,7 +3,12 @@ import time from typing import Any from unittest.mock import MagicMock +from core.repositories.human_input_repository import ( + HumanInputFormEntity, + HumanInputFormRepository, +) from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -23,19 +28,14 @@ from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -51,7 +51,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -66,7 +66,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 4f1741d4fb..f8aaea6424 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -4,6 +4,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -33,6 +34,7 @@ def test_streaming_conversation_variables(): NodeRunSucceededEvent, # Variable Assigner node NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, # ANSWER node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 07190e2654..2dec3580c0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -22,6 +22,9 @@ from typing import Any, cast from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import CHILD_SYNC_VARIABLE_NODE_IDS from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -33,7 +36,6 @@ from dify_graph.graph_events import ( GraphRunSucceededEvent, ) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ( ArrayNumberVariable, ArrayObjectVariable, @@ -204,16 +206,18 @@ class WorkflowRunner: } }, call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, ) - user_inputs = inputs if inputs is not None else {} + root_node_inputs = dict(inputs or {}) + root_node_inputs.setdefault("query", query) # Extract conversation variables from workflow config conversation_variables = [] @@ -242,11 +246,16 @@ class WorkflowRunner: ) conversation_variables.append(var) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs=user_inputs, - conversation_variables=conversation_variables, + root_node_id = get_default_root_node_id(graph_config) + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables, + conversation_variables=conversation_variables, + ), ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=root_node_inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -260,7 +269,7 @@ class WorkflowRunner: graph = Graph.init( graph_config=graph_config, node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), + root_node_id=root_node_id, ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py new file mode 100644 index 0000000000..d8ec3c7037 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py @@ -0,0 +1,129 @@ +import time +import uuid +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from dify_graph.entities import GraphInitParams +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import NodeRunVariableUpdatedEvent +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.variables import StringVariable + +DEFAULT_NODE_ID = "node_id" + + +class CaptureVariableUpdateLayer(GraphEngineLayer): + def __init__(self) -> None: + super().__init__() + self.events: list[NodeRunVariableUpdatedEvent] = [] + self.observed_values: list[object | None] = [] + + def on_graph_start(self) -> None: + pass + + def on_event(self, event) -> None: + if not isinstance(event, NodeRunVariableUpdatedEvent): + return + + current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) + self.events.append(event) + self.observed_values.append(None if current_value is None else current_value.value) + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_graph_engine_applies_variable_updates_before_notifying_layers(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "Start"}, "id": "start"}, + { + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "over-write", + "input_variable_selector": ["node_id", "test_string_variable"], + }, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + workflow_id="1", + graph_config=graph_config, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, + call_depth=0, + ) + + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), + conversation_variables=[ + StringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value="the first value", + ) + ], + ), + ) + variable_pool.add( + [DEFAULT_NODE_ID, "test_string_variable"], + StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ), + ) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + + engine = GraphEngine( + workflow_id="workflow-id", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + ) + capture_layer = CaptureVariableUpdateLayer() + engine.layer(capture_layer) + + events = list(engine.run()) + + update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] + assert len(update_events) == 1 + assert update_events[0].variable.value == "the second value" + + current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) + assert current_value is not None + assert current_value.value == "the second value" + + assert len(capture_layer.events) == 1 + assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py new file mode 100644 index 0000000000..4c2fdabd0b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -0,0 +1,33 @@ +from unittest.mock import patch + +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from dify_graph.enums import BuiltinNodeTypes + + +def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: + messages = iter(()) + transformer = AgentMessageTransformer() + + with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", return_value=iter(())) as transform: + result = list( + transformer.transform( + messages=messages, + tool_info={}, + parameters_for_log={}, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + node_type=BuiltinNodeTypes.AGENT, + node_id="node-id", + node_execution_id="execution-id", + ) + ) + + assert len(result) == 2 + transform.assert_called_once_with( + messages=messages, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fd563d1be2..3c7017cd54 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.answer.answer_node import AnswerNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -48,7 +48,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index cea7195417..93d876fb26 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -2,6 +2,7 @@ import pytest from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import ( BodyData, @@ -14,7 +15,6 @@ from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout from dify_graph.nodes.http_request.exc import AuthorizationConfigError from dify_graph.nodes.http_request.executor import Executor from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -30,7 +30,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -86,7 +86,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -144,7 +144,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -231,7 +231,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -537,7 +537,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -584,7 +584,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -625,7 +625,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["node", "count"], 42) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index b05c932071..afeb78fb2c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -8,12 +8,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -110,7 +110,7 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=time.perf_counter(), ) return HttpRequestNode( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d52dfa2a65..1a731726b7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients from dify_graph.runtime import VariablePool diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index a5ec36e1b4..f3d7c8900a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -9,35 +9,37 @@ import pytest from pydantic import ValidationError from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import HumanInputFormRepository +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + EmailRecipientType, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.node_events import PauseRequestedEvent from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, FormInput, FormInputDefault, HumanInputNodeData, - MemberRecipient, UserAction, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, ) from dify_graph.nodes.human_input.enums import ( ButtonStyle, - DeliveryMethodType, - EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit, ) from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository @@ -55,9 +57,9 @@ class TestDeliveryMethod: def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="test-user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -194,7 +196,7 @@ class TestHumanInputNodeData: EmailDeliveryMethod( enabled=False, # Disabled method should be fine config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + subject="Hi there", body="", recipients=EmailRecipients(include_bound_group=True) ), ), ] @@ -213,7 +215,7 @@ class TestHumanInputNodeData: assert node_data.title == "Test Node" assert node_data.desc is None - assert node_data.delivery_methods == [] + assert node_data.model_dump().get("delivery_methods") is None assert node_data.form_content == "" assert node_data.inputs == [] assert node_data.user_actions == [] @@ -262,10 +264,10 @@ class TestRecipients: def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123") assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" @@ -274,37 +276,45 @@ class TestRecipients: assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" + def test_email_recipients_bound_group(self): + """Test email recipients with the bound group enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + include_bound_group=True, + items=[MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123")], ) - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + assert recipients.include_bound_group is True + assert len(recipients.items) == 1 # Items are preserved even when include_bound_group is True def test_email_recipients_specific_users(self): """Test email recipients with specific users.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], ) - assert recipients.whole_workspace is False + assert recipients.include_bound_group is False assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" + assert recipients.items[0].reference_id == "user-123" assert recipients.items[1].email == "external@example.com" + def test_legacy_recipient_keys_are_rejected(self): + with pytest.raises(ValidationError): + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + with pytest.raises(ValidationError): + EmailRecipients(whole_workspace=True, items=[]) + class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -354,18 +364,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + runtime=runtime, ) run_result = node._run() @@ -380,7 +391,7 @@ class TestHumanInputNodeVariableResolution: def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -418,29 +429,30 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-2", rendered_content="Provide your name", - web_app_token="console-token", + submission_token="console-token", recipients=[SimpleNamespace(token="recipient-token")], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + runtime=runtime, ) run_result = node._run() pause_event = next(run_result) assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" + assert not hasattr(pause_event.reason, "form_token") def test_debugger_debug_mode_overrides_email_recipients(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user-123", app_id="app", workflow_id="workflow", @@ -475,7 +487,7 @@ class TestHumanInputNodeVariableResolution: enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], ), subject="Subject", @@ -492,18 +504,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-3", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + runtime=runtime, ) run_result = node._run() @@ -515,11 +528,11 @@ class TestHumanInputNodeVariableResolution: method = params.delivery_methods[0] assert isinstance(method, EmailDeliveryMethod) assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False + assert method.config.recipients.include_bound_group is False assert len(method.config.recipients.items) == 1 recipient = method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" class TestValidation: @@ -556,7 +569,7 @@ class TestHumanInputNodeRenderedContent: def test_replaces_outputs_placeholders_after_submission(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -595,13 +608,14 @@ class TestHumanInputNodeRenderedContent: config = {"id": "human", "data": node_data.model_dump()} form_repository = InMemoryHumanInputFormRepository() + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=form_repository, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + runtime=runtime, ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index ddb9328bf1..26de1833da 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -3,6 +3,7 @@ from types import SimpleNamespace from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( @@ -13,7 +14,6 @@ from dify_graph.graph_events import ( from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now @@ -26,7 +26,7 @@ class _FakeFormRepository: def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -91,7 +91,7 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 2eb4feef5f..bc4d3e4097 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -3,6 +3,7 @@ from typing import Any import pytest +from core.workflow.system_variables import default_system_variables from dify_graph.entities import GraphInitParams from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError from dify_graph.nodes.iteration.iteration_node import IterationNode @@ -12,7 +13,6 @@ from dify_graph.runtime import ( GraphRuntimeState, VariablePool, ) -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -32,7 +32,7 @@ class _MissingGraphBuilder: def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 33f7ace5ab..f6d17a857a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -14,9 +14,9 @@ from core.workflow.nodes.knowledge_index.protocols import ( PreviewItem, SummaryIndexServiceProtocol, ) -from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus +from core.workflow.system_variables import SystemVariableKey, build_system_variables +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +40,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 99997db6b2..d61cb222b9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -16,10 +16,10 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -43,7 +43,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index aa17391ad5..ba433b27b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock import httpx import pytest -from core.tools.tool_file_manager import ToolFileManager from dify_graph.file import FileTransferMethod, FileType from dify_graph.nodes.llm.file_saver import ( FileSaverImpl, @@ -14,6 +13,7 @@ from dify_graph.nodes.llm.file_saver import ( _get_extension, _validate_extension_override, ) +from dify_graph.nodes.protocols import ToolFileManagerProtocol _PNG_DATA = b"\x89PNG\r\n\x1a\n" @@ -30,7 +30,7 @@ class TestFileSaverImpl: mock_tool_file.id = _gen_id() mock_tool_file.name = f"{_gen_id()}.png" mock_tool_file.file_key = "test-file-key" - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file file_reference = MagicMock() file_reference_factory = MagicMock() @@ -47,7 +47,6 @@ class TestFileSaverImpl: assert file is file_reference mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - conversation_id=None, file_binary=_PNG_DATA, mimetype=mime_type, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 350ba55c5b..a4fcad9b09 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -14,9 +14,9 @@ from core.app.llm.model_access import ( ) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from dify_graph.entities import GraphInitParams from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.common_entities import I18nObject @@ -44,7 +44,6 @@ from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -62,6 +61,24 @@ class MockTokenBufferMemory: return self.history_messages +def _build_prepared_llm_mock() -> mock.MagicMock: + model_instance = mock.MagicMock() + model_instance.provider = "openai" + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.stop = () + model_instance.get_llm_num_tokens.return_value = 0 + model_instance.get_model_schema.return_value = AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + model_instance.is_structured_output_parse_error.return_value = False + return model_instance + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( @@ -98,7 +115,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) return GraphRuntimeState( @@ -127,7 +144,7 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, @@ -314,7 +331,6 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -332,7 +348,6 @@ def test_fetch_files_with_array_file_segment(): files = [ File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -341,7 +356,6 @@ def test_fetch_files_with_array_file_segment(): ), File( id="2", - tenant_id="test", type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -391,7 +405,6 @@ def test_fetch_files_with_non_existent_variable(): # files = [ # File( # id="1", -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -496,7 +509,6 @@ def test_fetch_files_with_non_existent_variable(): # sys_query=fake_query, # sys_files=[ # File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -572,7 +584,6 @@ def test_fetch_files_with_non_existent_variable(): # + [UserPromptMessage(content=fake_query)], # file_variables={ # "input.image": File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -656,7 +667,7 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): AssistantPromptMessage(content="first answer"), ] - model_instance = mock.MagicMock(spec=ModelInstance) + model_instance = _build_prepared_llm_mock() memory_config = MemoryConfig( role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), @@ -693,7 +704,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, @@ -711,7 +722,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -742,7 +752,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -797,7 +806,6 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: mock_saved_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 6cb7e18ca5..9f8b7694d7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -4,13 +4,13 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -36,7 +36,7 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) return init_params, runtime_state @@ -81,7 +81,7 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index f584fcda3e..1bc0bb8cb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -6,13 +6,13 @@ import pytest from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.graph import Graph from dify_graph.nodes.if_else.entities import IfElseNodeData from dify_graph.nodes.if_else.if_else_node import IfElseNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition from dify_graph.variables import ArrayFileSegment from extensions.ext_database import db @@ -34,7 +34,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -141,7 +141,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) @@ -252,7 +252,6 @@ def test_array_file_contains_file_name(): node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -315,7 +314,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -370,7 +369,7 @@ def test_execute_if_else_boolean_false_conditions(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -439,7 +438,7 @@ def test_execute_if_else_boolean_cases_structure(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 2eee22a678..d1c25da489 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -71,7 +71,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image1.jpg", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", @@ -79,7 +78,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="document1.pdf", type=FileType.DOCUMENT, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", @@ -87,7 +85,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image2.png", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", @@ -95,7 +92,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="audio1.mp3", type=FileType.AUDIO, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -119,14 +115,12 @@ def test_filter_files_by_type(list_operator_node): { "filename": "document1.pdf", "type": FileType.DOCUMENT, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related2", }, { "filename": "image2.png", "type": FileType.IMAGE, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related3", }, @@ -135,7 +129,6 @@ def test_filter_files_by_type(list_operator_node): for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type - assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id @@ -143,7 +136,6 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", @@ -164,7 +156,6 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index c5a02e87e4..fc7176de46 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -2,12 +2,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.nodes.question_classifier import ( QuestionClassifierNode, QuestionClassifierNodeData, ) +from dify_graph.template_rendering import Jinja2TemplateRenderer from tests.workflow_test_utils import build_test_graph_init_params @@ -86,7 +87,7 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon "instruction": "This is a test instruction", } ) - template_renderer = MagicMock(spec=TemplateRenderer) + template_renderer = MagicMock(spec=Jinja2TemplateRenderer) node = QuestionClassifierNode( id="node-id", config={"id": "node-id", "data": node_data.model_dump(mode="json")}, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index b8f0e25e91..4b1545b9e8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,19 +4,19 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError +from core.workflow.system_variables import build_system_variables from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from tests.workflow_test_utils import build_test_graph_init_params +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def make_start_node(user_inputs, variables): - variable_pool = VariablePool( - system_variables=SystemVariable(), - user_inputs=user_inputs, - conversation_variables=[], + variable_pool = build_test_variable_pool( + variables=build_system_variables(), + node_id="start", + inputs=user_inputs, ) config = { diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 344ded7180..2a954fbaf8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -3,17 +3,18 @@ from __future__ import annotations import sys import types from collections.abc import Generator +from types import SimpleNamespace from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest +from core.workflow.system_variables import build_system_variables from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -34,7 +35,6 @@ class _StubToolRuntime: tool_runtime: ToolRuntimeHandle, tool_parameters: dict[str, Any], workflow_call_depth: int, - conversation_id: str | None, provider_name: str, ) -> Generator[ToolRuntimeMessage, None, None]: yield from () @@ -98,7 +98,7 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id")) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] @@ -140,7 +140,6 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - tenant_id="tenant-id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", @@ -192,3 +191,35 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [] + + +def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): + file_obj = File( + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="file-id", + filename="demo.pdf", + extension=".pdf", + mime_type="application/pdf", + size=123, + storage_key="file-key", + ) + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + None, + SimpleNamespace(mime_type="application/pdf"), + ) + tool_node._runtime.build_file_reference = MagicMock(return_value=file_obj) + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.IMAGE_LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), + meta={"tool_file_id": "file-id"}, + ) + + events, _ = _run_transform(tool_node, message) + + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert len(completed_events) == 1 + files_segment = completed_events[0].node_run_result.outputs["files"] + assert isinstance(files_segment, ArrayFileSegment) + assert files_segment.value == [file_obj] diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index b822821f1b..c44a0a674e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -15,11 +15,13 @@ from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.nodes.tool.entities import ToolNodeData, ToolProviderType from dify_graph.nodes.tool.exc import ToolRuntimeInvocationError from dify_graph.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from tests.workflow_test_utils import build_test_graph_init_params +from dify_graph.runtime import VariablePool +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @pytest.fixture @@ -44,31 +46,57 @@ def runtime(monkeypatch) -> DifyToolNodeRuntime: return DifyToolNodeRuntime(init_params.run_context) +def _build_tool_node_data() -> ToolNodeData: + return ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: core_message = ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), meta=None, ) + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables(conversation_id="conversation-id") + ) + workflow_tool = MagicMock() with ( + patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_tool), patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, patch.object( ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=lambda *, messages, **_: messages, - ), + ) as transform_tool_messages, ): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + ) messages = list( runtime.invoke( - tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_runtime=tool_runtime, tool_parameters={}, workflow_call_depth=0, - conversation_id=None, provider_name="provider", ) ) + assert not hasattr(tool_runtime, "conversation_id") assert len(messages) == 1 graph_message = messages[0] assert graph_message.type == ToolRuntimeMessage.MessageType.LINK @@ -77,6 +105,10 @@ def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRunt callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] assert isinstance(callback, DifyWorkflowCallbackHandler) + assert generic_invoke_mock.call_args.kwargs["conversation_id"] == "conversation-id" + + transform_kwargs = transform_tool_messages.call_args.kwargs + assert transform_kwargs["conversation_id"] == "conversation-id" def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: @@ -88,7 +120,6 @@ def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) tool_runtime=ToolRuntimeHandle(raw=MagicMock()), tool_parameters={}, workflow_call_depth=0, - conversation_id=None, provider_name="provider", ) @@ -105,22 +136,11 @@ def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: - node_data = ToolNodeData.model_validate( - { - "type": "tool", - "title": "Tool", - "provider_id": "provider", - "provider_type": ToolProviderType.BUILT_IN, - "provider_name": "provider", - "tool_name": "lookup", - "tool_label": "Lookup", - "tool_configurations": {}, - "tool_parameters": {}, - } - ) + node_data = _build_tool_node_data() with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: - runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + tool_runtime = runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + assert not hasattr(tool_runtime, "conversation_id") workflow_tool = runtime_mock.call_args.args[3] assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 6a0548fa1c..1d09324927 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -4,19 +4,36 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph -from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayStringVariable, StringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool( + *, + conversation_id: str, + conversation_variables: list[StringVariable | ArrayStringVariable], +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=conversation_id), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_overwrite_string_variable(): graph_config = { "edges": [ @@ -70,10 +87,8 @@ def test_overwrite_string_variable(): conversation_id = str(uuid.uuid4()) # construct variable pool - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -107,16 +122,14 @@ def test_overwrite_string_variable(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == input_variable.value - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.value == "the second value" - assert got.to_object() == "the second value" + assert updated_event.variable.value == "the second value" + assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) def test_append_variable_to_array(): @@ -171,10 +184,8 @@ def test_append_variable_to_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) variable_pool.add( @@ -207,15 +218,13 @@ def test_append_variable_to_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == ["the first value", "the second value"] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["the first value", "the second value"] + assert updated_event.variable.value == ["the first value", "the second value"] def test_clear_array(): @@ -264,10 +273,8 @@ def test_clear_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -296,12 +303,10 @@ def test_clear_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == [] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index dd338651b9..f61aea98bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -4,17 +4,31 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.entities import GraphInitParams from dify_graph.graph import Graph +from dify_graph.graph_events import NodeRunVariableUpdatedEvent from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayStringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id="conversation_id"), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_handle_item_directly(): """Test the _handle_item method directly for remove operations.""" # Create variables @@ -105,12 +119,7 @@ def test_remove_first_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -145,11 +154,8 @@ def test_remove_first_from_array(): # Run the node result = list(node.run()) - # Completed run - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["second", "third"] + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["second", "third"] def test_remove_last_from_array(): @@ -193,12 +199,7 @@ def test_remove_last_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -230,11 +231,9 @@ def test_remove_last_from_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["first", "second"] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["first", "second"] def test_remove_first_from_empty_array(): @@ -278,12 +277,7 @@ def test_remove_first_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -315,11 +309,9 @@ def test_remove_first_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_remove_last_from_empty_array(): @@ -363,12 +355,7 @@ def test_remove_last_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -400,11 +387,9 @@ def test_remove_last_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_node_factory_creates_variable_assigner_node(): @@ -432,12 +417,7 @@ def test_node_factory_creates_variable_assigner_node(): }, call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(conversation_variables=[]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index b8d243a66d..6a47ef7ff8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -16,11 +16,12 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.system_variables import default_system_variables from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node( @@ -96,6 +97,18 @@ def create_test_file_dict( } +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="webhook-node-1", + inputs=inputs, + ) + + +def expected_factory_mapping(file_dict: dict) -> dict: + return {**file_dict, "upload_file_id": file_dict["related_id"]} + + def test_webhook_node_file_conversion_to_file_variable(): """Test that webhook node converts file dictionaries to FileVariable objects.""" # Create test file dictionary (as it comes from webhook service) @@ -111,9 +124,8 @@ def test_webhook_node_file_conversion_to_file_variable(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -122,7 +134,7 @@ def test_webhook_node_file_conversion_to_file_variable(): "image_upload": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -153,7 +165,7 @@ def test_webhook_node_file_conversion_to_file_variable(): # Verify file factory was called with correct parameters mock_file_factory.assert_called_once_with( - mapping=file_dict, + mapping=expected_factory_mapping(file_dict), tenant_id="test-tenant", ) @@ -184,16 +196,15 @@ def test_webhook_node_file_conversion_with_missing_files(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, # No files } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -219,9 +230,8 @@ def test_webhook_node_file_conversion_with_none_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -230,7 +240,7 @@ def test_webhook_node_file_conversion_with_none_file(): "file": None, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -256,9 +266,8 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -267,7 +276,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): "file": "not_a_dict", # Wrapped to match node expectation }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -300,9 +309,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -315,7 +323,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): "file_param": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -350,7 +358,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): # Verify file conversion was called mock_file_factory.assert_called_once_with( - mapping=file_dict, + mapping=expected_factory_mapping(file_dict), tenant_id="test-tenant", ) @@ -370,9 +378,8 @@ def test_webhook_node_different_file_types(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -383,7 +390,7 @@ def test_webhook_node_different_file_types(): "video": create_test_file_dict("video.mp4", "video"), }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -430,9 +437,8 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -441,7 +447,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): "file": "just a string", }, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 15cc07fbeb..b8a6877e14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -12,13 +12,14 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode +from core.workflow.system_variables import default_system_variables from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import FileVariable, StringVariable +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -62,6 +63,14 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="1", + inputs=inputs, + ) + + def test_webhook_node_basic_initialization(): """Test basic webhook node initialization and configuration.""" data = WebhookData( @@ -76,10 +85,7 @@ def test_webhook_node_basic_initialization(): timeout=30, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - ) + variable_pool = build_webhook_variable_pool({}) node = create_webhook_node(data, variable_pool) @@ -119,9 +125,8 @@ def test_webhook_node_run_with_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "Authorization": "Bearer token123", @@ -132,7 +137,7 @@ def test_webhook_node_run_with_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -155,9 +160,8 @@ def test_webhook_node_run_with_query_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": { @@ -167,7 +171,7 @@ def test_webhook_node_run_with_query_params(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -191,9 +195,8 @@ def test_webhook_node_run_with_body_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -205,7 +208,7 @@ def test_webhook_node_run_with_body_params(): }, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -222,7 +225,6 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -232,7 +234,6 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - tenant_id="1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", @@ -250,9 +251,8 @@ def test_webhook_node_run_with_file_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -262,7 +262,7 @@ def test_webhook_node_run_with_file_params(): "document": file2.to_dict(), }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -284,7 +284,6 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -303,16 +302,15 @@ def test_webhook_node_run_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, "files": {"upload": file_obj.to_dict()}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -343,10 +341,7 @@ def test_webhook_node_run_empty_webhook_data(): body=[WebhookBodyParameter(name="message", type="string", required=False)], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, # No webhook_data - ) + variable_pool = build_webhook_variable_pool({}) # No webhook_data node = create_webhook_node(data, variable_pool) result = node._run() @@ -369,9 +364,8 @@ def test_webhook_node_run_case_insensitive_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "content-type": "application/json", # lowercase @@ -382,7 +376,7 @@ def test_webhook_node_run_case_insensitive_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -399,12 +393,11 @@ def test_webhook_node_variable_pool_user_inputs(): data = WebhookData(title="Test Webhook") # Add some additional variables to the pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", - }, + } ) variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value")) @@ -430,16 +423,15 @@ def test_webhook_node_different_methods(method): method=method, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index e427c4b90a..b103111119 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -5,9 +5,10 @@ import pytest from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom from core.workflow import node_factory +from core.workflow import template_rendering as workflow_template_rendering from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.code.entities import CodeLanguage from dify_graph.variables.segments import StringSegment @@ -139,6 +140,37 @@ class TestDefaultWorkflowCodeExecutor: assert executor.is_execution_error(RuntimeError("boom")) is False +class TestCodeExecutorJinja2TemplateRenderer: + def test_render_template_delegates_to_code_executor(self, monkeypatch): + renderer = workflow_template_rendering.CodeExecutorJinja2TemplateRenderer() + execute_workflow_code_template = MagicMock(return_value={"result": "Hello workflow"}) + monkeypatch.setattr( + workflow_template_rendering.CodeExecutor, + "execute_workflow_code_template", + execute_workflow_code_template, + ) + + result = renderer.render_template("Hello {{ name }}", {"name": "workflow"}) + + assert result == "Hello workflow" + execute_workflow_code_template.assert_called_once_with( + language=CodeLanguage.JINJA2, + code="Hello {{ name }}", + inputs={"name": "workflow"}, + ) + + def test_render_template_wraps_code_execution_errors(self, monkeypatch): + renderer = workflow_template_rendering.CodeExecutorJinja2TemplateRenderer() + monkeypatch.setattr( + workflow_template_rendering.CodeExecutor, + "execute_workflow_code_template", + MagicMock(side_effect=workflow_template_rendering.CodeExecutionError("sandbox failed")), + ) + + with pytest.raises(workflow_template_rendering.TemplateRenderError, match="sandbox failed"): + renderer.render_template("{{ broken }}", {}) + + class TestDifyNodeFactoryInit: def test_init_builds_default_dependencies(self): graph_init_params = SimpleNamespace(run_context={"context": "value"}) @@ -220,8 +252,7 @@ class TestDifyNodeFactoryInit: resolve_dify_context.assert_called_once_with(graph_init_params.run_context) build_dify_model_access.assert_called_once_with(dify_context) - renderer_factory.assert_called_once() - assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor + renderer_factory.assert_called_once_with() assert factory.graph_init_params is graph_init_params assert factory.graph_runtime_state is graph_runtime_state assert factory._dify_context is dify_context @@ -278,8 +309,13 @@ class TestDifyNodeFactoryCreateNode: def factory(self): factory = object.__new__(node_factory.DifyNodeFactory) factory.graph_init_params = sentinel.graph_init_params - factory.graph_runtime_state = sentinel.graph_runtime_state - factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id") + factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock()) + factory._dify_context = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + invoke_from=InvokeFrom.DEBUGGER, + ) factory._code_executor = sentinel.code_executor factory._code_limits = sentinel.code_limits factory._jinja2_template_renderer = sentinel.jinja2_template_renderer @@ -341,7 +377,7 @@ class TestDifyNodeFactoryCreateNode: assert kwargs["id"] == "node-id" _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params - assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert kwargs["graph_runtime_state"] is factory.graph_runtime_state latest_node_class.assert_not_called() def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory): @@ -361,7 +397,7 @@ class TestDifyNodeFactoryCreateNode: assert kwargs["id"] == "node-id" _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params - assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert kwargs["graph_runtime_state"] is factory.graph_runtime_state @pytest.mark.parametrize( ("node_type", "constructor_name"), @@ -402,7 +438,7 @@ class TestDifyNodeFactoryCreateNode: assert kwargs["id"] == "node-id" _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type) assert kwargs["graph_init_params"] is sentinel.graph_init_params - assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert kwargs["graph_runtime_state"] is factory.graph_runtime_state if constructor_name == "CodeNode": assert kwargs["code_executor"] is sentinel.code_executor @@ -419,7 +455,13 @@ class TestDifyNodeFactoryCreateNode: elif constructor_name == "HumanInputNode": assert kwargs["form_repository"] is form_repository assert kwargs["runtime"] is sentinel.human_input_runtime - form_repository_impl.assert_called_once_with(tenant_id="tenant-id", app_id="app-id") + form_repository_impl.assert_called_once_with( + tenant_id="tenant-id", + app_id="app-id", + workflow_execution_id=None, + invoke_source=InvokeFrom.DEBUGGER, + submission_actor_id="user-id", + ) elif constructor_name == "DocumentExtractorNode": assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config assert kwargs["http_client"] is sentinel.http_client @@ -445,6 +487,7 @@ class TestDifyNodeFactoryCreateNode: "http_client": sentinel.http_client, "llm_file_saver": sentinel.llm_file_saver, "prompt_message_serializer": sentinel.prompt_message_serializer, + "template_renderer": sentinel.jinja2_template_renderer, }, ), ( @@ -501,7 +544,7 @@ class TestDifyNodeFactoryCreateNode: assert constructor_kwargs["id"] == "node-id" _assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type) assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params - assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider assert constructor_kwargs["model_factory"] is sentinel.model_factory assert constructor_kwargs["model_instance"] is sentinel.model_instance @@ -584,9 +627,7 @@ class TestDifyNodeFactoryMemory: ) assert result is sentinel.memory - factory.graph_runtime_state.variable_pool.get.assert_called_once_with( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) + factory.graph_runtime_state.variable_pool.get.assert_called_once_with(("sys", "conversation_id")) fetch_memory.assert_called_once_with( conversation_id="conversation-id", app_id="app-id", diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py new file mode 100644 index 0000000000..c8d7a77b58 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -0,0 +1,41 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.workflow import node_runtime +from core.workflow.node_runtime import DifyToolFileManager + + +def test_dify_tool_file_manager_resolves_conversation_id_for_tool_files(monkeypatch) -> None: + create_file_by_raw = MagicMock(return_value=SimpleNamespace(id="tool-file-id")) + manager_instance = SimpleNamespace(create_file_by_raw=create_file_by_raw) + monkeypatch.setattr(node_runtime, "ToolFileManager", MagicMock(return_value=manager_instance)) + conversation_id_getter = MagicMock(return_value="conversation-id") + + manager = DifyToolFileManager( + DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), + conversation_id_getter=conversation_id_getter, + ) + + tool_file = manager.create_file_by_raw( + file_binary=b"file-bytes", + mimetype="image/png", + filename="diagram", + ) + + assert tool_file.id == "tool-file-id" + conversation_id_getter.assert_called_once_with() + create_file_by_raw.assert_called_once_with( + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + file_binary=b"file-bytes", + mimetype="image/png", + filename="diagram", + ) diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 8023a0b594..1084729148 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -1,251 +1,58 @@ -import json -from typing import Any - -import pytest -from pydantic import ValidationError - +from core.workflow.system_variables import build_system_variables, default_system_variables, system_variables_to_mapping from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.file.models import File -from dify_graph.system_variable import SystemVariable - -# Test data constants for SystemVariable serialization tests -VALID_BASE_DATA: dict[str, Any] = { - "user_id": "a20f06b1-8703-45ab-937c-860a60072113", - "app_id": "661bed75-458d-49c9-b487-fda0762677b9", - "workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43", -} - -COMPLETE_VALID_DATA: dict[str, Any] = { - **VALID_BASE_DATA, - "query": "test query", - "files": [], - "conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9", - "dialogue_count": 5, - "workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5", -} -def create_test_file() -> File: - """Create a test File object for serialization tests.""" - return File( - tenant_id="test-tenant-id", +def test_build_system_variables_normalizes_workflow_execution_id(): + system_variables = build_system_variables( + user_id="user-id", + workflow_execution_id="run-id", + query="hello", + ignored=None, + ) + system_values = system_variables_to_mapping(system_variables) + + assert system_values == { + "user_id": "user-id", + "workflow_run_id": "run-id", + "query": "hello", + "files": [], + } + + +def test_build_system_variables_preserves_explicit_workflow_run_id(): + system_variables = build_system_variables( + workflow_run_id="explicit-run-id", + workflow_execution_id="ignored-run-id", + ) + system_values = system_variables_to_mapping(system_variables) + + assert system_values["workflow_run_id"] == "explicit-run-id" + assert system_values["files"] == [] + + +def test_build_system_variables_preserves_file_values(): + file = File( type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="test-file-id", + related_id="file-id", filename="test.txt", extension=".txt", mime_type="text/plain", - size=1024, - storage_key="test-storage-key", + size=1, + storage_key="storage-key", ) + system_variables = build_system_variables(files=[file]) + system_values = system_variables_to_mapping(system_variables) -class TestSystemVariableSerialization: - """Focused tests for SystemVariable serialization/deserialization logic.""" - - def test_basic_deserialization(self): - """Test successful deserialization from JSON structure with all fields correctly mapped.""" - # Test with complete data - system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA) - - # Verify all fields are correctly mapped - assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] - assert system_var.app_id == COMPLETE_VALID_DATA["app_id"] - assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"] - assert system_var.query == COMPLETE_VALID_DATA["query"] - assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"] - assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"] - assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"] - assert system_var.files == [] - - # Test with minimal data (only required fields) - minimal_var = SystemVariable.model_validate(VALID_BASE_DATA) - assert minimal_var.user_id == VALID_BASE_DATA["user_id"] - assert minimal_var.app_id == VALID_BASE_DATA["app_id"] - assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] - assert minimal_var.query is None - assert minimal_var.conversation_id is None - assert minimal_var.dialogue_count is None - assert minimal_var.workflow_execution_id is None - assert minimal_var.files == [] - - def test_alias_handling(self): - """Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic.""" - workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5" - - # Test workflow_run_id only (preferred alias) - data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var1 = SystemVariable.model_validate(data_run_id) - assert system_var1.workflow_execution_id == workflow_id - - # Test workflow_execution_id only (direct field name) - data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var2 = SystemVariable.model_validate(data_execution_id) - assert system_var2.workflow_execution_id == workflow_id - - # Test both present - workflow_run_id should take precedence - data_both = { - **VALID_BASE_DATA, - "workflow_execution_id": "should-be-ignored", - "workflow_run_id": workflow_id, - } - system_var3 = SystemVariable.model_validate(data_both) - assert system_var3.workflow_execution_id == workflow_id - - # Test neither present - should be None - system_var4 = SystemVariable.model_validate(VALID_BASE_DATA) - assert system_var4.workflow_execution_id is None - - def test_serialization_round_trip(self): - """Test that serialize → deserialize produces the same result with alias handling.""" - # Create original SystemVariable - original = SystemVariable.model_validate(COMPLETE_VALID_DATA) - - # Serialize to dict - serialized = original.model_dump(mode="json") - - # Verify alias is used in serialization (workflow_run_id, not workflow_execution_id) - assert "workflow_run_id" in serialized - assert "workflow_execution_id" not in serialized - assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] - - # Deserialize back - deserialized = SystemVariable.model_validate(serialized) - - # Verify all fields match after round-trip - assert deserialized.user_id == original.user_id - assert deserialized.app_id == original.app_id - assert deserialized.workflow_id == original.workflow_id - assert deserialized.query == original.query - assert deserialized.conversation_id == original.conversation_id - assert deserialized.dialogue_count == original.dialogue_count - assert deserialized.workflow_execution_id == original.workflow_execution_id - assert list(deserialized.files) == list(original.files) - - def test_json_round_trip(self): - """Test JSON serialization/deserialization consistency with proper structure.""" - # Create original SystemVariable - original = SystemVariable.model_validate(COMPLETE_VALID_DATA) - - # Serialize to JSON string - json_str = original.model_dump_json() - - # Parse JSON and verify structure - json_data = json.loads(json_str) - assert "workflow_run_id" in json_data - assert "workflow_execution_id" not in json_data - assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] - - # Deserialize from JSON data - deserialized = SystemVariable.model_validate(json_data) - - # Verify key fields match after JSON round-trip - assert deserialized.workflow_execution_id == original.workflow_execution_id - assert deserialized.user_id == original.user_id - assert deserialized.app_id == original.app_id - assert deserialized.workflow_id == original.workflow_id - - def test_files_field_deserialization(self): - """Test deserialization with File objects in the files field - SystemVariable specific logic.""" - # Test with empty files list - data_empty = {**VALID_BASE_DATA, "files": []} - system_var_empty = SystemVariable.model_validate(data_empty) - assert system_var_empty.files == [] - - # Test with single File object - test_file = create_test_file() - data_single = {**VALID_BASE_DATA, "files": [test_file]} - system_var_single = SystemVariable.model_validate(data_single) - assert len(system_var_single.files) == 1 - assert system_var_single.files[0].filename == "test.txt" - assert system_var_single.files[0].tenant_id == "test-tenant-id" - - # Test with multiple File objects - file1 = File( - tenant_id="tenant1", - type=FileType.DOCUMENT, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="file1", - filename="doc1.txt", - storage_key="key1", - ) - file2 = File( - tenant_id="tenant2", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url="https://example.com/image.jpg", - filename="image.jpg", - storage_key="key2", - ) - - data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} - system_var_multiple = SystemVariable.model_validate(data_multiple) - assert len(system_var_multiple.files) == 2 - assert system_var_multiple.files[0].filename == "doc1.txt" - assert system_var_multiple.files[1].filename == "image.jpg" - - # Verify files field serialization/deserialization - serialized = system_var_multiple.model_dump(mode="json") - deserialized = SystemVariable.model_validate(serialized) - assert len(deserialized.files) == 2 - assert deserialized.files[0].filename == "doc1.txt" - assert deserialized.files[1].filename == "image.jpg" - - def test_alias_serialization_consistency(self): - """Test that alias handling works consistently in both serialization directions.""" - workflow_id = "test-workflow-id" - - # Create with workflow_run_id (alias) - data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var = SystemVariable.model_validate(data_with_alias) - - # Serialize and verify alias is used - serialized = system_var.model_dump() - assert serialized["workflow_run_id"] == workflow_id - assert "workflow_execution_id" not in serialized - - # Deserialize and verify field mapping - deserialized = SystemVariable.model_validate(serialized) - assert deserialized.workflow_execution_id == workflow_id - - # Test JSON serialization path - json_serialized = json.loads(system_var.model_dump_json()) - assert json_serialized["workflow_run_id"] == workflow_id - assert "workflow_execution_id" not in json_serialized - - json_deserialized = SystemVariable.model_validate(json_serialized) - assert json_deserialized.workflow_execution_id == workflow_id - - def test_model_validator_serialization_logic(self): - """Test the custom model validator behavior for serialization scenarios.""" - workflow_id = "test-workflow-execution-id" - - # Test direct instantiation with workflow_execution_id (should work) - data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var1 = SystemVariable.model_validate(data1) - assert system_var1.workflow_execution_id == workflow_id - - # Test serialization of the above (should use alias) - serialized1 = system_var1.model_dump() - assert "workflow_run_id" in serialized1 - assert serialized1["workflow_run_id"] == workflow_id - - # Test both present - workflow_run_id takes precedence (validator logic) - data2 = { - **VALID_BASE_DATA, - "workflow_execution_id": "should-be-removed", - "workflow_run_id": workflow_id, - } - system_var2 = SystemVariable.model_validate(data2) - assert system_var2.workflow_execution_id == workflow_id - - # Verify serialization consistency - serialized2 = system_var2.model_dump() - assert serialized2["workflow_run_id"] == workflow_id + assert system_values["files"] == [file] -def test_constructor_with_extra_key(): - # Test that SystemVariable should forbid extra keys - with pytest.raises(ValidationError): - # This should fail because there is an unexpected key. - SystemVariable(invalid_key=1) +def test_default_system_variables_generates_workflow_run_id(): + system_variables = default_system_variables() + system_values = system_variables_to_mapping(system_variables) + + assert isinstance(system_values["workflow_run_id"], str) + assert system_values["workflow_run_id"] + assert system_values["files"] == [] diff --git a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py deleted file mode 100644 index b7a8f2551d..0000000000 --- a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import cast - -import pytest - -from dify_graph.file.models import File, FileTransferMethod, FileType -from dify_graph.system_variable import SystemVariable, SystemVariableReadOnlyView - - -class TestSystemVariableReadOnlyView: - """Test cases for SystemVariableReadOnlyView class.""" - - def test_read_only_property_access(self): - """Test that all properties return correct values from wrapped instance.""" - # Create test data - test_file = File( - id="file-123", - tenant_id="tenant-123", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="related-123", - ) - - datasource_info = {"key": "value", "nested": {"data": 42}} - - # Create SystemVariable with all fields - system_var = SystemVariable( - user_id="user-123", - app_id="app-123", - workflow_id="workflow-123", - files=[test_file], - workflow_execution_id="exec-123", - query="test query", - conversation_id="conv-123", - dialogue_count=5, - document_id="doc-123", - original_document_id="orig-doc-123", - dataset_id="dataset-123", - batch="batch-123", - datasource_type="type-123", - datasource_info=datasource_info, - invoke_from="invoke-123", - ) - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Test all properties - assert read_only_view.user_id == "user-123" - assert read_only_view.app_id == "app-123" - assert read_only_view.workflow_id == "workflow-123" - assert read_only_view.workflow_execution_id == "exec-123" - assert read_only_view.query == "test query" - assert read_only_view.conversation_id == "conv-123" - assert read_only_view.dialogue_count == 5 - assert read_only_view.document_id == "doc-123" - assert read_only_view.original_document_id == "orig-doc-123" - assert read_only_view.dataset_id == "dataset-123" - assert read_only_view.batch == "batch-123" - assert read_only_view.datasource_type == "type-123" - assert read_only_view.invoke_from == "invoke-123" - - def test_defensive_copying_of_mutable_objects(self): - """Test that mutable objects are defensively copied.""" - # Create test data - test_file = File( - id="file-123", - tenant_id="tenant-123", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="related-123", - ) - - datasource_info = {"key": "original_value"} - - # Create SystemVariable - system_var = SystemVariable( - files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123" - ) - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Test files defensive copying - files_copy = read_only_view.files - assert isinstance(files_copy, tuple) # Should be immutable tuple - assert len(files_copy) == 1 - assert files_copy[0].id == "file-123" - - # Verify it's a copy (can't modify original through view) - assert isinstance(files_copy, tuple) - # tuples don't have append method, so they're immutable - - # Test datasource_info defensive copying - datasource_copy = read_only_view.datasource_info - assert datasource_copy is not None - assert datasource_copy["key"] == "original_value" - - datasource_copy = cast(dict, datasource_copy) - with pytest.raises(TypeError): - datasource_copy["key"] = "modified value" - - # Verify original is unchanged - assert system_var.datasource_info is not None - assert system_var.datasource_info["key"] == "original_value" - assert read_only_view.datasource_info is not None - assert read_only_view.datasource_info["key"] == "original_value" - - def test_always_accesses_latest_data(self): - """Test that properties always return the latest data from wrapped instance.""" - # Create SystemVariable - system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123") - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Verify initial value - assert read_only_view.user_id == "original-user" - - # Modify the wrapped instance - system_var.user_id = "modified-user" - - # Verify view returns the new value - assert read_only_view.user_id == "modified-user" - - def test_repr_method(self): - """Test the __repr__ method.""" - # Create SystemVariable - system_var = SystemVariable(workflow_execution_id="exec-123") - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Test repr - repr_str = repr(read_only_view) - assert "SystemVariableReadOnlyView" in repr_str - assert "system_variable=" in repr_str - - def test_none_value_handling(self): - """Test that None values are properly handled.""" - # Create SystemVariable with all None values except workflow_execution_id - system_var = SystemVariable( - user_id=None, - app_id=None, - workflow_id=None, - workflow_execution_id="exec-123", - query=None, - conversation_id=None, - dialogue_count=None, - document_id=None, - original_document_id=None, - dataset_id=None, - batch=None, - datasource_type=None, - datasource_info=None, - invoke_from=None, - ) - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Test all None values - assert read_only_view.user_id is None - assert read_only_view.app_id is None - assert read_only_view.workflow_id is None - assert read_only_view.query is None - assert read_only_view.conversation_id is None - assert read_only_view.dialogue_count is None - assert read_only_view.document_id is None - assert read_only_view.original_document_id is None - assert read_only_view.dataset_id is None - assert read_only_view.batch is None - assert read_only_view.datasource_type is None - assert read_only_view.datasource_info is None - assert read_only_view.invoke_from is None - - # files should be empty tuple even when default list is empty - assert read_only_view.files == () - - def test_empty_files_handling(self): - """Test that empty files list is handled correctly.""" - # Create SystemVariable with empty files - system_var = SystemVariable(files=[], workflow_execution_id="exec-123") - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Test files handling - assert read_only_view.files == () - assert isinstance(read_only_view.files, tuple) - - def test_empty_datasource_info_handling(self): - """Test that empty datasource_info is handled correctly.""" - # Create SystemVariable with empty datasource_info - system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123") - - # Create read-only view - read_only_view = SystemVariableReadOnlyView(system_var) - - # Test datasource_info handling - assert read_only_view.datasource_info == {} - # Should be a copy, not the same object - assert read_only_view.datasource_info is not system_var.datasource_info diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 0fa0d26114..a219068168 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -3,10 +3,15 @@ from collections import defaultdict import pytest -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import FileSegment, StringSegment from dify_graph.variables.segments import ( ArrayAnySegment, @@ -34,16 +39,21 @@ from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture def pool(): - return VariablePool( - system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_system_variables( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + ), ) + return variable_pool @pytest.fixture def file(): return File( - tenant_id="test_tenant_id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_related_id", @@ -70,46 +80,35 @@ def test_get_file_attribute(pool, file): class TestVariablePool: def test_constructor(self): - # Test with minimal required SystemVariable - minimal_system_vars = SystemVariable( + pool = VariablePool() + assert pool.variable_dictionary == defaultdict(dict) + + complex_system_vars = build_system_variables( user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" ) - pool = VariablePool(system_variables=minimal_system_vars) - - # Test with all parameters - pool = VariablePool( - variable_dictionary={}, - user_inputs={}, - system_variables=minimal_system_vars, - environment_variables=[], - conversation_variables=[], - ) - - # Test with more complex SystemVariable - complex_system_vars = SystemVariable( - user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" - ) - pool = VariablePool( - user_inputs={"key": "value"}, - system_variables=complex_system_vars, - environment_variables=[ + add_variables_to_pool(pool, complex_system_vars) + add_variables_to_pool( + pool, + [ segment_to_variable( segment=build_segment(1), selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"], name="env_var_1", - ) - ], - conversation_variables=[ + ), segment_to_variable( segment=build_segment("1"), selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"], name="conv_var_1", - ) + ), ], ) + assert pool.get([SYSTEM_VARIABLE_NODE_ID, "user_id"]) is not None + assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None + assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None + def test_get_system_variables(self): - sys_var = SystemVariable( + sys_var = build_system_variables( user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id", @@ -118,16 +117,18 @@ class TestVariablePool: conversation_id="test_conv_id", dialogue_count=5, ) - pool = VariablePool(system_variables=sys_var) + pool = VariablePool() + add_variables_to_pool(pool, sys_var) + system_values = system_variables_to_mapping(sys_var) kv = [ - ("user_id", sys_var.user_id), - ("app_id", sys_var.app_id), - ("workflow_id", sys_var.workflow_id), - ("workflow_run_id", sys_var.workflow_execution_id), - ("query", sys_var.query), - ("conversation_id", sys_var.conversation_id), - ("dialogue_count", sys_var.dialogue_count), + ("user_id", system_values["user_id"]), + ("app_id", system_values["app_id"]), + ("workflow_id", system_values["workflow_id"]), + ("workflow_run_id", system_values["workflow_run_id"]), + ("query", system_values["query"]), + ("conversation_id", system_values["conversation_id"]), + ("dialogue_count", system_values["dialogue_count"]), ] for key, expected_value in kv: segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key]) @@ -149,7 +150,7 @@ class TestVariablePoolSerialization: def _create_pool_without_file(self): # Create comprehensive system variables - system_vars = SystemVariable( + system_vars = build_system_variables( user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id", @@ -236,17 +237,14 @@ class TestVariablePoolSerialization: } # Create VariablePool - pool = VariablePool( - system_variables=system_vars, - user_inputs=user_inputs, - environment_variables=env_vars, - conversation_variables=conv_vars, - ) + pool = VariablePool() + add_variables_to_pool(pool, system_vars) + add_variables_to_pool(pool, env_vars) + add_variables_to_pool(pool, conv_vars) return pool def _add_node_data_to_pool(self, pool: VariablePool, with_file=False): test_file = File( - tenant_id="test_tenant_id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_related_id", @@ -273,7 +271,7 @@ class TestVariablePoolSerialization: pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) def test_system_variables(self): - sys_vars = SystemVariable( + sys_vars = build_system_variables( user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id", @@ -282,24 +280,21 @@ class TestVariablePoolSerialization: conversation_id="test_conv_id", dialogue_count=5, ) - pool = VariablePool(system_variables=sys_vars) + pool = VariablePool() + add_variables_to_pool(pool, sys_vars) json = pool.model_dump_json() pool2 = VariablePool.model_validate_json(json) - assert pool2.system_variables == sys_vars + assert pool2.variable_dictionary == pool.variable_dictionary for mode in ["json", "python"]: dict_ = pool.model_dump(mode=mode) pool2 = VariablePool.model_validate(dict_) - assert pool2.system_variables == sys_vars + assert pool2.variable_dictionary == pool.variable_dictionary def test_pool_without_file_vars(self): pool = self._create_pool_without_file() json = pool.model_dump_json() pool2 = pool.model_validate_json(json) - assert pool2.system_variables == pool.system_variables - assert pool2.conversation_variables == pool.conversation_variables - assert pool2.environment_variables == pool.environment_variables - assert pool2.user_inputs == pool.user_inputs assert pool2.variable_dictionary == pool.variable_dictionary assert pool2 == pool @@ -314,10 +309,6 @@ class TestVariablePoolSerialization: # Verify serialized data structure assert isinstance(serialized_data, dict) - assert "system_variables" in serialized_data - assert "user_inputs" in serialized_data - assert "environment_variables" in serialized_data - assert "conversation_variables" in serialized_data assert "variable_dictionary" in serialized_data # Deserialize back using Pydantic's model_validate() @@ -365,17 +356,7 @@ class TestVariablePoolSerialization: def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool): """Assert that two VariablePools contain equivalent data""" - # Compare system variables - assert pool1.system_variables == pool2.system_variables - - # Compare user inputs - assert dict(pool1.user_inputs) == dict(pool2.user_inputs) - - # Compare environment variables count - assert pool1.environment_variables == pool2.environment_variables - - # Compare conversation variables count - assert pool1.conversation_variables == pool2.conversation_variables + assert pool1.variable_dictionary == pool2.variable_dictionary # Test key variable retrievals to ensure functionality is preserved test_selectors = [ @@ -406,13 +387,18 @@ class TestVariablePoolSerialization: def test_variable_pool_deserialization_default_dict(self): variable_pool = VariablePool( - user_inputs={"a": 1, "b": "2"}, - system_variables=SystemVariable(workflow_id=str(uuid.uuid4())), - environment_variables=[ - StringVariable(name="str_var", value="a"), - ], - conversation_variables=[IntegerVariable(name="int_var", value=1)], + variable_dictionary=defaultdict(dict), ) + add_variables_to_pool(variable_pool, build_system_variables(workflow_id=str(uuid.uuid4()))) + add_variables_to_pool( + variable_pool, + [ + StringVariable(name="str_var", value="a", selector=[ENVIRONMENT_VARIABLE_NODE_ID, "str_var"]), + IntegerVariable(name="int_var", value=1, selector=[CONVERSATION_VARIABLE_NODE_ID, "int_var"]), + ], + ) + variable_pool.add(["start", "a"], 1) + variable_pool.add(["start", "b"], "2") assert isinstance(variable_pool.variable_dictionary, defaultdict) json = variable_pool.model_dump_json() loaded = VariablePool.model_validate_json(json) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 93ba7f3333..a04443e266 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -4,18 +4,18 @@ import pytest from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.constants import ( +from core.workflow.system_variables import build_system_variables, default_system_variables +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) +from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.file.enums import FileType from dify_graph.file.models import File, FileTransferMethod from dify_graph.nodes.code.code_node import CodeNode from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.variables import StringVariable @@ -56,7 +56,7 @@ class TestWorkflowEntry: """Test mapping system variables from user inputs to variable pool.""" # Initialize variable pool with system variables variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id", @@ -128,7 +128,7 @@ class TestWorkflowEntry: return NodeConfigDictAdapter.validate_python(node_config) workflow = StubWorkflow() - variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={}) + variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={}) expected_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, @@ -158,7 +158,7 @@ class TestWorkflowEntry: # Initialize variable pool with environment variables env_var = StringVariable(name="API_KEY", value="existing_key") variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), environment_variables=[env_var], user_inputs={}, ) @@ -199,7 +199,7 @@ class TestWorkflowEntry: # Initialize variable pool with conversation variables conv_var = StringVariable(name="last_message", value="Hello") variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), conversation_variables=[conv_var], user_inputs={}, ) @@ -240,7 +240,7 @@ class TestWorkflowEntry: """Test mapping regular node variables from user inputs to variable pool.""" # Initialize empty variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) @@ -282,7 +282,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_file_handling(self): """Test mapping file inputs from user inputs to variable pool.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) @@ -341,7 +341,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_missing_variable_error(self): """Test that mapping raises error when required variable is missing.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) @@ -367,7 +367,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_alternative_key_format(self): """Test mapping with alternative key format (without node prefix).""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) @@ -397,7 +397,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_complex_selectors(self): """Test mapping with complex node variable keys.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) @@ -433,7 +433,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_invalid_node_variable(self): """Test that mapping handles invalid node variable format.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) @@ -464,7 +464,7 @@ class TestWorkflowEntry: conv_var = StringVariable(name="session_id", value="session123") variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="test_user", app_id="test_app", query="initial query", diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index dc4c7a00c5..4b3f604574 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -7,6 +7,7 @@ import pytest from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow import workflow_entry +from core.workflow.variable_prefixes import CHILD_SYNC_VARIABLE_NODE_IDS from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.errors import WorkflowNodeRunFailedError @@ -224,6 +225,7 @@ class TestWorkflowEntrySingleStepRun: patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool, patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, patch.object( workflow_entry.WorkflowEntry, @@ -260,6 +262,11 @@ class TestWorkflowEntrySingleStepRun: variable_mapping={}, user_inputs={"question": "hello"}, ) + add_node_inputs_to_pool.assert_called_once_with( + sentinel.variable_pool, + node_id="node-id", + inputs={"question": "hello"}, + ) mapping_user_inputs_to_variable_pool.assert_called_once_with( variable_mapping={}, user_inputs={"question": "hello"}, @@ -286,6 +293,7 @@ class TestWorkflowEntrySingleStepRun: patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool, patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, patch.object( workflow_entry.WorkflowEntry, @@ -317,6 +325,11 @@ class TestWorkflowEntrySingleStepRun: assert node.id == "node-id" assert list(generator) == ["event"] load_into_variable_pool.assert_called_once() + add_node_inputs_to_pool.assert_called_once_with( + sentinel.variable_pool, + node_id="node-id", + inputs={"question": "hello"}, + ) mapping_user_inputs_to_variable_pool.assert_not_called() def test_wraps_traced_node_run_failures(self): @@ -339,6 +352,7 @@ class TestWorkflowEntrySingleStepRun: patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "add_node_inputs_to_pool"), patch.object(workflow_entry, "load_into_variable_pool"), patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), patch.object( @@ -438,8 +452,9 @@ class TestWorkflowEntryHelpers: ) with ( - patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables), patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls, + patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool, patch.object( workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params ) as graph_init_params, @@ -469,11 +484,8 @@ class TestWorkflowEntryHelpers: assert node.id == "node-id" assert list(generator) == ["event"] - variable_pool_cls.assert_called_once_with( - system_variables=sentinel.system_variables, - user_inputs={}, - environment_variables=[], - ) + variable_pool_cls.assert_called_once_with() + add_variables_to_pool.assert_called_once_with(sentinel.variable_pool, sentinel.system_variables) build_dify_run_context.assert_called_once_with( tenant_id="tenant-id", app_id="", @@ -488,6 +500,7 @@ class TestWorkflowEntryHelpers: ), run_context={"_dify": "context"}, call_depth=0, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) dify_node_factory_cls.assert_called_once_with( graph_init_params=sentinel.graph_init_params, @@ -524,8 +537,9 @@ class TestWorkflowEntryHelpers: ) with ( - patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables), patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool), + patch.object(workflow_entry, "add_variables_to_pool"), patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), @@ -548,7 +562,6 @@ class TestWorkflowEntryHelpers: def test_handle_special_values_serializes_nested_files(self): file = File( - tenant_id="tenant-id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.png", diff --git a/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py index 1ce8765a3b..b36296d257 100644 --- a/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py @@ -121,9 +121,9 @@ class TestEncoders: assert jsonable_encoder(Decimal("1.000")) == "1.000" def test_jsonable_encoder_dict(self): - d = {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} - assert jsonable_encoder(d) == {"a": 1, "b": [2, 3]} - assert jsonable_encoder(d, sqlalchemy_safe=False) == {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + d = {"a": 1, "b": [2, 3], "_private": "hidden"} + assert jsonable_encoder(d) == {"a": 1, "b": [2, 3], "_private": "hidden"} + assert jsonable_encoder(d, excluded_key_prefixes=("_",)) == {"a": 1, "b": [2, 3]} d_with_none = {"a": 1, "b": None} assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 601f2c5e3a..98d77da725 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import pytest from httpx import Response +from core.workflow.file_reference import build_file_reference, resolve_file_record_id from factories.file_factory import ( File, FileTransferMethod, @@ -99,7 +100,37 @@ def test_build_from_mapping_backward_compatibility(mock_upload_file): assert isinstance(file, File) assert file.transfer_method == FileTransferMethod.LOCAL_FILE assert file.type == FileType.IMAGE - assert file.related_id == TEST_UPLOAD_FILE_ID + assert resolve_file_record_id(file.reference) == TEST_UPLOAD_FILE_ID + + +def test_build_from_mapping_accepts_opaque_reference_for_local_file(mock_upload_file): + mapping = { + "transfer_method": "local_file", + "reference": build_file_reference(record_id=TEST_UPLOAD_FILE_ID, storage_key="test_key"), + "type": "image", + } + + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + assert isinstance(file, File) + assert file.transfer_method == FileTransferMethod.LOCAL_FILE + assert file.type == FileType.IMAGE + assert resolve_file_record_id(file.reference) == TEST_UPLOAD_FILE_ID + + +def test_build_from_mapping_accepts_opaque_related_id_for_tool_file(mock_tool_file): + mapping = { + "transfer_method": "tool_file", + "related_id": build_file_reference(record_id=TEST_TOOL_FILE_ID, storage_key="tool_file.pdf"), + "type": "document", + } + + file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + assert isinstance(file, File) + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.type == FileType.DOCUMENT + assert resolve_file_record_id(file.reference) == TEST_TOOL_FILE_ID @pytest.mark.parametrize( diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index ce6b9232ce..f63f523389 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -227,7 +227,6 @@ def test_build_segment_array_file_single_file(): """Test building ArrayFileSegment from list with single file.""" file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file.png", @@ -247,7 +246,6 @@ def test_build_segment_array_file_multiple_files(): """Test building ArrayFileSegment from list with multiple files.""" file1 = File( id="test_file_id_1", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file1.png", @@ -258,7 +256,6 @@ def test_build_segment_array_file_multiple_files(): ) file2 = File( id="test_file_id_2", - tenant_id="test_tenant_id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_relation_id", @@ -305,7 +302,6 @@ def test_build_segment_array_any_mixed_with_files(): """Test building ArrayAnySegment from list with files and other types.""" file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file.png", @@ -334,7 +330,6 @@ def test_build_segment_array_file_properties(): """Test ArrayFileSegment properties and methods.""" file1 = File( id="test_file_id_1", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file1.png", @@ -345,7 +340,6 @@ def test_build_segment_array_file_properties(): ) file2 = File( id="test_file_id_2", - tenant_id="test_tenant_id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file2.txt", @@ -394,7 +388,6 @@ def test_build_segment_file_array_with_different_file_types(): """Test ArrayFileSegment with different file types.""" image_file = File( id="image_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/image.png", @@ -406,7 +399,6 @@ def test_build_segment_file_array_with_different_file_types(): video_file = File( id="video_id", - tenant_id="test_tenant_id", type=FileType.VIDEO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="video_relation_id", @@ -418,7 +410,6 @@ def test_build_segment_file_array_with_different_file_types(): audio_file = File( id="audio_id", - tenant_id="test_tenant_id", type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="audio_relation_id", @@ -456,7 +447,6 @@ def _generate_file(draw) -> File: url = "https://test.example.com/test-file" file = File( id="test_file_id", - tenant_id="test_tenant_id", type=file_type, transfer_method=transfer_method, remote_url=url, @@ -471,7 +461,6 @@ def _generate_file(draw) -> File: file = File( id="test_file_id", - tenant_id="test_tenant_id", type=file_type, transfer_method=transfer_method, related_id=str(relation_id), @@ -519,7 +508,6 @@ def test_build_segment_type_for_scalar(): file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file.png", @@ -576,7 +564,6 @@ class TestBuildSegmentWithType: """Test building a file segment with correct type.""" test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://test.example.com/test-file.png", diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index ef29b26a7a..ab6b4dc3c3 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -219,7 +219,6 @@ class TestWorkflowDraftVariableGetValue: tenant_id = "test_tenant_id" test_file = File( - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/example.jpg", @@ -292,7 +291,6 @@ class TestWorkflowDraftVariableGetValue: # Create a File with specific field values test_file = File( id="test_file_id", - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", @@ -314,7 +312,6 @@ class TestWorkflowDraftVariableGetValue: # Verify all important fields are preserved assert retrieved_file.id == test_file.id - assert retrieved_file.tenant_id == test_file.tenant_id assert retrieved_file.type == test_file.type assert retrieved_file.transfer_method == test_file.transfer_method assert retrieved_file.remote_url == test_file.remote_url diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index 086d1ac52e..4c3a0daedf 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -13,6 +13,7 @@ from pytest_mock import MockerFixture from sqlalchemy.orm import Session, sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from dify_graph.entities import ( WorkflowNodeExecution, ) @@ -22,7 +23,6 @@ from dify_graph.enums import ( WorkflowNodeExecutionStatus, ) from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models.account import Account, Tenant from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom @@ -175,8 +175,8 @@ def test_save_with_existing_tenant_id(repository, session): session_obj.commit.assert_called_once() -def test_get_by_workflow_run(repository, session, mocker: MockerFixture): - """Test get_by_workflow_run method.""" +def test_get_by_workflow_execution(repository, session, mocker: MockerFixture): + """Test get_by_workflow_execution method.""" session_obj, _ = session # Set up mock mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") @@ -206,7 +206,10 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture): # Call method order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) + result = repository.get_by_workflow_execution( + workflow_execution_id="test-workflow-run-id", + order_config=order_config, + ) # Assert select was called with correct parameters mock_select.assert_called_once() diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index a23c44b26e..82a8deec68 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy.engine import Engine from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -36,7 +36,11 @@ def mock_db(monkeypatch): def _make_valid_email_config(): - return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + return EmailDeliveryConfig( + recipients=EmailRecipients(include_bound_group=False, items=[]), + subject="Subj", + body="Body", + ) def test_build_form_link(): @@ -252,7 +256,9 @@ class TestEmailDeliveryTestHandler: # Test Case 1: External Recipient method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + recipients=EmailRecipients( + items=[ExternalRecipient(email="ext@example.com")], include_bound_group=False + ), subject="", body="", ) @@ -262,7 +268,7 @@ class TestEmailDeliveryTestHandler: # Test Case 2: Member Recipient method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + recipients=EmailRecipients(items=[MemberRecipient(reference_id="u1")], include_bound_group=False), subject="", body="", ) @@ -272,7 +278,11 @@ class TestEmailDeliveryTestHandler: # Test Case 3: Whole Workspace method = EmailDeliveryMethod( - config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[], include_bound_group=True), + subject="", + body="", + ) ) handler._query_workspace_member_emails = MagicMock( return_value={"u1": "u1@example.com", "u2": "u2@example.com"} diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index c703ab64d0..2564d0e0dd 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -43,7 +43,6 @@ from services.variable_truncator import ( def file() -> File: return File( id=str(uuid4()), # Generate new UUID for File.id - tenant_id=str(uuid.uuid4()), type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id=str(uuid.uuid4()), diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 0c2be9c79f..0bd9478557 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -7,8 +7,9 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import BuiltinNodeTypes from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import SegmentType from libs.uuid_utils import uuidv7 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index c890ab6a65..294b8af59d 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -5,16 +5,16 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.orm import sessionmaker -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.human_input.entities import HumanInputNodeData from services import workflow_service as workflow_service_module from services.workflow_service import WorkflowService @@ -23,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: +def _build_node_config(delivery_methods: list[EmailDeliveryMethod], *, legacy: bool = False) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -31,6 +31,14 @@ def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfi inputs=[], user_actions=[], ).model_dump(mode="json") + if legacy: + for delivery_method in node_data["delivery_methods"]: + recipients = delivery_method.get("config", {}).get("recipients", {}) + if "include_bound_group" in recipients: + recipients["whole_workspace"] = recipients.pop("include_bound_group") + for recipient in recipients.get("items", []): + if "reference_id" in recipient: + recipient["user_id"] = recipient.pop("reference_id") node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) @@ -41,7 +49,7 @@ def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailD enabled=enabled, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="tester@example.com")], ), subject="Test subject", @@ -69,7 +77,7 @@ def test_human_input_delivery_requires_draft_workflow(): def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=False) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -105,7 +113,7 @@ def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyP def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -144,7 +152,7 @@ def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.Mon def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True, debug_mode=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -178,8 +186,8 @@ def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytes sent_method = test_service_instance.send_test.call_args.kwargs["method"] assert isinstance(sent_method, EmailDeliveryMethod) assert sent_method.config.debug_mode is True - assert sent_method.config.recipients.whole_workspace is False + assert sent_method.config.recipients.include_bound_group is False assert len(sent_method.config.recipients.items) == 1 recipient = sent_method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == account.id + assert recipient.reference_id == account.id diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index 1f0bf8ef37..60513d17fb 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -2,7 +2,11 @@ from collections.abc import Mapping from typing import Any from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import CHILD_SYNC_VARIABLE_NODE_IDS from dify_graph.entities.graph_init_params import GraphInitParams +from dify_graph.runtime import VariablePool +from dify_graph.variables.variables import Variable def build_test_run_context( @@ -50,4 +54,18 @@ def build_test_graph_init_params( extra_context=extra_context, ), call_depth=call_depth, + child_sync_variable_node_ids=CHILD_SYNC_VARIABLE_NODE_IDS, ) + + +def build_test_variable_pool( + *, + variables: list[Variable] | tuple[Variable, ...] = (), + node_id: str | None = None, + inputs: Mapping[str, Any] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + if node_id is not None and inputs is not None: + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=inputs) + return variable_pool