diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 455656b396..96b3c2b86b 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -104,7 +104,8 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - parsed_reference = parse_file_reference(file.reference) + reference = getattr(file, "reference", None) or getattr(file, "related_id", None) + parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None if parsed_reference is None: raise ValueError("datasource file is missing reference") url = cls.get_datasource_file_url( diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 06b35dffcd..ab47facdab 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -12,7 +12,7 @@ from core.provider_manager import ProviderManager from dify_graph.model_runtime.callbacks.base_callback import Callback from dify_graph.model_runtime.entities.llm_entities import LLMResult from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType from dify_graph.model_runtime.entities.rerank_entities import RerankResult from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError @@ -50,6 +50,13 @@ class ModelInstance: credentials=self.credentials, ) + def get_model_schema(self) -> AIModelEntity: + """Return the resolved schema for the current model instance.""" + model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for {self.model_name}") + return model_schema + @staticmethod def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py index 112304e0ad..41b24635d9 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_compat.py @@ -12,7 +12,10 @@ import uuid from collections.abc import Mapping, Sequence from typing import Annotated, Any, ClassVar, Literal -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +import bleach +import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @@ -56,12 +59,39 @@ EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminat class EmailRecipients(BaseModel): model_config = ConfigDict(extra="forbid") - include_bound_group: bool = False + include_bound_group: bool = Field( + default=False, + validation_alias=AliasChoices("include_bound_group", "whole_workspace"), + ) items: list[EmailRecipient] = Field(default_factory=list) class EmailDeliveryConfig(BaseModel): URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "br", + "code", + "em", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[set[str]] = set(bleach.sanitizer.ALLOWED_PROTOCOLS) | {"mailto"} recipients: EmailRecipients subject: str @@ -88,6 +118,28 @@ class EmailDeliveryConfig(BaseModel): return templated_body return variable_pool.convert_template(templated_body).text + @classmethod + def render_markdown_body(cls, body: str) -> str: + stripped_body = bleach.clean(body, tags=[], attributes={}, strip=True) + rendered = markdown.markdown( + stripped_body, + extensions=[TableExtension(use_align_attribute=True)], + output_format="html", + ) + return bleach.clean( + rendered, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + ) + + @staticmethod + def sanitize_subject(subject: str) -> str: + sanitized = subject.replace("\r", " ").replace("\n", " ") + sanitized = bleach.clean(sanitized, tags=[], strip=True) + return " ".join(sanitized.split()) + class _DeliveryMethodBase(BaseModel): enabled: bool = True diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 608fa157ed..db94ea0693 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -312,12 +312,6 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) - def _dify_invoke_source(self) -> str: - invoke_from = self._dify_context.invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(invoke_from.value) - def _conversation_id(self) -> str | None: return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) @@ -355,6 +349,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { "runtime": self._human_input_runtime, + "form_repository": self._human_input_runtime.build_form_repository(), }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 94a431a9a3..9828da6924 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -15,7 +15,11 @@ from core.model_manager import ModelInstance from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError from core.plugin.impl.plugin import PluginInstaller from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine @@ -43,6 +47,7 @@ from dify_graph.nodes.llm.runtime_protocols import ( ) from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.runtime import ( + HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol, ToolNodeRuntimeProtocol, ) @@ -360,7 +365,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): provider_name: str, ) -> Generator[ToolRuntimeMessage, None, None]: runtime_binding = self._binding_from_handle(tool_runtime) - tool = cast("Tool", runtime_binding.tool) + tool = runtime_binding.tool callback = DifyWorkflowCallbackHandler() try: @@ -432,7 +437,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): @staticmethod def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: - return cast("Tool", DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool) + return DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool @staticmethod def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: @@ -561,9 +566,11 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): run_context: Mapping[str, Any] | DifyRunContext, *, workflow_execution_id_getter: Callable[[], str | None] | None = None, + form_repository: HumanInputFormRepository | None = None, ) -> None: self._run_context = resolve_dify_run_context(run_context) self._workflow_execution_id_getter = workflow_execution_id_getter + self._form_repository = form_repository def _invoke_source(self) -> str: invoke_from = self._run_context.invoke_from @@ -590,7 +597,10 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): return True return is_human_input_webapp_enabled(node_data) - def _build_form_repository(self) -> HumanInputFormRepositoryImpl: + def build_form_repository(self) -> HumanInputFormRepository: + if self._form_repository is not None: + return self._form_repository + invoke_source = self._invoke_source() return HumanInputFormRepositoryImpl( tenant_id=self._run_context.tenant_id, @@ -600,8 +610,15 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, ) - def get_form(self, *, node_id: str): - repo = self._build_form_repository() + def with_form_repository(self, form_repository: HumanInputFormRepository) -> DifyHumanInputNodeRuntime: + return DifyHumanInputNodeRuntime( + self._run_context, + workflow_execution_id_getter=self._workflow_execution_id_getter, + form_repository=form_repository, + ) + + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + repo = self.build_form_repository() return repo.get_form(node_id) def create_form( @@ -611,8 +628,8 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): node_data: HumanInputNodeData, rendered_content: str, resolved_default_values: Mapping[str, Any], - ): - repo = self._build_form_repository() + ) -> HumanInputFormStateProtocol: + repo = self.build_form_repository() params = FormCreateParams( workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, node_id=node_id, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 60b9ea92f1..01f77db4f0 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) class _WorkflowChildEngineBuilder: - def __init__(self, execution_context: AbstractContextManager[object] | None) -> None: + def __init__(self, execution_context: AbstractContextManager[object] | None = None) -> None: self._execution_context = execution_context @staticmethod @@ -88,15 +88,27 @@ class _WorkflowChildEngineBuilder: root_node_id=root_node_id, ) - child_engine = GraphEngine( - workflow_id=workflow_id, - graph=child_graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - child_engine_builder=self, - execution_context=self._execution_context, - ) + command_channel = InMemoryChannel() + config = GraphEngineConfig() + if self._execution_context is None: + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=command_channel, + config=config, + child_engine_builder=self, + ) + else: + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=command_channel, + config=config, + child_engine_builder=self, + execution_context=self._execution_context, + ) child_engine.layer(LLMQuotaLayer()) for layer in layers: child_engine.layer(cast(GraphEngineLayer, layer)) diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py index cc31f98b8b..76af06b67e 100644 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ b/api/dify_graph/model_runtime/callbacks/base_callback.py @@ -33,6 +33,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -46,6 +47,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response + :param user: optional end-user identifier for the invocation :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -62,6 +64,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -76,6 +79,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response + :param user: optional end-user identifier for the invocation :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -92,6 +96,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -106,6 +111,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response + :param user: optional end-user identifier for the invocation :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -122,6 +128,7 @@ class Callback(ABC): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -136,6 +143,7 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response + :param user: optional end-user identifier for the invocation :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py index 29c57ed3ab..03ab9534f7 100644 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ b/api/dify_graph/model_runtime/callbacks/logging_callback.py @@ -23,6 +23,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -36,6 +37,7 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response + :param user: optional end-user identifier for the invocation :param invocation_context: opaque request metadata for the current invocation """ self.print_text("\n[on_llm_before_invoke]\n", color="blue") @@ -53,6 +55,8 @@ class LoggingCallback(Callback): self.print_text(f"\t\t{tool.name}\n", color="blue") self.print_text(f"Stream: {stream}\n", color="blue") + if user: + self.print_text(f"User: {user}\n", color="blue") if invocation_context: self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") @@ -79,6 +83,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -95,7 +100,7 @@ class LoggingCallback(Callback): :param stream: is stream response :param invocation_context: opaque request metadata for the current invocation """ - _ = invocation_context + _ = user, invocation_context sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() @@ -110,6 +115,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -126,7 +132,7 @@ class LoggingCallback(Callback): :param stream: is stream response :param invocation_context: opaque request metadata for the current invocation """ - _ = invocation_context + _ = user, invocation_context self.print_text("\n[on_llm_after_invoke]\n", color="yellow") self.print_text(f"Content: {result.message.content}\n", color="yellow") @@ -152,6 +158,7 @@ class LoggingCallback(Callback): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, + user: str | None = None, invocation_context: Mapping[str, object] | None = None, ): """ @@ -168,6 +175,6 @@ class LoggingCallback(Callback): :param stream: is stream response :param invocation_context: opaque request metadata for the current invocation """ - _ = invocation_context + _ = user, invocation_context self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 08449e5b20..d08ca181e3 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -2,7 +2,7 @@ import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired @@ -61,6 +61,7 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", runtime: HumanInputNodeRuntimeProtocol | None = None, + form_repository: object | None = None, ) -> None: super().__init__( id=id, @@ -68,9 +69,14 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - if runtime is None: + resolved_runtime = runtime + if resolved_runtime is None: raise ValueError("runtime is required") - self._runtime = runtime + if form_repository is not None: + with_form_repository = getattr(resolved_runtime, "with_form_repository", None) + if callable(with_form_repository): + resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) + self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime @classmethod def version(cls) -> str: diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py index 5ae2362621..e5df51c978 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -99,6 +99,8 @@ class FileSaverImpl(LLMFileSaver): "mime_type": mime_type, "size": len(data), "tool_file_id": str(tool_file.id), + "related_id": str(tool_file.id), + "storage_key": tool_file.file_key, } ) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 457d8ad9e4..08868c73d0 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -194,10 +194,11 @@ class LLMNode(Node[LLMNodeData]): generator = self._fetch_context(node_data=self.node_data) context = None context_files: list[File] = [] - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event + if generator is not None: + for event in generator: + context = event.context + context_files = event.context_files or [] + yield event if context: node_inputs["#context#"] = context diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 00f50ad87d..de7903d9d0 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -804,7 +804,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_messages, _ = LLMNode.fetch_prompt_messages( sys_query="", sys_files=files, - context=context, + context=context or "", memory=None, model_instance=model_instance, prompt_template=prompt_template, diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index b8a13dea28..2547a90d95 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -291,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): prompt_template=prompt_template, sys_query="", sys_files=[], - context=context, + context=context or "", memory=None, model_instance=model_instance, stop=model_instance.stop, diff --git a/api/dify_graph/nodes/runtime.py b/api/dify_graph/nodes/runtime.py index 0c9b6493ba..528592e025 100644 --- a/api/dify_graph/nodes/runtime.py +++ b/api/dify_graph/nodes/runtime.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence +from datetime import datetime from typing import TYPE_CHECKING, Any, Protocol from dify_graph.model_runtime.entities.llm_entities import LLMUsage @@ -102,4 +103,4 @@ class HumanInputFormStateProtocol(Protocol): def status(self) -> HumanInputFormStatus: ... @property - def expiration_time(self): ... + def expiration_time(self) -> datetime: ... diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py index cb45dd7e08..813d82b61e 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict @@ -8,7 +8,7 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompleted from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase +from dify_graph.variables import SegmentType, Variable, VariableBase from .node_data import VariableAssignerData, WriteMode @@ -90,7 +90,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - yield VariableUpdatedEvent(variable=updated_variable) + yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py index dad7a391cd..08d86613e6 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator, Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus @@ -8,7 +8,7 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompleted from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase +from dify_graph.variables import SegmentType, Variable, VariableBase from dify_graph.variables.consts import SELECTORS_LENGTH from . import helpers @@ -208,7 +208,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) - yield VariableUpdatedEvent(variable=variable) + yield VariableUpdatedEvent(variable=cast(Variable, variable)) yield StreamCompletedEvent( node_run_result=NodeRunResult( diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py index b58359c005..7d1f72650f 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -6,13 +6,13 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Annotated, Any, Union, cast -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from dify_graph.file import File, FileAttribute, file_manager from dify_graph.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable from dify_graph.variables.consts import SELECTORS_LENGTH from dify_graph.variables.segments import FileSegment, ObjectSegment -from dify_graph.variables.variables import Variable +from dify_graph.variables.variables import RAGPipelineVariableInput, Variable VariableValue = Union[str, int, float, dict[str, object], list[object], File] @@ -20,14 +20,66 @@ VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9 class VariablePool(BaseModel): + _SYSTEM_VARIABLE_NODE_ID = "sys" + _ENVIRONMENT_VARIABLE_NODE_ID = "env" + _CONVERSATION_VARIABLE_NODE_ID = "conversation" + _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" + # Variable dictionary is a dictionary for looking up variables by their selector. # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", - default=defaultdict(dict), + default_factory=lambda: defaultdict(dict), ) + system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) + user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) + + @model_validator(mode="after") + def _load_legacy_bootstrap_inputs(self) -> VariablePool: + """ + Accept legacy constructor kwargs that still appear throughout the workflow + layer while keeping serialized state focused on `variable_dictionary`. + """ + + self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) + self._ingest_legacy_rag_variables(self.rag_pipeline_variables) + + # These kwargs are accepted for compatibility but should not affect the + # stable serialized form or model equality. + self.system_variables = () + self.environment_variables = () + self.conversation_variables = () + self.rag_pipeline_variables = () + self.user_inputs = {} + return self + + def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: + for variable in variables: + selector = [node_id, variable.name] + normalized_variable = variable + if list(variable.selector) != selector: + normalized_variable = variable.model_copy(update={"selector": selector}) + self.add(normalized_variable.selector, normalized_variable) + + def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: + if not rag_pipeline_variables: + return + + values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_variable_input in rag_pipeline_variables: + values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( + rag_variable_input.value + ) + + for node_id, value in values_by_node_id.items(): + self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) def add(self, selector: Sequence[str], value: Any, /): """ diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 49b5886b07..8cdce0b5b2 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import Session from werkzeug.http import parse_options_header from core.helper import ssrf_proxy -from core.workflow.file_reference import build_file_reference, resolve_file_record_id +from core.workflow.file_reference import build_file_reference, parse_file_reference, resolve_file_record_id from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers from dify_graph.file.file_factory import standardize_file_type from extensions.ext_database import db diff --git a/api/models/model.py b/api/models/model.py index 62a2a76c81..d2490b2223 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -57,7 +57,7 @@ def _resolve_app_tenant_id(app_id: str) -> str: resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) if not resolved_tenant_id: raise ValueError(f"Unable to resolve tenant_id for app {app_id}") - return cast(str, resolved_tenant_id) + return resolved_tenant_id def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index c50f8c898a..d63820c8c8 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import Any, cast +from typing import Any from core.workflow.file_reference import parse_file_reference from dify_graph.file import File, FileTransferMethod @@ -61,4 +61,4 @@ def build_file_from_input_mapping( mapping["upload_file_id"] = record_id tenant_id = resolve_file_mapping_tenant_id(file_mapping=mapping, tenant_resolver=tenant_resolver) - return cast(File, file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)) + return file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 4abbf8cbc6..cf127c753c 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -898,7 +898,6 @@ class WorkflowService: node_id=node_id, node_title=node.title, resolved_default_values=resolved_default_values, - form_token=None, ) return human_input_required.model_dump(mode="json") diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 83601dc1b9..1095e79da8 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -67,7 +67,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte actions=[UserAction(id="approve", title="Approve")], node_id="node-1", node_title="Ask Name", - form_token="backstage-token", ) pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) @@ -78,6 +77,11 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte "create_api_workflow_run_repository", lambda *_, **__: repo, ) + monkeypatch.setattr( + workflow_run_module, + "_load_form_tokens_by_form_id", + lambda _form_ids: {"form-1": "backstage-token"}, + ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 83a6e0f231..c7cb5f142d 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -137,7 +137,6 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No actions=[], node_id="node-1", node_title="Approval", - form_token="token-1", resolved_default_values={}, ) event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 206abd36a1..3c5304f728 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -74,7 +74,6 @@ def test_graph_run_paused_event_emits_queue_pause_event(): actions=[], node_id="node-human", node_title="Human Step", - form_token="tok", ) event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) workflow_entry = SimpleNamespace( diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index b103111119..bf6dbd7e2d 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -423,12 +423,8 @@ class TestDifyNodeFactoryCreateNode: if constructor_name == "HumanInputNode": form_repository = sentinel.form_repository - form_repository_impl = MagicMock(return_value=form_repository) - monkeypatch.setattr( - node_factory, - "HumanInputFormRepositoryImpl", - form_repository_impl, - ) + factory._human_input_runtime = MagicMock() + factory._human_input_runtime.build_form_repository.return_value = form_repository node_config = {"id": "node-id", "data": {"type": node_type}} result = factory.create_node(node_config) @@ -454,14 +450,8 @@ class TestDifyNodeFactoryCreateNode: assert kwargs["file_reference_factory"] is sentinel.file_reference_factory elif constructor_name == "HumanInputNode": assert kwargs["form_repository"] is form_repository - assert kwargs["runtime"] is sentinel.human_input_runtime - form_repository_impl.assert_called_once_with( - tenant_id="tenant-id", - app_id="app-id", - workflow_execution_id=None, - invoke_source=InvokeFrom.DEBUGGER, - submission_actor_id="user-id", - ) + assert kwargs["runtime"] is factory._human_input_runtime + factory._human_input_runtime.build_form_repository.assert_called_once_with() elif constructor_name == "DocumentExtractorNode": assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config assert kwargs["http_client"] is sentinel.http_client diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index a219068168..61681204f8 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -107,6 +107,26 @@ class TestVariablePool: assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None + def test_constructor_loads_legacy_bootstrap_kwargs(self): + pool = VariablePool( + system_variables=build_system_variables(user_id="test_user_id"), + environment_variables=[StringVariable(name="env_var", value="env-value")], + conversation_variables=[StringVariable(name="conv_var", value="conv-value")], + user_inputs={"ignored": "value"}, + ) + + system_value = pool.get([SYSTEM_VARIABLE_NODE_ID, "user_id"]) + environment_value = pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var"]) + conversation_value = pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var"]) + + assert system_value is not None + assert system_value.value == "test_user_id" + assert environment_value is not None + assert environment_value.value == "env-value" + assert conversation_value is not None + assert conversation_value.value == "conv-value" + assert "system_variables" not in pool.model_dump() + def test_get_system_variables(self): sys_var = build_system_variables( user_id="test_user_id", diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 4b3f604574..26b419b330 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -129,6 +129,7 @@ class TestWorkflowEntryInit: patch.object(workflow_entry.dify_config, "DEBUG", True), patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False), patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True), + patch.object(workflow_entry, "capture_current_context", return_value=sentinel.execution_context), patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls, patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), @@ -164,6 +165,7 @@ class TestWorkflowEntryInit: command_channel=sentinel.command_channel, config=sentinel.graph_engine_config, child_engine_builder=entry._child_engine_builder, + execution_context=sentinel.execution_context, ) debug_logging_layer.assert_called_once_with( level="DEBUG", diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 3707ed90be..7f0afa4daf 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -1,6 +1,5 @@ """Unit tests for non-SQL helper logic in workflow run repository.""" -import secrets from datetime import UTC, datetime from unittest.mock import Mock, patch @@ -9,7 +8,7 @@ import pytest from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowPause as WorkflowPauseModel from models.workflow import WorkflowPauseReason from repositories.sqlalchemy_api_workflow_run_repository import ( @@ -82,8 +81,8 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Test helper that builds HumanInputRequired pause reasons.""" - def test_prefers_backstage_token_when_available(self) -> None: - """Use backstage token when multiple recipient types may exist.""" + def test_builds_reason_from_form_definition(self) -> None: + """Build the graph pause reason from the stored form definition.""" # Arrange expiration_time = datetime.now(UTC) form_definition = FormDefinition( @@ -114,22 +113,14 @@ class TestBuildHumanInputRequiredReason: node_id="node-1", message="", ) - access_token = secrets.token_urlsafe(8) - backstage_recipient = HumanInputFormRecipient( - form_id="form-1", - delivery_id="delivery-1", - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) # Act - reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) + reason = _build_human_input_required_reason(reason_model, form_model) # Assert assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token assert reason.node_title == "Ask Name" assert reason.form_content == "content" assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" + assert reason.resolved_default_values == {"name": "Alice"}