From 5cad0caaaead2ef3509f50db28639d4ef3338dde Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 18 Mar 2026 17:50:18 +0800 Subject: [PATCH] refactor: make code simpler Signed-off-by: -LAN- --- api/core/workflow/workflow_entry.py | 9 ++------- api/dify_graph/graph_engine/graph_engine.py | 4 ---- api/dify_graph/nodes/iteration/iteration_node.py | 7 ++----- api/dify_graph/runtime/graph_runtime_state.py | 8 ++++---- .../workflow/entities/test_graph_runtime_state.py | 11 +++++++++++ .../core/workflow/test_workflow_entry_helpers.py | 7 ++++--- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 36d021eb45..22c855e7bf 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,7 +1,6 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from contextlib import AbstractContextManager from typing import Any, cast from configs import dify_config @@ -40,9 +39,6 @@ logger = logging.getLogger(__name__) class _WorkflowChildEngineBuilder: - def __init__(self, execution_context: AbstractContextManager[object] | None = None) -> None: - self._execution_context = execution_context - @staticmethod def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: """ @@ -97,7 +93,6 @@ class _WorkflowChildEngineBuilder: command_channel=command_channel, config=config, child_engine_builder=self, - execution_context=self._execution_context, ) child_engine.layer(LLMQuotaLayer()) for layer in layers: @@ -149,7 +144,8 @@ class WorkflowEntry: self.command_channel = command_channel execution_context = capture_current_context() - self._child_engine_builder = _WorkflowChildEngineBuilder(execution_context=execution_context) + graph_runtime_state.execution_context = execution_context + self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -162,7 +158,6 @@ class WorkflowEntry: scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, ), child_engine_builder=self._child_engine_builder, - execution_context=execution_context, ) # Add debug logging layer when in debug mode diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index 5cdffed3b0..dfe8ad71fb 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -10,7 +10,6 @@ from __future__ import annotations import logging import queue from collections.abc import Generator, Mapping -from contextlib import AbstractContextManager from typing import TYPE_CHECKING, cast, final from dify_graph.entities.workflow_start_reason import WorkflowStartReason @@ -77,7 +76,6 @@ class GraphEngine: command_channel: CommandChannel, config: GraphEngineConfig = _DEFAULT_CONFIG, child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, - execution_context: AbstractContextManager[object] | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" @@ -90,8 +88,6 @@ class GraphEngine: self._child_engine_builder = child_engine_builder if child_engine_builder is not None: self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) - if execution_context is not None: - self._graph_runtime_state.execution_context = execution_context # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 376579d6fd..c1c25aa0f9 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -1,7 +1,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from contextlib import AbstractContextManager, nullcontext +from contextlib import AbstractContextManager from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast @@ -336,10 +336,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _capture_execution_context(self) -> AbstractContextManager[object]: """Return the application-supplied execution context for parallel iterations.""" - execution_context = self.graph_runtime_state.execution_context - if execution_context is not None: - return execution_context - return nullcontext() + return self.graph_runtime_state.execution_context def _handle_iteration_success( self, diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py index 984dcc3526..becacf4642 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -3,7 +3,7 @@ from __future__ import annotations import importlib import json from collections.abc import Mapping, Sequence -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, nullcontext from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol @@ -235,7 +235,7 @@ class GraphRuntimeState: self._response_coordinator = response_coordinator # Application code injects this when worker threads must restore request # or framework-local state. It is intentionally excluded from snapshots. - self._execution_context = execution_context + self._execution_context = execution_context if execution_context is not None else nullcontext() self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() @@ -335,12 +335,12 @@ class GraphRuntimeState: return self._response_coordinator @property - def execution_context(self) -> AbstractContextManager[object] | None: + def execution_context(self) -> AbstractContextManager[object]: return self._execution_context @execution_context.setter def execution_context(self, value: AbstractContextManager[object] | None) -> None: - self._execution_context = value + self._execution_context = value if value is not None else nullcontext() # ------------------------------------------------------------------ # Scalar state diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 491326916b..6331edc819 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -23,6 +23,17 @@ class StubCoordinator: class TestGraphRuntimeState: + def test_execution_context_defaults_to_empty_context(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + + with state.execution_context: + assert state.execution_context is not None + + state.execution_context = None + + with state.execution_context: + assert state.execution_context is not None + def test_property_getters_and_setters(self): # FIXME(-LAN-): Mock VariablePool if needed variable_pool = VariablePool() diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 26b419b330..5960230dd3 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -120,6 +120,7 @@ class TestWorkflowEntryInit: def test_applies_debug_and_observability_layers(self): graph_engine = MagicMock() + graph_runtime_state = SimpleNamespace(execution_context=None) debug_layer = sentinel.debug_layer execution_limits_layer = sentinel.execution_limits_layer llm_quota_layer = sentinel.llm_quota_layer @@ -153,7 +154,7 @@ class TestWorkflowEntryInit: invoke_from=InvokeFrom.DEBUGGER, call_depth=0, variable_pool=sentinel.variable_pool, - graph_runtime_state=sentinel.graph_runtime_state, + graph_runtime_state=graph_runtime_state, command_channel=None, ) @@ -161,12 +162,12 @@ class TestWorkflowEntryInit: graph_engine_cls.assert_called_once_with( workflow_id="workflow-id-123456", graph=sentinel.graph, - graph_runtime_state=sentinel.graph_runtime_state, + graph_runtime_state=graph_runtime_state, command_channel=sentinel.command_channel, config=sentinel.graph_engine_config, child_engine_builder=entry._child_engine_builder, - execution_context=sentinel.execution_context, ) + assert graph_runtime_state.execution_context is sentinel.execution_context debug_logging_layer.assert_called_once_with( level="DEBUG", include_inputs=True,