diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c59fb3422b..716dfc2121 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any from configs import dify_config from context import capture_current_context @@ -24,7 +24,6 @@ from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from dify_graph.nodes import BuiltinNodeTypes @@ -64,16 +63,22 @@ class _WorkflowChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + """Build a child engine with a fresh runtime state and no inherited layers.""" + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, ) + graph_config = graph_init_params.graph_config has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) if has_root_node is False: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -89,14 +94,11 @@ class _WorkflowChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=command_channel, config=config, child_engine_builder=self, ) - child_engine.layer(LLMQuotaLayer()) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index dfe8ad71fb..c6d294f57d 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -9,7 +9,7 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import TYPE_CHECKING, cast, final from dify_graph.entities.workflow_start_reason import WorkflowStartReason @@ -25,7 +25,7 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis @@ -85,6 +85,7 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._layers: list[GraphEngineLayer] = [] self._child_engine_builder = child_engine_builder if child_engine_builder is not None: self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) @@ -148,10 +149,6 @@ class GraphEngine: update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Worker Pool Setup === # Create worker pool for parallel node execution self._worker_pool = WorkerPool( @@ -221,18 +218,14 @@ class GraphEngine: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: dict[str, object] | Mapping[str, object], root_node_id: str, - layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: return self._graph_runtime_state.create_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) def run(self) -> Generator[GraphEngineEvent, None, None]: diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index f7c0684819..7a87889b3e 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -548,7 +548,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _create_graph_engine(self, index: int, item: object): from dify_graph.entities import GraphInitParams - from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState + from dify_graph.runtime import ChildGraphNotFoundError # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -563,14 +563,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # append iteration variable (item, index) to variable pool variable_pool_copy.add([self._node_id, "index"], index) variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) root_node_id = self.node_data.start_node_id if root_node_id is None: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") @@ -579,9 +571,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, + variable_pool=variable_pool_copy, ) except ChildGraphNotFoundError as exc: raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 64a0535ef4..fabdf8131e 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -408,7 +408,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _create_graph_engine(self, start_at: datetime, root_node_id: str): from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -418,16 +417,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): call_depth=self.workflow_call_depth, ) - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, ) diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py index 0a83a232ea..df65bbaeef 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -143,10 +143,9 @@ class ChildGraphEngineBuilderProtocol(Protocol): *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: ... @@ -290,21 +289,19 @@ class GraphRuntimeState: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: + """Create a child graph engine that derives its runtime state from the parent.""" if self._child_engine_builder is None: raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") return self._child_engine_builder.build_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, + parent_graph_runtime_state=self, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) # ------------------------------------------------------------------ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index aef32208d7..75e0484225 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,12 +12,12 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file @@ -28,7 +28,6 @@ from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, @@ -61,20 +60,28 @@ class _TableTestChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) if self._use_mock_factory: node_factory = MockNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, mock_config=self._mock_config, ) else: - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, + ) + graph_config = graph_init_params.graph_config child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) if not child_graph: raise ValueError("child graph not found") @@ -82,13 +89,11 @@ class _TableTestChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, ) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index bc4d3e4097..1a6938b4c8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any import pytest @@ -22,10 +22,9 @@ class _MissingGraphBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> object: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -69,8 +68,6 @@ def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing runtime_state.create_child_engine( workflow_id="workflow", graph_init_params=graph_init_params, - graph_runtime_state=_build_runtime_state(), - graph_config={}, root_node_id="root", ) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 9945fb992f..b07ac5cc2d 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -38,45 +38,62 @@ class TestWorkflowChildEngineBuilder: def test_build_child_engine_raises_when_root_node_is_missing(self): builder = workflow_entry._WorkflowChildEngineBuilder() + graph_init_params = SimpleNamespace(graph_config={"nodes": []}) + parent_graph_runtime_state = SimpleNamespace( + execution_context=sentinel.execution_context, + variable_pool=sentinel.variable_pool, + ) with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory): with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"): builder.build_child_engine( workflow_id="workflow-id", - graph_init_params=sentinel.graph_init_params, - graph_runtime_state=sentinel.graph_runtime_state, - graph_config={"nodes": []}, + graph_init_params=graph_init_params, + parent_graph_runtime_state=parent_graph_runtime_state, root_node_id="missing", ) - def test_build_child_engine_constructs_graph_engine_and_layers(self): + def test_build_child_engine_constructs_graph_engine_without_layers(self): builder = workflow_entry._WorkflowChildEngineBuilder() + graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]}) + parent_graph_runtime_state = SimpleNamespace( + execution_context=sentinel.execution_context, + variable_pool=sentinel.parent_variable_pool, + ) child_graph = sentinel.child_graph + child_graph_runtime_state = sentinel.child_graph_runtime_state child_engine = MagicMock() - quota_layer = sentinel.quota_layer - additional_layers = [sentinel.layer_one, sentinel.layer_two] with ( + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object( + workflow_entry, + "GraphRuntimeState", + return_value=child_graph_runtime_state, + ) as graph_runtime_state_cls, patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory, patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init, patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), - patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer), ): result = builder.build_child_engine( workflow_id="workflow-id", - graph_init_params=sentinel.graph_init_params, - graph_runtime_state=sentinel.graph_runtime_state, - graph_config={"nodes": [{"id": "root"}]}, + graph_init_params=graph_init_params, + parent_graph_runtime_state=parent_graph_runtime_state, root_node_id="root", - layers=additional_layers, + variable_pool=sentinel.child_variable_pool, ) assert result is child_engine + graph_runtime_state_cls.assert_called_once_with( + variable_pool=sentinel.child_variable_pool, + start_at=123.0, + execution_context=sentinel.execution_context, + ) dify_node_factory.assert_called_once_with( - graph_init_params=sentinel.graph_init_params, - graph_runtime_state=sentinel.graph_runtime_state, + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, ) graph_init.assert_called_once_with( graph_config={"nodes": [{"id": "root"}]}, @@ -86,16 +103,12 @@ class TestWorkflowChildEngineBuilder: graph_engine_cls.assert_called_once_with( workflow_id="workflow-id", graph=child_graph, - graph_runtime_state=sentinel.graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=sentinel.command_channel, config=sentinel.graph_engine_config, child_engine_builder=builder, ) - assert child_engine.layer.call_args_list == [ - ((quota_layer,), {}), - ((sentinel.layer_one,), {}), - ((sentinel.layer_two,), {}), - ] + child_engine.layer.assert_not_called() class TestWorkflowEntryInit: