mirror of
https://github.com/langgenius/dify.git
synced 2026-03-29 08:00:43 -04:00
feat(graph_engine): Support pausing workflow graph executions (#26585)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@@ -111,9 +111,11 @@ class RedisChannel:
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand.model_validate(data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@@ -5,10 +5,11 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
]
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -16,17 +12,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
|
||||
|
||||
@final
|
||||
class PauseCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.pause(command.reason)
|
||||
|
||||
@@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel):
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@@ -103,6 +105,8 @@ class GraphExecution:
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: str | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@@ -126,6 +130,17 @@ class GraphExecution:
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: str | None = None) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
@@ -140,7 +155,12 @@ class GraphExecution:
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted
|
||||
return self.started and not self.completed and not self.aborted and not self.paused
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if the execution is currently paused."""
|
||||
return self.paused
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
@@ -173,6 +193,8 @@ class GraphExecution:
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
@@ -197,6 +219,8 @@ class GraphExecution:
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
||||
@@ -16,7 +16,6 @@ class CommandType(StrEnum):
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
RESUME = "resume"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
||||
|
||||
class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
||||
|
||||
@@ -8,8 +8,7 @@ from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -24,11 +23,13 @@ from core.workflow.graph_events import (
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
@@ -203,6 +204,18 @@ class EventHandler:
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason or "Awaiting human input"
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
|
||||
self._graph_runtime_state.register_paused_node(event.node_id)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
|
||||
@@ -97,6 +97,10 @@ class EventManager:
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""Notify registered layers about an event without buffering it."""
|
||||
self._notify_layers(event)
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
@@ -9,28 +9,29 @@ import contextvars
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor
|
||||
from .domain import GraphExecution
|
||||
from .entities.commands import AbortCommand
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -67,17 +71,16 @@ class GraphEngine:
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
if graph_runtime_state.graph_execution_json != "":
|
||||
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
|
||||
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
@@ -86,13 +89,7 @@ class GraphEngine:
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# === Execution Queues ===
|
||||
# Create ready queue from saved state or initialize new one
|
||||
self._ready_queue: ReadyQueue
|
||||
if self._graph_runtime_state.ready_queue_json == "":
|
||||
self._ready_queue = InMemoryReadyQueue()
|
||||
else:
|
||||
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
|
||||
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
@@ -103,11 +100,7 @@ class GraphEngine:
|
||||
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
if graph_runtime_state.response_coordinator_json != "":
|
||||
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
|
||||
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
@@ -133,19 +126,6 @@ class GraphEngine:
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
@@ -153,12 +133,12 @@ class GraphEngine:
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register abort command handler
|
||||
# Register command handlers
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
self._command_processor.register_handler(AbortCommand, abort_handler)
|
||||
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
@@ -191,12 +171,23 @@ class GraphEngine:
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
@@ -237,26 +228,41 @@ class GraphEngine:
|
||||
# Initialize layers
|
||||
self._initialize_layers()
|
||||
|
||||
# Start execution
|
||||
self._graph_execution.start()
|
||||
is_resume = self._graph_execution.started
|
||||
if not is_resume:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
# Start subsystems
|
||||
self._start_execution()
|
||||
self._start_execution(resume=is_resume)
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.aborted:
|
||||
if self._graph_execution.is_paused:
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=self._graph_execution.pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
yield paused_event
|
||||
elif self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
yield GraphRunAbortedEvent(
|
||||
aborted_event = GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(aborted_event)
|
||||
yield aborted_event
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
@@ -264,20 +270,26 @@ class GraphEngine:
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
partial_event = GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(partial_event)
|
||||
yield partial_event
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
succeeded_event = GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(succeeded_event)
|
||||
yield succeeded_event
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(
|
||||
failed_event = GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
self._event_manager.notify_layers(failed_event)
|
||||
yield failed_event
|
||||
raise
|
||||
|
||||
finally:
|
||||
@@ -299,8 +311,12 @@ class GraphEngine:
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
|
||||
def _start_execution(self) -> None:
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
@@ -309,10 +325,15 @@ class GraphEngine:
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
if not resume:
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
for node_id in paused_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Start dispatcher
|
||||
self._dispatcher.start()
|
||||
|
||||
@@ -7,9 +7,9 @@ intercept and respond to GraphEngine events.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
|
||||
410
api/core/workflow/graph_engine/layers/persistence.py
Normal file
410
api/core/workflow/graph_engine/layers/persistence.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""Workflow persistence layer for GraphEngine.
|
||||
|
||||
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
|
||||
listening to ``GraphEngineEvent`` instances directly and persisting workflow
|
||||
and node execution state via the injected repositories.
|
||||
|
||||
The design keeps domain persistence concerns inside the engine thread, while
|
||||
allowing presentation layers to remain read-only observers of repository
|
||||
state.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PersistenceWorkflowInfo:
|
||||
"""Static workflow metadata required for persistence."""
|
||||
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeRuntimeSnapshot:
|
||||
"""Lightweight cache to keep node metadata across event phases."""
|
||||
|
||||
node_id: str
|
||||
title: str
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
"""GraphEngine layer that persists workflow and node execution state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_info: PersistenceWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._trace_manager = trace_manager
|
||||
|
||||
self._workflow_execution: WorkflowExecution | None = None
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
|
||||
self._node_sequence: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._handle_graph_run_started()
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
self._handle_graph_run_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._handle_graph_run_partial_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
self._handle_graph_run_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunAbortedEvent):
|
||||
self._handle_graph_run_aborted(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
self._handle_node_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunExceptionEvent):
|
||||
self._handle_node_exception(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunPauseRequestedEvent):
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_graph_run_started(self) -> None:
|
||||
execution_id = self._get_execution_id()
|
||||
workflow_execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=self._prepare_workflow_inputs(),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
self._workflow_execution = workflow_execution
|
||||
|
||||
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.FAILED
|
||||
execution.error_message = event.error
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=event.error)
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.STOPPED
|
||||
execution.error_message = event.reason or "Workflow execution aborted"
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=execution.error_message or "")
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.error_message = event.reason or "Workflow execution paused"
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Node-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=event.id,
|
||||
node_execution_id=event.id,
|
||||
workflow_id=execution.workflow_id,
|
||||
workflow_execution_id=execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=self._next_node_sequence(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
metadata=metadata,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
self._node_execution_cache[event.id] = domain_execution
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
snapshot = _NodeRuntimeSnapshot(
|
||||
node_id=event.node_id,
|
||||
title=event.node_title,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
|
||||
domain_execution.error = event.error
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error=event.reason,
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_execution_id(self) -> str:
|
||||
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
|
||||
return str(workflow_execution_id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for field_name, value in self._system_variables().items():
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID.value:
|
||||
# Conversation IDs are tied to the current session; omit them so persisted
|
||||
# workflow inputs stay reusable without binding future runs to this conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return handled or {}
|
||||
|
||||
def _get_workflow_execution(self) -> WorkflowExecution:
|
||||
if self._workflow_execution is None:
|
||||
raise ValueError("workflow execution not initialized")
|
||||
return self._workflow_execution
|
||||
|
||||
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
if node_execution_id not in self._node_execution_cache:
|
||||
raise ValueError(f"Node execution not found for id={node_execution_id}")
|
||||
return self._node_execution_cache[node_execution_id]
|
||||
|
||||
def _next_node_sequence(self) -> int:
|
||||
self._node_sequence += 1
|
||||
return self._node_sequence
|
||||
|
||||
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
node_result: NodeRunResult,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
if update_outputs:
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=node_result.inputs,
|
||||
process_data=node_result.process_data,
|
||||
outputs=node_result.outputs,
|
||||
metadata=node_result.metadata,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
for execution in self._node_execution_cache.values():
|
||||
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
|
||||
execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
execution.error = error_message
|
||||
execution.finished_at = now
|
||||
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
|
||||
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
@@ -9,7 +9,7 @@ Supports stop, pause, and resume operations.
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop, pause, and resume operations.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -32,19 +32,29 @@ class GraphEngineManager:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Create Redis channel for this task
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
# Create and send abort command
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
|
||||
try:
|
||||
channel.send_command(abort_command)
|
||||
channel.send_command(command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy stop flag mechanism will still work
|
||||
# The legacy control mechanisms will still work
|
||||
pass
|
||||
|
||||
@@ -33,6 +33,12 @@ class Dispatcher:
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
@@ -77,33 +83,41 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
commands_checked = False
|
||||
should_check_commands = False
|
||||
should_break = False
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
if self._should_check_commands(event):
|
||||
self._execution_coordinator.check_commands()
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Process commands even when no new events arrive so abort requests are not missed
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
should_check_commands = True
|
||||
should_break = True
|
||||
else:
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
should_check_commands = self._should_check_commands(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Process commands even when no new events arrive so abort requests are not missed
|
||||
should_check_commands = True
|
||||
time.sleep(0.1)
|
||||
|
||||
if should_check_commands and not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
# Check if execution is complete
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
commands_checked = True
|
||||
|
||||
if should_break:
|
||||
if not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
|
||||
@@ -2,17 +2,13 @@
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventManager
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
@@ -27,8 +23,6 @@ class ExecutionCoordinator:
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
@@ -38,15 +32,11 @@ class ExecutionCoordinator:
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
@@ -65,15 +55,24 @@ class ExecutionCoordinator:
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
# Treat paused, aborted, or failed executions as terminal states
|
||||
if self._graph_execution.is_paused:
|
||||
return True
|
||||
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Expose whether the underlying graph execution is paused."""
|
||||
return self._graph_execution.is_paused
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if self._graph_execution.is_paused:
|
||||
return
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
@@ -85,3 +84,21 @@ class ExecutionCoordinator:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
|
||||
def handle_pause_if_needed(self) -> None:
|
||||
"""If the execution has been paused, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.is_paused:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def handle_abort_if_needed(self) -> None:
|
||||
"""If the execution has been aborted, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.aborted:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
@@ -14,11 +14,11 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
Reference in New Issue
Block a user