diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c03cd9da11..1c1434c401 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -69,7 +69,7 @@ class _WorkflowChildEngineBuilder: root_node_id: str, variable_pool: VariablePool | None = None, ) -> GraphEngine: - """Build a child engine with a fresh runtime state and no inherited layers.""" + """Build a child engine with a fresh runtime state and only child-safe layers.""" child_graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, start_at=time.perf_counter(), @@ -101,6 +101,7 @@ class _WorkflowChildEngineBuilder: config=config, child_engine_builder=self, ) + child_engine.layer(LLMQuotaLayer()) return child_engine diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index c6d294f57d..bfacd8ed71 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -213,6 +213,10 @@ class GraphEngine: self._bind_layer_context(layer) return self + def request_abort(self, reason: str | None = None) -> None: + """Queue an abort command for this engine.""" + self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) + def create_child_engine( self, *, diff --git a/api/dify_graph/nodes/iteration/exc.py b/api/dify_graph/nodes/iteration/exc.py index d9947e09bc..7b6af61b9d 100644 --- a/api/dify_graph/nodes/iteration/exc.py +++ b/api/dify_graph/nodes/iteration/exc.py @@ -20,3 +20,7 @@ class IterationGraphNotFoundError(IterationNodeError): class IterationIndexNotFoundError(IterationNodeError): """Raised when the iteration index is not found.""" + + +class ChildGraphAbortedError(IterationNodeError): + """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 23da584743..a24e6c33e5 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -1,7 +1,9 @@ import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from contextlib import suppress from datetime import UTC, datetime +from threading import Lock from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs @@ -15,6 +17,7 @@ from dify_graph.enums import ( ) from dify_graph.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunSucceededEvent, @@ -37,6 +40,7 @@ from dify_graph.variables import IntegerVariable, NoneSegment from dify_graph.variables.segments import ArrayAnySegment, ArraySegment from .exc import ( + ChildGraphAbortedError, InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -49,6 +53,7 @@ if TYPE_CHECKING: from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) @@ -195,16 +200,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_engine = self._create_graph_engine(index, item) # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Accumulate usage from this iteration - usage_accumulator[0] = self._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) + try: + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, + ) + finally: + self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -222,6 +225,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks + started_child_engines: dict[int, GraphEngine] = {} + started_child_engines_lock = Lock() + merged_usage_indexes: set[int] = set() future_to_index: dict[ Future[ tuple[ @@ -236,9 +242,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): for index, item in enumerate(iterator_list_value): yield IterationNextEvent(index=index) future = executor.submit( - self._execute_single_iteration_parallel, + self._execute_tracked_iteration_parallel, index=index, item=item, + started_child_engines=started_child_engines, + started_child_engines_lock=started_child_engines_lock, ) future_to_index[future] = index @@ -265,8 +273,31 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iter_run_map[str(index)] = iteration_duration usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) + merged_usage_indexes.add(index) except Exception as e: + if index not in merged_usage_indexes: + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + if isinstance(e, ChildGraphAbortedError): + self._abort_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, + ) + self._drain_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + usage_accumulator=usage_accumulator, + merged_usage_indexes=merged_usage_indexes, + ) + raise e + # Handle errors based on error_handle_mode match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -284,18 +315,100 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs[:] = [output for output in outputs if output is not None] + @staticmethod + def _merge_graph_engine_usage( + *, + usage_accumulator: list[LLMUsage], + graph_engine: "GraphEngine | None", + ) -> None: + if graph_engine is None: + return + usage_accumulator[0] = IterationNode._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) + + def _abort_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + reason: str, + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + + graph_engine = started_child_engines.get(index) + if graph_engine is not None: + graph_engine.request_abort(reason) + + future.cancel() + + def _drain_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + usage_accumulator: list[LLMUsage], + merged_usage_indexes: set[int], + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + if future.cancelled(): + continue + + with suppress(Exception): + future.result() + + if index in merged_usage_indexes: + continue + + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + + def _execute_tracked_iteration_parallel( + self, + *, + index: int, + item: object, + started_child_engines: dict[int, "GraphEngine"], + started_child_engines_lock: Lock, + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + graph_engine = self._create_graph_engine(index, item) + with started_child_engines_lock: + started_child_engines[index] = graph_engine + + return self._execute_parallel_iteration_with_graph_engine( + index=index, + graph_engine=graph_engine, + ) + def _execute_single_iteration_parallel( self, index: int, item: object, ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: """Execute a single iteration in parallel mode and return results.""" + graph_engine = self._create_graph_engine(index, item) + return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) + + def _execute_parallel_iteration_with_graph_engine( + self, + *, + index: int, + graph_engine: "GraphEngine", + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + """Execute a prepared child engine in parallel mode and return results.""" iter_start_at = datetime.now(UTC).replace(tzinfo=None) events: list[GraphNodeEventBase] = [] outputs_temp: list[object] = [] - graph_engine = self._create_graph_engine(index, item) - # Collect events instead of yielding them directly for event in self._run_single_iter( variable_pool=graph_engine.graph_runtime_state.variable_pool, @@ -529,6 +642,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): else: outputs.append(result.to_object()) return + elif isinstance(event, GraphRunAbortedEvent): + raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) elif isinstance(event, GraphRunFailedEvent): match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index fabdf8131e..8542972a30 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -14,6 +14,7 @@ from dify_graph.enums import ( ) from dify_graph.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, NodeRunSucceededEvent, ) @@ -37,6 +38,7 @@ if TYPE_CHECKING: from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): @@ -123,7 +125,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) loop_start_time = datetime.now(UTC).replace(tzinfo=None) - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + try: + reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + finally: + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) # Track loop duration loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() @@ -140,9 +145,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # For other outputs, just update self.graph_runtime_state.set_output(key, value) - # Accumulate usage from the sub-graph execution - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -254,6 +256,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): yield event if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: reach_break_node = True + if isinstance(event, GraphRunAbortedEvent): + raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) if isinstance(event, GraphRunFailedEvent): raise Exception(event.error) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 23cd5d7f8f..d1a7e2da51 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -73,9 +73,8 @@ def test_abort_command(): config=GraphEngineConfig(), ) - # Send abort command before starting - abort_command = AbortCommand(reason="Test abort") - command_channel.send_command(abort_command) + # Queue an abort request before starting. + engine.request_abort("Test abort") # Run engine and collect events events = list(engine.run()) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py new file mode 100644 index 0000000000..969a21ed7f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py @@ -0,0 +1,201 @@ +from threading import Event +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph_events import GraphRunAbortedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.nodes.iteration.exc import ChildGraphAbortedError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage + + +class _AbortOnRequestGraphEngine: + def __init__(self, *, index: int, total_tokens: int) -> None: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + self.started = Event() + self.abort_requested = Event() + self.finished = Event() + self.abort_reason: str | None = None + self.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ) + + def request_abort(self, reason: str | None = None) -> None: + self.abort_reason = reason + self.abort_requested.set() + + def run(self): + self.started.set() + assert self.abort_requested.wait(1), "parallel sibling never received an abort request" + self.finished.set() + yield GraphRunAbortedEvent(reason=self.abort_reason) + + +def _build_immediate_abort_graph_engine( + *, + index: int, + total_tokens: int, + wait_before_abort: Event | None = None, +) -> SimpleNamespace: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + started = Event() + finished = Event() + + def run(): + started.set() + if wait_before_abort is not None: + assert wait_before_abort.wait(1), "parallel sibling never started" + finished.set() + yield GraphRunAbortedEvent(reason="quota exceeded") + + return SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ), + run=run, + request_abort=lambda reason=None: None, + started=started, + finished=finished, + ) + + +def _build_iteration_node( + *, + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, + is_parallel: bool = False, +) -> IterationNode: + node = IterationNode.__new__(IterationNode) + node._node_id = "iteration-node" + node._node_data = IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id="child-start", + is_parallel=is_parallel, + parallel_nums=2, + error_handle_mode=error_handle_mode, + ) + + variable_pool = build_test_variable_pool() + variable_pool.add(["start", "items"], ["first", "second"]) + node.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=LLMUsage.empty_usage(), + ) + return node + + +def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: + node = _build_iteration_node() + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], 0) + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): + list( + node._run_single_iter( + variable_pool=variable_pool, + outputs=[], + graph_engine=graph_engine, + ) + ) + + +def test_iteration_run_fails_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + node._create_graph_engine.assert_called_once() + node._run_single_iter.assert_called_once() + + +def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=_usage_with_tokens(7), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 + + +@pytest.mark.parametrize( + "error_handle_mode", + [ + ErrorHandleMode.CONTINUE_ON_ERROR, + ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, + ], +) +def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( + error_handle_mode: ErrorHandleMode, +) -> None: + node = _build_iteration_node( + error_handle_mode=error_handle_mode, + is_parallel=True, + ) + blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) + aborting_engine = _build_immediate_abort_graph_engine( + index=0, + total_tokens=3, + wait_before_abort=blocking_engine.started, + ) + node._create_graph_engine = MagicMock( + side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] + ) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + assert events[-1].node_run_result.llm_usage.total_tokens == 8 + assert node.graph_runtime_state.llm_usage.total_tokens == 8 + assert blocking_engine.started.is_set() + assert blocking_engine.abort_requested.is_set() + assert blocking_engine.finished.is_set() + assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py index 78ecb279a5..9931eaf39a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py @@ -22,7 +22,15 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: ) node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - def fake_execute_single_iteration_parallel(*, index: int, item: object): + def fake_execute_tracked_iteration_parallel( + *, + index: int, + item: object, + started_child_engines: dict[int, object], + started_child_engines_lock: object, + ): + _ = started_child_engines + _ = started_child_engines_lock return ( 0.1 + (index * 0.1), [ @@ -37,7 +45,7 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: LLMUsage.empty_usage(), ) - node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel + node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel outputs: list[object] = [] iter_run_map: dict[str, float] = {} diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py index 6372583839..77ec5ac128 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -1,6 +1,22 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph_events import GraphRunAbortedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent from dify_graph.nodes.loop.entities import LoopNodeData from dify_graph.nodes.loop.loop_node import LoopNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: @@ -50,3 +66,85 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf ) assert seen_configs == [child_node_config] + + +def test_run_single_loop_raises_on_child_abort_event() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(RuntimeError, match="quota exceeded"): + list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) + + +def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=2, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), + ) + create_graph_engine = MagicMock(return_value=aborting_engine) + node._create_graph_engine = create_graph_engine + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], LoopStartedEvent) + assert isinstance(events[1], LoopFailedEvent) + assert events[1].error == "quota exceeded" + assert isinstance(events[2], StreamCompletedEvent) + assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[2].node_run_result.error == "quota exceeded" + create_graph_engine.assert_called_once() + + +def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), + ) + node._create_graph_engine = MagicMock(return_value=aborting_engine) + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index b07ac5cc2d..7df44ca49d 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -1,4 +1,5 @@ from collections import UserString +from contextlib import nullcontext from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel @@ -6,21 +7,58 @@ import pytest from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.model_manager import ModelInstance from core.workflow import workflow_entry +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import NodeType +from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.file.models import File +from dify_graph.graph import Graph from dify_graph.graph_events import GraphRunFailedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.nodes.base.node import Node from dify_graph.runtime import ChildGraphNotFoundError +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def _build_typed_node_config(node_type: NodeType): return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}}) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + +class _FakeModelNodeMixin: + @classmethod + def version(cls) -> str: + return "1" + + def post_init(self) -> None: + self.model_instance, self.raw_model_instance = _build_wrapped_model_instance() + self.usage_snapshot = LLMUsage.empty_usage() + self.usage_snapshot.total_tokens = 1 + + def _run(self) -> NodeRunResult: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + llm_usage=self.usage_snapshot, + ) + + +class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]): + node_type = BuiltinNodeTypes.LLM + + +class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]): + node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER + + class TestWorkflowChildEngineBuilder: @pytest.mark.parametrize( ("graph_config", "node_id", "expected"), @@ -53,7 +91,7 @@ class TestWorkflowChildEngineBuilder: root_node_id="missing", ) - def test_build_child_engine_constructs_graph_engine_without_layers(self): + def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self): builder = workflow_entry._WorkflowChildEngineBuilder() graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]}) parent_graph_runtime_state = SimpleNamespace( @@ -76,6 +114,7 @@ class TestWorkflowChildEngineBuilder: patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer), ): result = builder.build_child_engine( workflow_id="workflow-id", @@ -108,7 +147,71 @@ class TestWorkflowChildEngineBuilder: config=sentinel.graph_engine_config, child_engine_builder=builder, ) - child_engine.layer.assert_not_called() + assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})] + + @pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode]) + def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls): + builder = workflow_entry._WorkflowChildEngineBuilder() + graph_init_params = build_test_graph_init_params( + graph_config={"nodes": [{"id": "root"}], "edges": []}, + ) + parent_graph_runtime_state = SimpleNamespace( + execution_context=nullcontext(None), + variable_pool=build_test_variable_pool(), + ) + created_node: dict[str, _FakeLLMNode | _FakeQuestionClassifierNode] = {} + + def build_graph(*, graph_config, node_factory, root_node_id): + _ = graph_config + node = node_cls( + id=root_node_id, + config={ + "id": root_node_id, + "data": { + "type": node_cls.node_type, + "title": "Child Model", + }, + }, + graph_init_params=node_factory.graph_init_params, + graph_runtime_state=node_factory.graph_runtime_state, + ) + created_node["node"] = node + return Graph( + nodes={root_node_id: node}, + edges={}, + in_edges={}, + out_edges={}, + root_node=node, + ) + + with ( + patch.object( + workflow_entry, + "DifyNodeFactory", + side_effect=lambda graph_init_params, graph_runtime_state: SimpleNamespace( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ), + ), + patch.object(workflow_entry.Graph, "init", side_effect=build_graph), + patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota, + patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota, + ): + child_engine = builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=graph_init_params, + parent_graph_runtime_state=parent_graph_runtime_state, + root_node_id="root", + ) + list(child_engine.run()) + + node = created_node["node"] + ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance) + deduct_quota.assert_called_once_with( + tenant_id="tenant", + model_instance=node.raw_model_instance, + usage=node.usage_snapshot, + ) class TestWorkflowEntryInit: