From 13bf6547ee4127e2cfa9afcddc12ffff4f720bdd Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 27 Nov 2025 15:41:56 +0800 Subject: [PATCH] Refactor: centralize node data hydration (#27771) --- .../helper/code_executor/code_executor.py | 7 +- api/core/workflow/nodes/agent/agent_node.py | 25 +-- api/core/workflow/nodes/answer/answer_node.py | 26 +-- api/core/workflow/nodes/base/node.py | 160 +++++++++++++++--- api/core/workflow/nodes/code/code_node.py | 26 +-- .../nodes/datasource/datasource_node.py | 26 +-- .../workflow/nodes/document_extractor/node.py | 26 +-- api/core/workflow/nodes/end/end_node.py | 29 +--- api/core/workflow/nodes/http_request/node.py | 27 +-- .../nodes/human_input/human_input_node.py | 26 +-- .../workflow/nodes/if_else/if_else_node.py | 26 +-- .../nodes/iteration/iteration_node.py | 25 +-- .../nodes/iteration/iteration_start_node.py | 29 +--- .../knowledge_index/knowledge_index_node.py | 26 +-- .../knowledge_retrieval_node.py | 25 +-- api/core/workflow/nodes/list_operator/node.py | 28 +-- api/core/workflow/nodes/llm/node.py | 26 +-- api/core/workflow/nodes/loop/loop_end_node.py | 29 +--- api/core/workflow/nodes/loop/loop_node.py | 25 +-- .../workflow/nodes/loop/loop_start_node.py | 29 +--- api/core/workflow/nodes/node_factory.py | 10 +- .../parameter_extractor_node.py | 26 +-- .../question_classifier_node.py | 26 +-- api/core/workflow/nodes/start/start_node.py | 29 +--- .../template_transform_node.py | 26 +-- api/core/workflow/nodes/tool/tool_node.py | 25 +-- .../trigger_plugin/trigger_event_node.py | 29 +--- .../trigger_schedule/trigger_schedule_node.py | 29 +--- .../workflow/nodes/trigger_webhook/node.py | 28 +-- .../variable_aggregator_node.py | 27 +-- .../nodes/variable_assigner/v1/node.py | 26 +-- .../nodes/variable_assigner/v2/node.py | 26 +-- api/core/workflow/workflow_entry.py | 2 - .../workflow/nodes/test_code.py | 4 - .../workflow/nodes/test_http.py | 8 - .../workflow/nodes/test_llm.py | 4 - .../nodes/test_parameter_extractor.py | 1 - .../workflow/nodes/test_template_transform.py | 1 - .../workflow/nodes/test_tool.py | 1 - .../workflow/graph/test_graph_validation.py | 64 +++---- .../graph_engine/test_command_system.py | 6 +- .../test_human_input_pause_multi_branch.py | 5 - .../test_human_input_pause_single_branch.py | 4 - .../graph_engine/test_if_else_streaming.py | 5 - .../graph_engine/test_mock_factory.py | 3 - .../test_mock_iteration_simple.py | 2 + .../test_mock_nodes_template_code.py | 7 - .../core/workflow/nodes/answer/test_answer.py | 3 - .../workflow/nodes/base/test_base_node.py | 85 ++++++++++ .../core/workflow/nodes/llm/test_node.py | 4 - .../core/workflow/nodes/test_base_node.py | 74 ++++++++ .../nodes/test_document_extractor_node.py | 2 - .../core/workflow/nodes/test_if_else.py | 12 -- .../core/workflow/nodes/test_list_operator.py | 2 - .../workflow/nodes/tool/test_tool_node.py | 1 - .../v1/test_variable_assigner_v1.py | 9 - .../v2/test_variable_assigner_v2.py | 17 -- .../nodes/webhook/test_webhook_node.py | 1 - 58 files changed, 381 insertions(+), 899 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_base_node.py diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index f92278f9e2..73174ed28d 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -152,10 +152,5 @@ class CodeExecutor: raise CodeExecutionError(f"Unsupported language {language}") runner, preload = template_transformer.transform_caller(code, inputs) - - try: - response = cls.execute_code(language, preload, runner) - except CodeExecutionError as e: - raise e - + response = cls.execute_code(language, preload, runner) return template_transformer.transform_response(response) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 626ef1df7b..7248f9b1d5 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -26,7 +26,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayFileSegment, StringSegment from core.workflow.enums import ( - ErrorStrategy, NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, @@ -40,7 +39,6 @@ from core.workflow.node_events import ( StreamCompletedEvent, ) from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool @@ -66,7 +64,7 @@ if TYPE_CHECKING: from core.plugin.entities.request import InvokeCredentials -class AgentNode(Node): +class AgentNode(Node[AgentNodeData]): """ Agent Node """ @@ -74,27 +72,6 @@ class AgentNode(Node): node_type = NodeType.AGENT _node_data: AgentNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = AgentNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 86174c7ea6..0fe40db786 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -2,42 +2,20 @@ from collections.abc import Mapping, Sequence from typing import Any from core.variables import ArrayFileSegment, FileSegment, Segment -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.answer.entities import AnswerNodeData -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -class AnswerNode(Node): +class AnswerNode(Node[AnswerNodeData]): node_type = NodeType.ANSWER execution_type = NodeExecutionType.RESPONSE _node_data: AnswerNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = AnswerNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index eda030699a..bbdd3099da 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -2,7 +2,7 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod -from typing import Any, ClassVar +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom @@ -49,12 +49,121 @@ from models.enums import UserFrom from .entities import BaseNodeData, RetryConfig +NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) + logger = logging.getLogger(__name__) -class Node: +class Node(Generic[NodeDataT]): node_type: ClassVar["NodeType"] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE + _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData + + def __init_subclass__(cls, **kwargs: Any) -> None: + """ + Automatically extract and validate the node data type from the generic parameter. + + When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method: + 1. Inspects `__orig_bases__` to find the `Node[T]` parameterization + 2. Extracts `T` (e.g., `MyNodeData`) from the generic argument + 3. Validates that `T` is a proper `BaseNodeData` subclass + 4. Stores it in `_node_data_type` for automatic hydration in `__init__` + + This eliminates the need for subclasses to manually implement boilerplate + accessor methods like `_get_title()`, `_get_error_strategy()`, etc. + + How it works: + :: + + class CodeNode(Node[CodeNodeData]): + │ │ + │ └─────────────────────────────────┐ + │ │ + ▼ ▼ + ┌─────────────────────────────┐ ┌─────────────────────────────────┐ + │ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │ + │ Node[CodeNodeData], │ │ title: str │ + │ ) │ │ desc: str | None │ + └──────────────┬──────────────┘ │ ... │ + │ └─────────────────────────────────┘ + ▼ ▲ + ┌─────────────────────────────┐ │ + │ get_origin(base) -> Node │ │ + │ get_args(base) -> ( │ │ + │ CodeNodeData, │ ──────────────────────┘ + │ ) │ + └──────────────┬──────────────┘ + │ + ▼ + ┌─────────────────────────────┐ + │ Validate: │ + │ - Is it a type? │ + │ - Is it a BaseNodeData │ + │ subclass? │ + └──────────────┬──────────────┘ + │ + ▼ + ┌─────────────────────────────┐ + │ cls._node_data_type = │ + │ CodeNodeData │ + └─────────────────────────────┘ + + Later, in __init__: + :: + + config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate() + │ + ▼ + CodeNodeData instance + (stored in self._node_data) + + Example: + class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted + node_type = NodeType.CODE + # No need to implement _get_title, _get_error_strategy, etc. + """ + super().__init_subclass__(**kwargs) + + if cls is Node: + return + + node_data_type = cls._extract_node_data_type_from_generic() + + if node_data_type is None: + raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype") + + cls._node_data_type = node_data_type + + @classmethod + def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: + """ + Extract the node data type from the generic parameter `Node[T]`. + + Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`. + + Returns: + The extracted BaseNodeData subtype, or None if not found. + + Raises: + TypeError: If the generic argument is invalid (not exactly one argument, + or not a BaseNodeData subtype). + """ + # __orig_bases__ contains the original generic bases before type erasure. + # For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`. + for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined] + origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]` + if origin is Node: + args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]` + if len(args) != 1: + raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument") + + candidate = args[0] + if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData): + raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype") + + return candidate + + return None def __init__( self, @@ -63,6 +172,7 @@ class Node: graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: + self._graph_init_params = graph_init_params self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id @@ -83,8 +193,24 @@ class Node: self._node_execution_id: str = "" self._start_at = naive_utc_now() - @abstractmethod - def init_node_data(self, data: Mapping[str, Any]) -> None: ... + raw_node_data = config.get("data") or {} + if not isinstance(raw_node_data, Mapping): + raise ValueError("Node config data must be a mapping.") + + self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data) + + self.post_init() + + def post_init(self) -> None: + """Optional hook for subclasses requiring extra initialization.""" + return + + @property + def graph_init_params(self) -> "GraphInitParams": + return self._graph_init_params + + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: + return cast(NodeDataT, self._node_data_type.model_validate(data)) @abstractmethod def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: @@ -273,38 +399,29 @@ class Node: def retry(self) -> bool: return False - # Abstract methods that subclasses must implement to provide access - # to BaseNodeData properties in a type-safe way - - @abstractmethod def _get_error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" - ... + return self._node_data.error_strategy - @abstractmethod def _get_retry_config(self) -> RetryConfig: """Get the retry configuration for this node.""" - ... + return self._node_data.retry_config - @abstractmethod def _get_title(self) -> str: """Get the node title.""" - ... + return self._node_data.title - @abstractmethod def _get_description(self) -> str | None: """Get the node description.""" - ... + return self._node_data.desc - @abstractmethod def _get_default_value_dict(self) -> dict[str, Any]: """Get the default values dictionary for this node.""" - ... + return self._node_data.default_value_dict - @abstractmethod def get_base_node_data(self) -> BaseNodeData: """Get the BaseNodeData object for this node.""" - ... + return self._node_data # Public interface properties that delegate to abstract methods @property @@ -332,6 +449,11 @@ class Node: """Get the default values dictionary for this node.""" return self._get_default_value_dict() + @property + def node_data(self) -> NodeDataT: + """Typed access to this node's configuration data.""" + return self._node_data + def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: match result.status: case WorkflowNodeExecutionStatus.FAILED: diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index c87cbf9628..4c64f45f04 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -9,9 +9,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.variables.segments import ArrayFileSegment from core.variables.types import SegmentType -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.entities import CodeNodeData @@ -22,32 +21,11 @@ from .exc import ( ) -class CodeNode(Node): +class CodeNode(Node[CodeNodeData]): node_type = NodeType.CODE _node_data: CodeNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = CodeNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: """ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 34c1db9468..d8718222f8 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -20,9 +20,8 @@ from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey +from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.tool.exc import ToolFileError @@ -38,7 +37,7 @@ from .entities import DatasourceNodeData from .exc import DatasourceNodeError, DatasourceParameterError -class DatasourceNode(Node): +class DatasourceNode(Node[DatasourceNodeData]): """ Datasource Node """ @@ -47,27 +46,6 @@ class DatasourceNode(Node): node_type = NodeType.DATASOURCE execution_type = NodeExecutionType.ROOT - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = DatasourceNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - def _run(self) -> Generator: """ Run the datasource node diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 12cd7e2bd9..17f09e69a2 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -25,9 +25,8 @@ from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment from core.variables.segments import ArrayStringSegment, FileSegment -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from .entities import DocumentExtractorNodeData @@ -36,7 +35,7 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) -class DocumentExtractorNode(Node): +class DocumentExtractorNode(Node[DocumentExtractorNodeData]): """ Extracts text content from various file types. Supports plain text, PDF, and DOC/DOCX files. @@ -46,27 +45,6 @@ class DocumentExtractorNode(Node): _node_data: DocumentExtractorNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = DocumentExtractorNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 7ec74084d0..e188a5616b 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,41 +1,16 @@ -from collections.abc import Mapping -from typing import Any - -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.end.entities import EndNodeData -class EndNode(Node): +class EndNode(Node[EndNodeData]): node_type = NodeType.END execution_type = NodeExecutionType.RESPONSE _node_data: EndNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = EndNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 152d3cc562..3114bc3758 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -7,10 +7,10 @@ from configs import dify_config from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager from core.variables.segments import ArrayFileSegment -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector +from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.http_request.executor import Executor from factories import file_factory @@ -31,32 +31,11 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) -class HttpRequestNode(Node): +class HttpRequestNode(Node[HttpRequestNodeData]): node_type = NodeType.HTTP_REQUEST _node_data: HttpRequestNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = HttpRequestNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index c0d64a060a..db2df68f46 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -2,15 +2,14 @@ from collections.abc import Mapping from typing import Any from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult, PauseRequestedEvent -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from .entities import HumanInputNodeData -class HumanInputNode(Node): +class HumanInputNode(Node[HumanInputNodeData]): node_type = NodeType.HUMAN_INPUT execution_type = NodeExecutionType.BRANCH @@ -28,31 +27,10 @@ class HumanInputNode(Node): _node_data: HumanInputNodeData - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = HumanInputNodeData(**data) - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - def _run(self): # type: ignore[override] if self._is_completion_ready(): branch_handle = self._resolve_branch_selection() diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 165e529714..f4c6e1e190 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -3,9 +3,8 @@ from typing import Any, Literal from typing_extensions import deprecated -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.runtime import VariablePool @@ -13,33 +12,12 @@ from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor -class IfElseNode(Node): +class IfElseNode(Node[IfElseNodeData]): node_type = NodeType.IF_ELSE execution_type = NodeExecutionType.BRANCH _node_data: IfElseNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = IfElseNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 63e0932a98..9d0a9d48f7 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -14,7 +14,6 @@ from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import ( - ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionMetadataKey, @@ -36,7 +35,6 @@ from core.workflow.node_events import ( StreamCompletedEvent, ) from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.runtime import VariablePool @@ -60,7 +58,7 @@ logger = logging.getLogger(__name__) EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) -class IterationNode(LLMUsageTrackingMixin, Node): +class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): """ Iteration Node. """ @@ -69,27 +67,6 @@ class IterationNode(LLMUsageTrackingMixin, Node): execution_type = NodeExecutionType.CONTAINER _node_data: IterationNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = IterationNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 90b7f4539b..9767bd8d59 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,14 +1,10 @@ -from collections.abc import Mapping -from typing import Any - -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import IterationStartNodeData -class IterationStartNode(Node): +class IterationStartNode(Node[IterationStartNodeData]): """ Iteration Start Node. """ @@ -17,27 +13,6 @@ class IterationStartNode(Node): _node_data: IterationStartNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = IterationStartNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 2ba1e5e1c5..c222bd9712 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -10,9 +10,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey +from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.runtime import VariablePool @@ -35,32 +34,11 @@ default_retrieval_model = { } -class KnowledgeIndexNode(Node): +class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): _node_data: KnowledgeIndexNodeData node_type = NodeType.KNOWLEDGE_INDEX execution_type = NodeExecutionType.RESPONSE - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = KnowledgeIndexNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - def _run(self) -> NodeRunResult: # type: ignore node_data = self._node_data variable_pool = self.graph_runtime_state.variable_pool diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index e8ee44d5a9..99bb058c4b 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -30,14 +30,12 @@ from core.variables import ( from core.variables.segments import ArrayObjectSegment from core.workflow.entities import GraphInitParams from core.workflow.enums import ( - ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_ASSISTANT_PROMPT_1, @@ -82,7 +80,7 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node): +class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_data: KnowledgeRetrievalNodeData @@ -118,27 +116,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node): ) self._llm_file_saver = llm_file_saver - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = KnowledgeRetrievalNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls): return "1" diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 54f3ef8a54..ab63951082 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,12 +1,11 @@ -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from .entities import FilterOperator, ListOperatorNodeData, Order @@ -35,32 +34,11 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: return wrapper -class ListOperatorNode(Node): +class ListOperatorNode(Node[ListOperatorNodeData]): node_type = NodeType.LIST_OPERATOR _node_data: ListOperatorNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = ListOperatorNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 06c9beaed2..44a9ed95d9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -55,7 +55,6 @@ from core.variables import ( from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.enums import ( - ErrorStrategy, NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, @@ -69,7 +68,7 @@ from core.workflow.node_events import ( StreamChunkEvent, StreamCompletedEvent, ) -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector +from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool @@ -100,7 +99,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LLMNode(Node): +class LLMNode(Node[LLMNodeData]): node_type = NodeType.LLM _node_data: LLMNodeData @@ -139,27 +138,6 @@ class LLMNode(Node): ) self._llm_file_saver = llm_file_saver - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LLMNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index e5bce1230c..bdcae5c6fb 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,14 +1,10 @@ -from collections.abc import Mapping -from typing import Any - -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopEndNodeData -class LoopEndNode(Node): +class LoopEndNode(Node[LoopEndNodeData]): """ Loop End Node. """ @@ -17,27 +13,6 @@ class LoopEndNode(Node): _node_data: LoopEndNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopEndNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 60baed1ed5..ce7245952c 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast from core.model_runtime.entities.llm_entities import LLMUsage from core.variables import Segment, SegmentType from core.workflow.enums import ( - ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionMetadataKey, @@ -29,7 +28,6 @@ from core.workflow.node_events import ( StreamCompletedEvent, ) from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData from core.workflow.utils.condition.processor import ConditionProcessor @@ -42,7 +40,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(LLMUsageTrackingMixin, Node): +class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): """ Loop Node. """ @@ -51,27 +49,6 @@ class LoopNode(LLMUsageTrackingMixin, Node): _node_data: LoopNodeData execution_type = NodeExecutionType.CONTAINER - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index e065dc90a0..f9df4fa3a6 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,14 +1,10 @@ -from collections.abc import Mapping -from typing import Any - -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopStartNodeData -class LoopStartNode(Node): +class LoopStartNode(Node[LoopStartNodeData]): """ Loop Start Node. """ @@ -17,27 +13,6 @@ class LoopStartNode(Node): _node_data: LoopStartNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopStartNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/workflow/nodes/node_factory.py index 84f63d57eb..5fc363257b 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/workflow/nodes/node_factory.py @@ -69,17 +69,9 @@ class DifyNodeFactory(NodeFactory): raise ValueError(f"No latest version class found for node type: {node_type}") # Create node instance - node_instance = node_class( + return node_class( id=node_id, config=node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) - - # Initialize node with provided data - node_data = node_config.get("data", {}) - if not is_str_dict(node_data): - raise ValueError(f"Node {node_id} missing data information") - node_instance.init_node_data(node_data) - - return node_instance diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index e250650fef..e053e6c4a3 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -27,10 +27,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables.types import ArrayValidation, SegmentType -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.runtime import VariablePool @@ -84,7 +83,7 @@ def extract_json(text): return None -class ParameterExtractorNode(Node): +class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Parameter Extractor Node. """ @@ -93,27 +92,6 @@ class ParameterExtractorNode(Node): _node_data: ParameterExtractorNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = ParameterExtractorNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - _model_instance: ModelInstance | None = None _model_config: ModelConfigWithCredentialsEntity | None = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 948a1cead7..36a692d109 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -13,14 +13,13 @@ from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities import GraphInitParams from core.workflow.enums import ( - ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector +from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils @@ -44,7 +43,7 @@ if TYPE_CHECKING: from core.workflow.runtime import GraphRuntimeState -class QuestionClassifierNode(Node): +class QuestionClassifierNode(Node[QuestionClassifierNodeData]): node_type = NodeType.QUESTION_CLASSIFIER execution_type = NodeExecutionType.BRANCH @@ -78,27 +77,6 @@ class QuestionClassifierNode(Node): ) self._llm_file_saver = llm_file_saver - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = QuestionClassifierNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls): return "1" diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 3b134be1a1..634d6abd09 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,41 +1,16 @@ -from collections.abc import Mapping -from typing import Any - from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.start.entities import StartNodeData -class StartNode(Node): +class StartNode(Node[StartNodeData]): node_type = NodeType.START execution_type = NodeExecutionType.ROOT _node_data: StartNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = StartNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 254a8318b5..917680c428 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -3,41 +3,19 @@ from typing import Any from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH -class TemplateTransformNode(Node): +class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = NodeType.TEMPLATE_TRANSFORM _node_data: TemplateTransformNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = TemplateTransformNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 4f8dcb92ba..2a92292781 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -16,14 +16,12 @@ from core.tools.workflow_as_tool.tool import WorkflowTool from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( - ErrorStrategy, NodeType, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -42,7 +40,7 @@ if TYPE_CHECKING: from core.workflow.runtime import VariablePool -class ToolNode(Node): +class ToolNode(Node[ToolNodeData]): """ Tool Node """ @@ -51,9 +49,6 @@ class ToolNode(Node): _node_data: ToolNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = ToolNodeData.model_validate(data) - @classmethod def version(cls) -> str: return "1" @@ -498,24 +493,6 @@ class ToolNode(Node): return result - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @property def retry(self) -> bool: return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index c4c2ff87db..d745c06522 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,43 +1,18 @@ from collections.abc import Mapping -from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from .entities import TriggerEventNodeData -class TriggerEventNode(Node): +class TriggerEventNode(Node[TriggerEventNodeData]): node_type = NodeType.TRIGGER_PLUGIN execution_type = NodeExecutionType.ROOT - _node_data: TriggerEventNodeData - - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = TriggerEventNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index 98a841d1be..fb5c8a4dce 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,42 +1,17 @@ from collections.abc import Mapping -from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData -class TriggerScheduleNode(Node): +class TriggerScheduleNode(Node[TriggerScheduleNodeData]): node_type = NodeType.TRIGGER_SCHEDULE execution_type = NodeExecutionType.ROOT - _node_data: TriggerScheduleNodeData - - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = TriggerScheduleNodeData(**data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 15009f90d0..4bc6a82349 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -3,41 +3,17 @@ from typing import Any from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType +from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from .entities import ContentType, WebhookData -class TriggerWebhookNode(Node): +class TriggerWebhookNode(Node[WebhookData]): node_type = NodeType.TRIGGER_WEBHOOK execution_type = NodeExecutionType.ROOT - _node_data: WebhookData - - def init_node_data(self, data: Mapping[str, Any]) -> None: - self._node_data = WebhookData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 0ac0d3d858..679e001e79 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,40 +1,17 @@ from collections.abc import Mapping -from typing import Any from core.variables.segments import Segment -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData -class VariableAggregatorNode(Node): +class VariableAggregatorNode(Node[VariableAssignerNodeData]): node_type = NodeType.VARIABLE_AGGREGATOR _node_data: VariableAssignerNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = VariableAssignerNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 3a0793f092..f07b5760fd 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -5,9 +5,8 @@ from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities import GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError @@ -22,33 +21,12 @@ if TYPE_CHECKING: _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] -class VariableAssignerNode(Node): +class VariableAssignerNode(Node[VariableAssignerData]): node_type = NodeType.VARIABLE_ASSIGNER _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY _node_data: VariableAssignerData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = VariableAssignerData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - def __init__( self, id: str, diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index f15924d78f..e7150393d5 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -7,9 +7,8 @@ from core.variables import SegmentType, Variable from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError @@ -51,32 +50,11 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ mapping[key] = selector -class VariableAssignerNode(Node): +class VariableAssignerNode(Node[VariableAssignerNodeData]): node_type = NodeType.VARIABLE_ASSIGNER _node_data: VariableAssignerNodeData - def init_node_data(self, data: Mapping[str, Any]): - self._node_data = VariableAssignerNodeData.model_validate(data) - - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - return self._node_data.retry_config - - def _get_title(self) -> str: - return self._node_data.title - - def _get_description(self) -> str | None: - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._node_data - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: """ Check if this Variable Assigner node blocks the output of specific variables. diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index a6c6784e39..d4ec29518a 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -159,7 +159,6 @@ class WorkflowEntry: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - node.init_node_data(node_config_data) try: # variable selector to variable mapping @@ -303,7 +302,6 @@ class WorkflowEntry: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - node.init_node_data(node_data) try: # variable selector to variable mapping diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 78878cdeef..e421e4ff36 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -69,10 +69,6 @@ def init_code_node(code_config: dict): graph_runtime_state=graph_runtime_state, ) - # Initialize node data - if "data" in code_config: - node.init_node_data(code_config["data"]) - return node diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 2367990d3e..e75258a2a2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -65,10 +65,6 @@ def init_http_node(config: dict): graph_runtime_state=graph_runtime_state, ) - # Initialize node data - if "data" in config: - node.init_node_data(config["data"]) - return node @@ -709,10 +705,6 @@ def test_nested_object_variable_selector(setup_http_mock): graph_runtime_state=graph_runtime_state, ) - # Initialize node data - if "data" in graph_config["nodes"][1]: - node.init_node_data(graph_config["nodes"][1]["data"]) - result = node._run() assert result.process_data is not None data = result.process_data.get("request", "") diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 3b16c3920b..d268c5da22 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -82,10 +82,6 @@ def init_llm_node(config: dict) -> LLMNode: graph_runtime_state=graph_runtime_state, ) - # Initialize node data - if "data" in config: - node.init_node_data(config["data"]) - return node diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 9d9102cee2..654db59bec 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -85,7 +85,6 @@ def init_parameter_extractor_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - node.init_node_data(config.get("data", {})) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 285387b817..3bcb9a3a34 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -82,7 +82,6 @@ def test_execute_code(setup_code_executor_mock): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - node.init_node_data(config.get("data", {})) # execute node result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 8dd8150b1c..d666f0ebe2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -62,7 +62,6 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - node.init_node_data(config.get("data", {})) return node diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index 0f62a11684..2597a3d65a 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -3,7 +3,6 @@ from __future__ import annotations import time from collections.abc import Mapping from dataclasses import dataclass -from typing import Any import pytest @@ -12,14 +11,19 @@ from core.workflow.entities import GraphInitParams from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType from core.workflow.graph import Graph from core.workflow.graph.validation import GraphValidationError -from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.node import Node from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom -class _TestNode(Node): +class _TestNodeData(BaseNodeData): + type: NodeType | str | None = None + execution_type: NodeExecutionType | str | None = None + + +class _TestNode(Node[_TestNodeData]): node_type = NodeType.ANSWER execution_type = NodeExecutionType.EXECUTABLE @@ -41,31 +45,8 @@ class _TestNode(Node): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - data = config.get("data", {}) - if isinstance(data, Mapping): - execution_type = data.get("execution_type") - if isinstance(execution_type, str): - self.execution_type = NodeExecutionType(execution_type) - self._base_node_data = BaseNodeData(title=str(data.get("title", self.id))) - self.data: dict[str, object] = {} - def init_node_data(self, data: Mapping[str, object]) -> None: - title = str(data.get("title", self.id)) - desc = data.get("description") - error_strategy_value = data.get("error_strategy") - error_strategy: ErrorStrategy | None = None - if isinstance(error_strategy_value, ErrorStrategy): - error_strategy = error_strategy_value - elif isinstance(error_strategy_value, str): - error_strategy = ErrorStrategy(error_strategy_value) - self._base_node_data = BaseNodeData( - title=title, - desc=str(desc) if desc is not None else None, - error_strategy=error_strategy, - ) - self.data = dict(data) - - node_type_value = data.get("type") + node_type_value = self.data.get("type") if isinstance(node_type_value, NodeType): self.node_type = node_type_value elif isinstance(node_type_value, str): @@ -77,23 +58,19 @@ class _TestNode(Node): def _run(self): raise NotImplementedError - def _get_error_strategy(self) -> ErrorStrategy | None: - return self._base_node_data.error_strategy + def post_init(self) -> None: + super().post_init() + self._maybe_override_execution_type() + self.data = dict(self.node_data.model_dump()) - def _get_retry_config(self) -> RetryConfig: - return self._base_node_data.retry_config - - def _get_title(self) -> str: - return self._base_node_data.title - - def _get_description(self) -> str | None: - return self._base_node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - return self._base_node_data.default_value_dict - - def get_base_node_data(self) -> BaseNodeData: - return self._base_node_data + def _maybe_override_execution_type(self) -> None: + execution_type_value = self.node_data.execution_type + if execution_type_value is None: + return + if isinstance(execution_type_value, NodeExecutionType): + self.execution_type = execution_type_value + else: + self.execution_type = NodeExecutionType(execution_type_value) @dataclass(slots=True) @@ -109,7 +86,6 @@ class _SimpleNodeFactory: graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) - node.init_node_data(node_config.get("data", {})) return node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 5d958803bc..b074a11be9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -32,7 +32,7 @@ def test_abort_command(): # Create mock nodes with required attributes - using shared runtime state start_node = StartNode( id="start", - config={"id": "start"}, + config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( tenant_id="test_tenant", app_id="test_app", @@ -45,7 +45,6 @@ def test_abort_command(): ), graph_runtime_state=shared_runtime_state, ) - start_node.init_node_data({"title": "start", "variables": []}) mock_graph.nodes["start"] = start_node # Mock graph methods @@ -142,7 +141,7 @@ def test_pause_command(): start_node = StartNode( id="start", - config={"id": "start"}, + config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( tenant_id="test_tenant", app_id="test_app", @@ -155,7 +154,6 @@ def test_pause_command(): ), graph_runtime_state=shared_runtime_state, ) - start_node.init_node_data({"title": "start", "variables": []}) mock_graph.nodes["start"] = start_node mock_graph.get_outgoing_edges = MagicMock(return_value=[]) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index c9e7e31e52..1c50318af6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -63,7 +63,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - start_node.init_node_data(start_config["data"]) def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: llm_data = LLMNodeData( @@ -88,7 +87,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - llm_node.init_node_data(llm_config["data"]) return llm_node llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") @@ -105,7 +103,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - human_node.init_node_data(human_config["data"]) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") @@ -125,7 +122,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - end_primary.init_node_data(end_primary_config["data"]) end_secondary_data = EndNodeData( title="End Secondary", @@ -142,7 +138,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - end_secondary.init_node_data(end_secondary_config["data"]) graph = ( Graph.new() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 27d264365d..d7de18172b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -62,7 +62,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - start_node.init_node_data(start_config["data"]) def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: llm_data = LLMNodeData( @@ -87,7 +86,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - llm_node.init_node_data(llm_config["data"]) return llm_node llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") @@ -104,7 +102,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - human_node.init_node_data(human_config["data"]) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -123,7 +120,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - end_node.init_node_data(end_config["data"]) graph = ( Graph.new() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index dfd33f135f..5d2c17b9b4 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -62,7 +62,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - start_node.init_node_data(start_config["data"]) def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: llm_data = LLMNodeData( @@ -87,7 +86,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - llm_node.init_node_data(llm_config["data"]) return llm_node llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") @@ -118,7 +116,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - if_else_node.init_node_data(if_else_config["data"]) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") @@ -138,7 +135,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - end_primary.init_node_data(end_primary_config["data"]) end_secondary_data = EndNodeData( title="End Secondary", @@ -155,7 +151,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - end_secondary.init_node_data(end_secondary_config["data"]) graph = ( Graph.new() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 03de984bd1..eeffdd27fe 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -111,9 +111,6 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, ) - # Initialize node with provided data - mock_instance.init_node_data(node_data) - return mock_instance # For non-mocked node types, use parent implementation diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 48fa00f105..1cda6ced31 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -142,6 +142,8 @@ def test_mock_loop_node_preserves_config(): "start_node_id": "node1", "loop_variables": [], "outputs": {}, + "break_conditions": [], + "logical_operator": "and", }, } diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 23274f5981..4fb693a5c2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -63,7 +63,6 @@ class TestMockTemplateTransformNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() @@ -125,7 +124,6 @@ class TestMockTemplateTransformNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() @@ -184,7 +182,6 @@ class TestMockTemplateTransformNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() @@ -246,7 +243,6 @@ class TestMockTemplateTransformNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() @@ -311,7 +307,6 @@ class TestMockCodeNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() @@ -376,7 +371,6 @@ class TestMockCodeNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() @@ -445,7 +439,6 @@ class TestMockCodeNode: graph_runtime_state=graph_runtime_state, mock_config=mock_config, ) - mock_node.init_node_data(node_config["data"]) # Run the node result = mock_node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index d151bbe015..98d9560e64 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -83,9 +83,6 @@ def test_execute_answer(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 4b1f224e67..6eead80ac9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,4 +1,7 @@ +import pytest + from core.workflow.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.node import Node # Ensures that all node classes are imported. @@ -7,6 +10,12 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING _ = NODE_TYPE_CLASSES_MAPPING +class _TestNodeData(BaseNodeData): + """Test node data for unit tests.""" + + pass + + def _get_all_subclasses(root: type[Node]) -> list[type[Node]]: subclasses = [] queue = [root] @@ -34,3 +43,79 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined node_type_and_version = (node_type, node_version) assert node_type_and_version not in type_version_set type_version_set.add(node_type_and_version) + + +def test_extract_node_data_type_from_generic_extracts_type(): + """When a class inherits from Node[T], it should extract T.""" + + class _ConcreteNode(Node[_TestNodeData]): + node_type = NodeType.CODE + + @staticmethod + def version() -> str: + return "1" + + result = _ConcreteNode._extract_node_data_type_from_generic() + + assert result is _TestNodeData + + +def test_extract_node_data_type_from_generic_returns_none_for_base_node(): + """The base Node class itself should return None (no generic parameter).""" + result = Node._extract_node_data_type_from_generic() + + assert result is None + + +def test_extract_node_data_type_from_generic_raises_for_non_base_node_data(): + """When generic parameter is not a BaseNodeData subtype, should raise TypeError.""" + with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"): + + class _InvalidNode(Node[str]): # type: ignore[type-arg] + pass + + +def test_extract_node_data_type_from_generic_raises_for_non_type(): + """When generic parameter is not a concrete type, should raise TypeError.""" + from typing import TypeVar + + T = TypeVar("T") + + with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"): + + class _InvalidNode(Node[T]): # type: ignore[type-arg] + pass + + +def test_init_subclass_raises_without_generic_or_explicit_type(): + """A subclass must either use Node[T] or explicitly set _node_data_type.""" + with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"): + + class _InvalidNode(Node): + pass + + +def test_init_subclass_rejects_explicit_node_data_type_without_generic(): + """Setting _node_data_type explicitly cannot bypass the Node[T] requirement.""" + with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"): + + class _ExplicitNode(Node): + _node_data_type = _TestNodeData + node_type = NodeType.CODE + + @staticmethod + def version() -> str: + return "1" + + +def test_init_subclass_sets_node_data_type_from_generic(): + """Verify that __init_subclass__ sets _node_data_type from the generic parameter.""" + + class _AutoNode(Node[_TestNodeData]): + node_type = NodeType.CODE + + @staticmethod + def version() -> str: + return "1" + + assert _AutoNode._node_data_type is _TestNodeData diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 3ffb5c0fdf..77264022bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -111,8 +111,6 @@ def llm_node( graph_runtime_state=graph_runtime_state, llm_file_saver=mock_file_saver, ) - # Initialize node data - node.init_node_data(node_config["data"]) return node @@ -498,8 +496,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, llm_file_saver=mock_file_saver, ) - # Initialize node data - node.init_node_data(node_config["data"]) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py new file mode 100644 index 0000000000..4a57ab2b89 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -0,0 +1,74 @@ +from collections.abc import Mapping + +import pytest + +from core.workflow.entities import GraphInitParams +from core.workflow.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import Node +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable + + +class _SampleNodeData(BaseNodeData): + foo: str + + +class _SampleNode(Node[_SampleNodeData]): + node_type = NodeType.ANSWER + + @classmethod + def version(cls) -> str: + return "sample-test" + + def _run(self): + raise NotImplementedError + + +def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: + init_params = GraphInitParams( + tenant_id="tenant", + app_id="app", + workflow_id="workflow", + graph_config=graph_config, + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + start_at=0.0, + ) + return init_params, runtime_state + + +def test_node_hydrates_data_during_initialization(): + graph_config: dict[str, object] = {} + init_params, runtime_state = _build_context(graph_config) + + node = _SampleNode( + id="node-1", + config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + assert node.node_data.foo == "bar" + assert node.title == "Sample" + + +def test_missing_generic_argument_raises_type_error(): + graph_config: dict[str, object] = {} + + with pytest.raises(TypeError): + + class _InvalidNode(Node): # type: ignore[type-abstract] + node_type = NodeType.ANSWER + + @classmethod + def version(cls) -> str: + return "1" + + def _run(self): + raise NotImplementedError diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 315c50d946..088c60a337 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -50,8 +50,6 @@ def document_extractor_node(graph_init_params): graph_init_params=graph_init_params, graph_runtime_state=Mock(), ) - # Initialize node data - node.init_node_data(node_config["data"]) return node diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 962e43a897..dc7175f964 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -114,9 +114,6 @@ def test_execute_if_else_result_true(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - # Mock db.session.close() db.session.close = MagicMock() @@ -187,9 +184,6 @@ def test_execute_if_else_result_false(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - # Mock db.session.close() db.session.close = MagicMock() @@ -252,9 +246,6 @@ def test_array_file_contains_file_name(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( @@ -347,7 +338,6 @@ def test_execute_if_else_boolean_conditions(condition: Condition): graph_runtime_state=graph_runtime_state, config={"id": "if-else", "data": node_data}, ) - node.init_node_data(node_data) # Mock db.session.close() db.session.close = MagicMock() @@ -417,7 +407,6 @@ def test_execute_if_else_boolean_false_conditions(): "data": node_data, }, ) - node.init_node_data(node_data) # Mock db.session.close() db.session.close = MagicMock() @@ -487,7 +476,6 @@ def test_execute_if_else_boolean_cases_structure(): graph_runtime_state=graph_runtime_state, config={"id": "if-else", "data": node_data}, ) - node.init_node_data(node_data) # Mock db.session.close() db.session.close = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 55fe62ca43..ff3eec0608 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -57,8 +57,6 @@ def list_operator_node(): graph_init_params=graph_init_params, graph_runtime_state=MagicMock(), ) - # Initialize node data - node.init_node_data(node_config["data"]) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 1f35c0faed..09b8191870 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -73,7 +73,6 @@ def tool_node(monkeypatch) -> "ToolNode": graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - node.init_node_data(config["data"]) return node diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 6af4777e0e..ef23a8f565 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -101,9 +101,6 @@ def test_overwrite_string_variable(): conv_var_updater_factory=mock_conv_var_updater_factory, ) - # Initialize node data - node.init_node_data(node_config["data"]) - list(node.run()) expected_var = StringVariable( id=conversation_variable.id, @@ -203,9 +200,6 @@ def test_append_variable_to_array(): conv_var_updater_factory=mock_conv_var_updater_factory, ) - # Initialize node data - node.init_node_data(node_config["data"]) - list(node.run()) expected_value = list(conversation_variable.value) expected_value.append(input_variable.value) @@ -296,9 +290,6 @@ def test_clear_array(): conv_var_updater_factory=mock_conv_var_updater_factory, ) - # Initialize node data - node.init_node_data(node_config["data"]) - list(node.run()) expected_var = ArrayStringVariable( id=conversation_variable.id, diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 80071c8616..f793341e73 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -139,11 +139,6 @@ def test_remove_first_from_array(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - - # Skip the mock assertion since we're in a test environment - # Run the node result = list(node.run()) @@ -228,10 +223,6 @@ def test_remove_last_from_array(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - - # Skip the mock assertion since we're in a test environment list(node.run()) got = variable_pool.get(["conversation", conversation_variable.name]) @@ -313,10 +304,6 @@ def test_remove_first_from_empty_array(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - - # Skip the mock assertion since we're in a test environment list(node.run()) got = variable_pool.get(["conversation", conversation_variable.name]) @@ -398,10 +385,6 @@ def test_remove_last_from_empty_array(): config=node_config, ) - # Initialize node data - node.init_node_data(node_config["data"]) - - # Skip the mock assertion since we're in a test environment list(node.run()) got = variable_pool.get(["conversation", conversation_variable.name]) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index d7094ae5f2..a599d4f831 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -47,7 +47,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) ), ) - node.init_node_data(node_config["data"]) return node