feat(graph_engine): Support pausing workflow graph executions (#26585)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-10-19 21:33:41 +08:00
committed by GitHub
parent 9a5f214623
commit 578247ffbc
112 changed files with 3766 additions and 2415 deletions

View File

@@ -447,6 +447,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@@ -466,8 +468,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
@@ -483,6 +483,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
"""
Generate worker in a new thread.
@@ -538,6 +540,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow=workflow,
system_user_id=system_user_id,
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
@@ -570,8 +574,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@@ -584,7 +586,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message: message
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@@ -596,8 +597,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message=message,
user=user,
dialogue_count=self._dialogue_count,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)

View File

@@ -23,8 +23,12 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -55,6 +59,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow,
system_user_id: str,
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
super().__init__(
queue_manager=queue_manager,
@@ -68,11 +74,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._workflow = workflow
self.system_user_id = system_user_id
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
system_inputs = SystemVariable(
query=self.application_generate_entity.query,
files=self.application_generate_entity.files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
@@ -89,7 +108,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
@@ -114,17 +132,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = self._initialize_conversation_variables()
# Create a variable pool.
system_inputs = SystemVariable(
query=query,
files=files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
@@ -172,6 +179,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@@ -60,14 +61,11 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
@@ -77,7 +75,7 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline:
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@@ -92,8 +90,6 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
@@ -113,31 +109,20 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables=SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
self._workflow_system_variables = SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._task_state = WorkflowTaskState()
@@ -156,6 +141,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@@ -288,12 +275,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
@@ -304,21 +285,28 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(
self,
event: QueueWorkflowStartedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_run_id = run_id
with self._database_session() as session:
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_execution.id_
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
message.workflow_run_id = run_id
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
)
yield workflow_start_resp
@@ -326,13 +314,9 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
@@ -344,14 +328,9 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_resp:
@@ -367,14 +346,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
@@ -385,16 +362,13 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
@@ -504,29 +478,19 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
@@ -535,30 +499,20 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
@@ -567,32 +521,25 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.FAILED,
graph_runtime_state=validated_state,
error=event.error,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
@@ -607,25 +554,23 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
if self._workflow_run_id and graph_runtime_state:
_ = trace_manager
resolved_state = None
if self._workflow_run_id:
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
if self._workflow_run_id and resolved_state:
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.STOPPED,
graph_runtime_state=resolved_state,
error=event.get_stop_reason(),
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowExecutionStatus.STOPPED,
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp
elif event.stopped_by in (
@@ -647,7 +592,7 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
@@ -661,7 +606,7 @@ class AdvancedChatAppGenerateTaskPipeline:
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@@ -670,10 +615,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
with self._database_session() as session:
message = self._get_message(session=session)
message.message_metadata = self._task_state.metadata.model_dump_json()
return
yield # Make this a generator
@@ -682,10 +623,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
with self._database_session() as session:
message = self._get_message(session=session)
message.message_metadata = self._task_state.metadata.model_dump_json()
return
yield # Make this a generator
@@ -739,7 +676,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: Any,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
@@ -752,7 +688,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -769,7 +704,6 @@ class AdvancedChatAppGenerateTaskPipeline:
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -788,15 +722,12 @@ class AdvancedChatAppGenerateTaskPipeline:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state: GraphRuntimeState | None = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueErrorEvent():
@@ -804,15 +735,11 @@ class AdvancedChatAppGenerateTaskPipeline:
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
)
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
)
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break
# Handle all other events through elegant dispatch
@@ -820,7 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -878,6 +804,12 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.

View File

@@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
WorkflowQueueMessage,
)
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class AppQueueManager:
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
self._graph_runtime_state: GraphRuntimeState | None = None
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
@@ -109,6 +111,16 @@ class AppQueueManager:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
@property
def graph_runtime_state(self) -> GraphRuntimeState | None:
"""Retrieve the attached graph runtime state, if available."""
return self._graph_runtime_state
@graph_runtime_state.setter
def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None:
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@@ -0,0 +1,55 @@
"""Shared helpers for managing GraphRuntimeState across task pipelines."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.workflow.runtime import GraphRuntimeState
if TYPE_CHECKING:
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
class GraphRuntimeStateSupport:
"""
Mixin that centralises common GraphRuntimeState access patterns used by task pipelines.
Subclasses are expected to provide:
* `_base_task_pipeline` exposing the queue manager with an optional cached runtime state.
* `_graph_runtime_state` attribute used as the local cache for the runtime state.
"""
_base_task_pipeline: BasedGenerateTaskPipeline
_graph_runtime_state: GraphRuntimeState | None = None
def _ensure_graph_runtime_initialized(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Validate and return the active graph runtime state."""
return self._resolve_graph_runtime_state(graph_runtime_state)
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
system_variables = graph_runtime_state.variable_pool.system_variables
if not system_variables or not system_variables.workflow_execution_id:
raise ValueError("workflow_execution_id missing from runtime state")
return str(system_variables.workflow_execution_id)
def _resolve_graph_runtime_state(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Return the cached runtime state or bootstrap it from the queue manager."""
if graph_runtime_state is not None:
self._graph_runtime_state = graph_runtime_state
return graph_runtime_state
if self._graph_runtime_state is None:
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
if self._graph_runtime_state is None:
raise ValueError("graph runtime state not initialized.")
return self._graph_runtime_state

View File

@@ -1,9 +1,8 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Union
from sqlalchemy.orm import Session
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@@ -39,16 +38,36 @@ from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import (
NodeType,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
EndUser,
)
from models import Account, EndUser
from services.variable_truncator import VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
title: str
index: int
start_at: datetime
iteration_id: str = ""
"""Empty string means the node is not executing inside an iteration."""
loop_id: str = ""
"""Empty string means the node is not executing inside a loop."""
class WorkflowResponseConverter:
def __init__(
@@ -56,37 +75,151 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
system_variables: SystemVariable,
):
self._application_generate_entity = application_generate_entity
self._user = user
self._system_variables = system_variables
self._workflow_inputs = self._prepare_workflow_inputs()
self._truncator = VariableTruncator.default()
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
self._workflow_execution_id: str | None = None
self._workflow_started_at: datetime | None = None
# ------------------------------------------------------------------
# Workflow lifecycle helpers
# ------------------------------------------------------------------
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = dict(self._application_generate_entity.inputs)
for field_name, value in self._system_variables.to_dict().items():
# TODO(@future-refactor): store system variables separately from user inputs so we don't
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
if field_name == SystemVariableKey.CONVERSATION_ID:
# Conversation IDs are session-scoped; omitting them keeps workflow inputs
# reusable without pinning new runs to a prior conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return dict(handled or {})
def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str:
"""Return the memoized workflow run id, optionally seeding it during start events."""
if workflow_run_id is not None:
self._workflow_execution_id = workflow_run_id
if not self._workflow_execution_id:
raise ValueError("workflow_run_id missing before streaming workflow events")
return self._workflow_execution_id
# ------------------------------------------------------------------
# Node snapshot helpers
# ------------------------------------------------------------------
def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot:
snapshot = _NodeSnapshot(
title=event.node_title,
index=event.node_run_index,
start_at=event.start_at,
iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "",
)
node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot
return snapshot
def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.get(NodeExecutionId(node_execution_id))
def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None)
@staticmethod
def _merge_metadata(
base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None,
snapshot: _NodeSnapshot | None,
) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None:
if not base_metadata and not snapshot:
return base_metadata
merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if base_metadata:
merged.update(base_metadata)
if snapshot:
if snapshot.iteration_id:
merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id
if snapshot.loop_id:
merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id
return merged or None
def _truncate_mapping(
self,
mapping: Mapping[str, Any] | None,
) -> tuple[Mapping[str, Any] | None, bool]:
if mapping is None:
return None, False
if not mapping:
return {}, False
normalized = WorkflowEntry.handle_special_values(dict(mapping))
if normalized is None:
return None, False
truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized))
return truncated, is_truncated
@staticmethod
def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
if outputs is None:
return None
converter = WorkflowRuntimeTypeConverter()
return converter.to_json_encodable(outputs)
def workflow_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution: WorkflowExecution,
workflow_run_id: str,
workflow_id: str,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
self._workflow_started_at = started_at
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id_,
workflow_run_id=run_id,
data=WorkflowStartStreamResponse.Data(
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
inputs=workflow_execution.inputs,
created_at=int(workflow_execution.started_at.timestamp()),
id=run_id,
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
),
)
def workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_execution: WorkflowExecution,
workflow_id: str,
status: WorkflowExecutionStatus,
graph_runtime_state: GraphRuntimeState,
error: str | None = None,
exceptions_count: int = 0,
) -> WorkflowFinishStreamResponse:
created_by = None
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_finish_to_stream_response called before workflow_start_to_stream_response",
)
finished_at = naive_utc_now()
elapsed_time = (finished_at - started_at).total_seconds()
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
user = self._user
if isinstance(user, Account):
created_by = {
@@ -94,38 +227,29 @@ class WorkflowResponseConverter:
"name": user.name,
"email": user.email,
}
elif isinstance(user, EndUser):
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_execution.finished_at.timestamp())
if workflow_execution.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id_,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
total_steps=workflow_execution.total_steps,
id=run_id,
workflow_id=workflow_id,
status=status.value,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
created_by=created_by,
created_at=int(workflow_execution.started_at.timestamp()),
finished_at=finished_at_timestamp,
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
exceptions_count=workflow_execution.exceptions_count,
created_at=int(started_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(outputs_mapping),
exceptions_count=exceptions_count,
),
)
@@ -134,38 +258,28 @@ class WorkflowResponseConverter:
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeStartStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=snapshot.title,
index=snapshot.index,
created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
@@ -189,41 +303,54 @@ class WorkflowResponseConverter:
*,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeFinishStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
json_converter = WorkflowRuntimeTypeConverter()
start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value
error_message = event.error
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index if snapshot else 0,
title=snapshot.title if snapshot else "",
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=status,
error=error_message,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@@ -234,44 +361,45 @@ class WorkflowResponseConverter:
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
) -> NodeRetryStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
json_converter = WorkflowRuntimeTypeConverter()
snapshot = self._get_snapshot(event.node_execution_id)
if snapshot is None:
raise AssertionError("node retry event arrived without a stored snapshot")
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index,
title=snapshot.title,
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(snapshot.start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
@@ -379,8 +507,6 @@ class WorkflowResponseConverter:
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@@ -405,9 +531,6 @@ class WorkflowResponseConverter:
pre_loop_output={},
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@@ -446,8 +569,6 @@ class WorkflowResponseConverter:
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)

View File

@@ -352,6 +352,8 @@ class PipelineGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@@ -367,8 +369,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
@@ -573,6 +573,8 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
@@ -620,6 +622,8 @@ class PipelineGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
runner.run()
@@ -648,8 +652,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -660,7 +662,6 @@ class PipelineGenerator(BaseAppGenerator):
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@@ -670,8 +671,6 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)

View File

@@ -11,11 +11,14 @@ from core.app.entities.app_invoke_entities import (
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -40,6 +43,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
@@ -56,6 +61,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
@@ -163,6 +170,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_pool=variable_pool,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@@ -231,6 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@@ -244,8 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
@@ -424,6 +424,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
"""
Generate worker in a new thread.
@@ -465,6 +467,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
@@ -493,8 +497,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -514,8 +516,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)

View File

@@ -5,12 +5,13 @@ from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@@ -34,6 +35,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
super().__init__(
queue_manager=queue_manager,
@@ -43,6 +46,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self.application_generate_entity = application_generate_entity
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
"""
@@ -51,6 +56,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@@ -60,18 +73,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@@ -96,6 +100,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
self._queue_manager.graph_runtime_state = graph_runtime_state
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
@@ -115,6 +121,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@@ -8,11 +8,9 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
@@ -53,27 +51,20 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
)
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline:
class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@@ -85,8 +76,6 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
@@ -99,42 +88,30 @@ class WorkflowAppGenerateTaskPipeline:
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
else:
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables=SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._workflow_execution_id = ""
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
self._workflow = workflow
self._workflow_system_variables = SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -261,15 +238,9 @@ class WorkflowAppGenerateTaskPipeline:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
if not self._workflow_execution_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
@@ -283,12 +254,14 @@ class WorkflowAppGenerateTaskPipeline:
self, event: QueueWorkflowStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
)
yield start_resp
@@ -296,14 +269,9 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
@@ -315,13 +283,9 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
@@ -331,14 +295,12 @@ class WorkflowAppGenerateTaskPipeline:
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_success_response:
yield node_success_response
@@ -349,17 +311,13 @@ class WorkflowAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_failed_response:
yield node_failed_response
@@ -372,7 +330,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_start_resp
@@ -385,7 +343,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_next_resp
@@ -398,7 +356,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_finish_resp
@@ -409,7 +367,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_start_resp
@@ -420,7 +378,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_next_resp
@@ -433,7 +391,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_finish_resp
@@ -442,33 +400,22 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@@ -476,34 +423,23 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@@ -511,37 +447,33 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
if isinstance(event, QueueWorkflowFailedEvent):
status = WorkflowExecutionStatus.FAILED
error = event.error
exceptions_count = event.exceptions_count
else:
status = WorkflowExecutionStatus.STOPPED
error = event.get_stop_reason()
exceptions_count = 0
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=status,
graph_runtime_state=validated_state,
error=error,
exceptions_count=exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@@ -601,7 +533,6 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: AppQueueEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
@@ -614,7 +545,6 @@ class WorkflowAppGenerateTaskPipeline:
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -631,7 +561,6 @@ class WorkflowAppGenerateTaskPipeline:
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -642,7 +571,6 @@ class WorkflowAppGenerateTaskPipeline:
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -661,15 +589,12 @@ class WorkflowAppGenerateTaskPipeline:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 44-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
@@ -681,12 +606,19 @@ class WorkflowAppGenerateTaskPipeline:
yield from self._handle_error_event(event)
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
# Handle all other events through elegant dispatch
case _:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@@ -697,7 +629,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@@ -709,11 +641,14 @@ class WorkflowAppGenerateTaskPipeline:
# not save log for debugging
return
if not workflow_run_id:
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
workflow_app_log.workflow_id = workflow_execution.workflow_id
workflow_app_log.workflow_run_id = workflow_execution.id_
workflow_app_log.workflow_id = self._workflow.id
workflow_app_log.workflow_run_id = workflow_run_id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id

View File

@@ -25,7 +25,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
@@ -54,6 +54,7 @@ from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
@@ -346,9 +347,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
)
self._publish_event(QueueWorkflowStartedEvent())
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@@ -372,7 +371,6 @@ class WorkflowBasedAppRunner:
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
inputs=inputs,
@@ -393,7 +391,6 @@ class WorkflowBasedAppRunner:
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=event.agent_strategy,
@@ -494,7 +491,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)
@@ -536,7 +532,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)

View File

@@ -3,11 +3,11 @@ from datetime import datetime
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@@ -54,6 +54,7 @@ class AppQueueEvent(BaseModel):
"""
event: QueueEvent
model_config = ConfigDict(arbitrary_types_allowed=True)
class QueueLLMChunkEvent(AppQueueEvent):
@@ -80,7 +81,6 @@ class QueueIterationStartEvent(AppQueueEvent):
node_run_index: int
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
@@ -132,19 +132,10 @@ class QueueLoopStartEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
@@ -160,16 +151,6 @@ class QueueLoopNextEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Any = None # output for the current loop
@@ -185,14 +166,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
@@ -285,12 +258,9 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
@@ -334,15 +304,9 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_title: str
node_type: NodeType
node_run_index: int = 1 # FIXME(-LAN-): may not used
predecessor_node_id: str | None = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
parallel_mode_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
@@ -360,14 +324,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
@@ -423,14 +379,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
@@ -455,7 +403,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None

View File

@@ -257,13 +257,8 @@ class NodeStartStreamResponse(StreamResponse):
inputs_truncated: bool = False
created_at: int
extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
parallel_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
@@ -285,10 +280,6 @@ class NodeStartStreamResponse(StreamResponse):
"inputs": None,
"created_at": self.data.created_at,
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
},
@@ -324,10 +315,6 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
@@ -357,10 +344,6 @@ class NodeFinishStreamResponse(StreamResponse):
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
},
@@ -396,10 +379,6 @@ class NodeRetryStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
retry_index: int = 0
@@ -430,10 +409,6 @@ class NodeRetryStreamResponse(StreamResponse):
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"retry_index": self.data.retry_index,
@@ -541,8 +516,6 @@ class LoopNodeStartStreamResponse(StreamResponse):
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@@ -567,9 +540,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
created_at: int
pre_loop_output: Any = None
extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str
@@ -603,8 +573,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_COMPLETED
workflow_run_id: str

View File

@@ -18,7 +18,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class AdvancedPromptTransform(PromptTransform):

View File

@@ -104,7 +104,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize in-memory cache for node executions
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
# Initialize FileService for handling offloaded data
@@ -332,17 +331,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
Args:
execution: The NodeExecution domain entity to persist
"""
# NOTE: As per the implementation of `WorkflowCycleManager`,
# the `save` method is invoked multiple times during the node's execution lifecycle, including:
#
# - When the node starts execution
# - When the node retries execution
# - When the node completes execution (either successfully or with failure)
#
# Only the final invocation will have `inputs` and `outputs` populated.
#
# This simplifies the logic for saving offloaded variables but introduces a tight coupling
# between this module and `WorkflowCycleManager`.
# NOTE: The workflow engine triggers `save` multiple times for a single node execution:
# when the node starts, any time it retries, and once more when it reaches a terminal state.
# Only the final call contains the complete inputs and outputs payloads, so earlier invocations
# must tolerate missing data without attempting to offload variables.
# Convert domain model to database model using tenant context and other attributes
db_model = self._to_db_model(execution)

View File

@@ -58,8 +58,8 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.nodes.tool.entities import ToolEntity
from core.workflow.runtime import VariablePool
logger = logging.getLogger(__name__)

View File

@@ -1,18 +1,11 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .run_condition import RunCondition
from .variable_pool import VariablePool, VariableValue
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"RunCondition",
"VariablePool",
"VariableValue",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@@ -1,160 +0,0 @@
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
_variable_pool: VariablePool = PrivateAttr()
_start_at: float = PrivateAttr()
_total_tokens: int = PrivateAttr(default=0)
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
_response_coordinator_json: str = PrivateAttr()
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
response_coordinator_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)
# Initialize private attributes with validation
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
if llm_usage is None:
llm_usage = LLMUsage.empty_usage()
self._llm_usage = llm_usage
if outputs is None:
outputs = {}
self._outputs = deepcopy(outputs)
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
self._response_coordinator_json = response_coordinator_json
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
return self._variable_pool
@property
def start_at(self) -> float:
"""Get the start time."""
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
"""Set the start time."""
self._start_at = value
@property
def total_tokens(self) -> int:
"""Get the total tokens count."""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int):
"""Set the total tokens count."""
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
"""Get the LLM usage info."""
# Return a copy to prevent external modification
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage):
"""Set the LLM usage info."""
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, object]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, object]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count."""
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
"""Set the node run steps count."""
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
"""Increment the node run steps by 1."""
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
"""Add tokens to the total count."""
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return self._ready_queue_json
@property
def graph_execution_json(self) -> str:
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@property
def response_coordinator_json(self) -> str:
"""Get a copy of the serialized response coordinator state."""
return self._response_coordinator_json

View File

@@ -1,21 +0,0 @@
import hashlib
from typing import Literal
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: str | None = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: list[Condition] | None = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@@ -58,6 +58,7 @@ class NodeType(StrEnum):
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
HUMAN_INPUT = "human-input"
class NodeExecutionType(StrEnum):
@@ -96,6 +97,7 @@ class WorkflowExecutionStatus(StrEnum):
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
PAUSED = "paused"
class WorkflowNodeExecutionMetadataKey(StrEnum):

View File

@@ -1,16 +1,11 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph import Graph, GraphBuilder, NodeFactory
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = [
"Edge",
"Graph",
"GraphBuilder",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@@ -195,6 +195,12 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@classmethod
def _mark_inactive_root_branches(
cls,
@@ -344,3 +350,96 @@ class Graph:
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
@final
class GraphBuilder:
"""Fluent helper for constructing simple graphs, primarily for tests."""
def __init__(self, *, graph_cls: type[Graph]):
self._graph_cls = graph_cls
self._nodes: list[Node] = []
self._nodes_by_id: dict[str, Node] = {}
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
"""Register the root node. Must be called exactly once."""
if self._nodes:
raise ValueError("Root node has already been added")
self._register_node(node)
self._nodes.append(node)
return self
def add_node(
self,
node: Node,
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
raise ValueError("Root node must be added before adding other nodes")
predecessor_id = from_node_id or self._nodes[-1].id
if predecessor_id not in self._nodes_by_id:
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
predecessor = self._nodes_by_id[predecessor_id]
self._register_node(node)
self._nodes.append(node)
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
self._edges.append(edge)
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
raise ValueError(f"Tail node '{tail}' not found")
if head not in self._nodes_by_id:
raise ValueError(f"Head node '{head}' not found")
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
self._edges.append(edge)
return self
def build(self) -> Graph:
"""Materialize the graph instance from the accumulated nodes and edges."""
if not self._nodes:
raise ValueError("Cannot build an empty graph")
nodes = {node.id: node for node in self._nodes}
edges = {edge.id: edge for edge in self._edges}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
out_edges[edge.tail].append(edge.id)
in_edges[edge.head].append(edge.id)
return self._graph_cls(
nodes=nodes,
edges=edges,
in_edges=dict(in_edges),
out_edges=dict(out_edges),
root_node=self._nodes[0],
)
def _register_node(self, node: Node) -> None:
if not node.id:
raise ValueError("Node must have a non-empty id")
if node.id in self._nodes_by_id:
raise ValueError(f"Duplicate node id detected: {node.id}")
self._nodes_by_id[node.id] = node

View File

@@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@@ -111,9 +111,11 @@ class RedisChannel:
if command_type == CommandType.ABORT:
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None

View File

@@ -5,10 +5,11 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
]

View File

@@ -1,14 +1,10 @@
"""
Command handler implementations.
"""
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@@ -16,17 +12,17 @@ logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")
@final
class PauseCommandHandler(CommandHandler):
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, PauseCommand)
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
execution.pause(command.reason)

View File

@@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel):
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: str | None = Field(default=None)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@@ -103,6 +105,8 @@ class GraphExecution:
started: bool = False
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: str | None = None
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@@ -126,6 +130,17 @@ class GraphExecution:
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def pause(self, reason: str | None = None) -> None:
"""Pause the graph execution without marking it complete."""
if self.completed:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True
self.pause_reason = reason
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
@@ -140,7 +155,12 @@ class GraphExecution:
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted
return self.started and not self.completed and not self.aborted and not self.paused
@property
def is_paused(self) -> bool:
"""Check if the execution is currently paused."""
return self.paused
@property
def has_error(self) -> bool:
@@ -173,6 +193,8 @@ class GraphExecution:
started=self.started,
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reason=self.pause_reason,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@@ -197,6 +219,8 @@ class GraphExecution:
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reason = state.pause_reason
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {

View File

@@ -16,7 +16,6 @@ class CommandType(StrEnum):
ABORT = "abort"
PAUSE = "pause"
RESUME = "resume"
class GraphEngineCommand(BaseModel):
@@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for abort")
class PauseCommand(GraphEngineCommand):
"""Command to pause a running workflow execution."""
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for pause")

View File

@@ -8,8 +8,7 @@ from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -24,11 +23,13 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.runtime import GraphRuntimeState
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
@@ -203,6 +204,18 @@ class EventHandler:
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunPauseRequestedEvent) -> None:
"""Handle pause requests emitted by nodes."""
pause_reason = event.reason or "Awaiting human input"
self._graph_execution.pause(pause_reason)
self._state_manager.finish_execution(event.node_id)
if event.node_id in self._graph.nodes:
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
self._graph_runtime_state.register_paused_node(event.node_id)
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunFailedEvent) -> None:
"""

View File

@@ -97,6 +97,10 @@ class EventManager:
"""
self._layers = layers
def notify_layers(self, event: GraphEngineEvent) -> None:
"""Notify registered layers about an event without buffering it."""
self._notify_layers(event)
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.

View File

@@ -9,28 +9,29 @@ import contextvars
import logging
import queue
from collections.abc import Generator
from typing import final
from typing import TYPE_CHECKING, cast, final
from flask import Flask, current_app
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
from .command_processing import AbortCommandHandler, CommandProcessor
from .domain import GraphExecution
from .entities.commands import AbortCommand
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
from .response_coordinator import ResponseStreamCoordinator
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
logger = logging.getLogger(__name__)
@@ -67,17 +71,16 @@ class GraphEngine:
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# Graph execution tracks the overall execution state
self._graph_execution = GraphExecution(workflow_id=workflow_id)
if graph_runtime_state.graph_execution_json != "":
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
# === Core Dependencies ===
# Graph structure and configuration
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
self._graph_execution.workflow_id = workflow_id
# === Worker Management Parameters ===
# Parameters for dynamic worker pool scaling
self._min_workers = min_workers
@@ -86,13 +89,7 @@ class GraphEngine:
self._scale_down_idle_time = scale_down_idle_time
# === Execution Queues ===
# Create ready queue from saved state or initialize new one
self._ready_queue: ReadyQueue
if self._graph_runtime_state.ready_queue_json == "":
self._ready_queue = InMemoryReadyQueue()
else:
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
@@ -103,11 +100,7 @@ class GraphEngine:
# === Response Coordination ===
# Coordinates response streaming from response nodes
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
if graph_runtime_state.response_coordinator_json != "":
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
# === Event Management ===
# Event manager handles both collection and emission of events
@@ -133,19 +126,6 @@ class GraphEngine:
skip_propagator=self._skip_propagator,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# === Command Processing ===
# Processes external commands (e.g., abort requests)
self._command_processor = CommandProcessor(
@@ -153,12 +133,12 @@ class GraphEngine:
graph_execution=self._graph_execution,
)
# Register abort command handler
# Register command handlers
abort_handler = AbortCommandHandler()
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
self._command_processor.register_handler(AbortCommand, abort_handler)
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
@@ -191,12 +171,23 @@ class GraphEngine:
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# Dispatches events and manages execution flow
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
@@ -237,26 +228,41 @@ class GraphEngine:
# Initialize layers
self._initialize_layers()
# Start execution
self._graph_execution.start()
is_resume = self._graph_execution.started
if not is_resume:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reason = None
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
yield start_event
# Start subsystems
self._start_execution()
self._start_execution(resume=is_resume)
# Yield events as they occur
yield from self._event_manager.emit_events()
# Handle completion
if self._graph_execution.aborted:
if self._graph_execution.is_paused:
paused_event = GraphRunPausedEvent(
reason=self._graph_execution.pause_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)
yield paused_event
elif self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
aborted_event = GraphRunAbortedEvent(
reason=abort_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(aborted_event)
yield aborted_event
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
@@ -264,20 +270,26 @@ class GraphEngine:
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
partial_event = GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
self._event_manager.notify_layers(partial_event)
yield partial_event
else:
yield GraphRunSucceededEvent(
succeeded_event = GraphRunSucceededEvent(
outputs=outputs,
)
self._event_manager.notify_layers(succeeded_event)
yield succeeded_event
except Exception as e:
yield GraphRunFailedEvent(
failed_event = GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
self._event_manager.notify_layers(failed_event)
yield failed_event
raise
finally:
@@ -299,8 +311,12 @@ class GraphEngine:
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
def _start_execution(self) -> None:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@@ -309,10 +325,15 @@ class GraphEngine:
if node.execution_type == NodeExecutionType.RESPONSE:
self._response_coordinator.register(node.id)
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
if not resume:
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Start dispatcher
self._dispatcher.start()

View File

@@ -7,9 +7,9 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayer(ABC):

View File

@@ -0,0 +1,410 @@
"""Workflow persistence layer for GraphEngine.
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
listening to ``GraphEngineEvent`` instances directly and persisting workflow
and node execution state via the injected repositories.
The design keeps domain persistence concerns inside the engine thread, while
allowing presentation layers to remain read-only observers of repository
state.
"""
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@dataclass(slots=True)
class PersistenceWorkflowInfo:
"""Static workflow metadata required for persistence."""
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
@dataclass(slots=True)
class _NodeRuntimeSnapshot:
"""Lightweight cache to keep node metadata across event phases."""
node_id: str
title: str
predecessor_node_id: str | None
iteration_id: str | None
loop_id: str | None
created_at: datetime
class WorkflowPersistenceLayer(GraphEngineLayer):
"""GraphEngine layer that persists workflow and node execution state."""
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_info: PersistenceWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
trace_manager: TraceQueueManager | None = None,
) -> None:
super().__init__()
self._application_generate_entity = application_generate_entity
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._trace_manager = trace_manager
self._workflow_execution: WorkflowExecution | None = None
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
self._node_sequence: int = 0
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self._handle_graph_run_started()
return
if isinstance(event, GraphRunSucceededEvent):
self._handle_graph_run_succeeded(event)
return
if isinstance(event, GraphRunPartialSucceededEvent):
self._handle_graph_run_partial_succeeded(event)
return
if isinstance(event, GraphRunFailedEvent):
self._handle_graph_run_failed(event)
return
if isinstance(event, GraphRunAbortedEvent):
self._handle_graph_run_aborted(event)
return
if isinstance(event, GraphRunPausedEvent):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return
if isinstance(event, NodeRunFailedEvent):
self._handle_node_failed(event)
return
if isinstance(event, NodeRunExceptionEvent):
self._handle_node_exception(event)
return
if isinstance(event, NodeRunPauseRequestedEvent):
self._handle_node_pause_requested(event)
def on_graph_end(self, error: Exception | None) -> None:
return
# ------------------------------------------------------------------
# Graph-level handlers
# ------------------------------------------------------------------
def _handle_graph_run_started(self) -> None:
execution_id = self._get_execution_id()
workflow_execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=self._prepare_workflow_inputs(),
started_at=naive_utc_now(),
)
self._workflow_execution_repository.save(workflow_execution)
self._workflow_execution = workflow_execution
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.SUCCEEDED
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.FAILED
execution.error_message = event.error
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=event.error)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.STOPPED
execution.error_message = event.reason or "Workflow execution aborted"
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=execution.error_message or "")
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.PAUSED
execution.error_message = event.reason or "Workflow execution paused"
execution.outputs = event.outputs
self._populate_completion_statistics(execution, update_finished=False)
self._workflow_execution_repository.save(execution)
# ------------------------------------------------------------------
# Node-level handlers
# ------------------------------------------------------------------
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
execution = self._get_workflow_execution()
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.id,
node_execution_id=event.id,
workflow_id=execution.workflow_id,
workflow_execution_id=execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=self._next_node_sequence(),
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=event.start_at,
)
self._node_execution_cache[event.id] = domain_execution
self._workflow_node_execution_repository.save(domain_execution)
snapshot = _NodeRuntimeSnapshot(
node_id=event.node_id,
title=event.node_title,
predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
domain_execution = self._get_node_execution(event.id)
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
domain_execution.error = event.error
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error,
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.PAUSED,
error=event.reason,
update_outputs=False,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_execution_id(self) -> str:
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
if not workflow_execution_id:
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
return str(workflow_execution_id)
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = {**self._application_generate_entity.inputs}
for field_name, value in self._system_variables().items():
if field_name == SystemVariableKey.CONVERSATION_ID.value:
# Conversation IDs are tied to the current session; omit them so persisted
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
def _get_workflow_execution(self) -> WorkflowExecution:
if self._workflow_execution is None:
raise ValueError("workflow execution not initialized")
return self._workflow_execution
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._node_execution_cache:
raise ValueError(f"Node execution not found for id={node_execution_id}")
return self._node_execution_cache[node_execution_id]
def _next_node_sequence(self) -> int:
self._node_sequence += 1
return self._node_sequence
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
def _update_node_execution(
self,
domain_execution: WorkflowNodeExecution,
node_result: NodeRunResult,
status: WorkflowNodeExecutionStatus,
*,
error: str | None = None,
update_outputs: bool = True,
) -> None:
finished_at = naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
if error:
domain_execution.error = error
if update_outputs:
domain_execution.update_from_mapping(
inputs=node_result.inputs,
process_data=node_result.process_data,
outputs=node_result.outputs,
metadata=node_result.metadata,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _fail_running_node_executions(self, *, error_message: str) -> None:
now = naive_utc_now()
for execution in self._node_execution_cache.values():
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
execution.status = WorkflowNodeExecutionStatus.FAILED
execution.error = error_message
execution.finished_at = now
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
self._workflow_node_execution_repository.save(execution)
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
if not self._trace_manager:
return
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
)
self._trace_manager.add_trace_task(trace_task)
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@@ -9,7 +9,7 @@ Supports stop, pause, and resume operations.
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from extensions.ext_redis import redis_client
@@ -20,7 +20,7 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop, pause, and resume operations.
Supports stop and pause operations.
"""
@staticmethod
@@ -32,19 +32,29 @@ class GraphEngineManager:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
GraphEngineManager._send_command(task_id, abort_command)
@staticmethod
def send_pause_command(task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
# Create Redis channel for this task
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
# Create and send abort command
abort_command = AbortCommand(reason=reason or "User requested stop")
try:
channel.send_command(abort_command)
channel.send_command(command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy stop flag mechanism will still work
# The legacy control mechanisms will still work
pass

View File

@@ -33,6 +33,12 @@ class Dispatcher:
with timeout and completion detection.
"""
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
@@ -77,33 +83,41 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for scaling
self._execution_coordinator.check_scaling()
commands_checked = False
should_check_commands = False
should_break = False
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
if self._should_check_commands(event):
self._execution_coordinator.check_commands()
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
if self._execution_coordinator.is_execution_complete():
should_check_commands = True
should_break = True
else:
# Check for scaling
self._execution_coordinator.check_scaling()
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
should_check_commands = self._should_check_commands(event)
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
should_check_commands = True
time.sleep(0.1)
if should_check_commands and not commands_checked:
self._execution_coordinator.check_commands()
# Check if execution is complete
if self._execution_coordinator.is_execution_complete():
break
commands_checked = True
if should_break:
if not commands_checked:
self._execution_coordinator.check_commands()
break
except Exception as e:
logger.exception("Dispatcher error")

View File

@@ -2,17 +2,13 @@
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING, final
from typing import final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventManager
from ..graph_state_manager import GraphStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandler
@final
class ExecutionCoordinator:
@@ -27,8 +23,6 @@ class ExecutionCoordinator:
self,
graph_execution: GraphExecution,
state_manager: GraphStateManager,
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
@@ -38,15 +32,11 @@ class ExecutionCoordinator:
Args:
graph_execution: Graph execution aggregate
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
@@ -65,15 +55,24 @@ class ExecutionCoordinator:
Returns:
True if execution is complete
"""
# Check if aborted or failed
# Treat paused, aborted, or failed executions as terminal states
if self._graph_execution.is_paused:
return True
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self._state_manager.is_execution_complete()
@property
def is_paused(self) -> bool:
"""Expose whether the underlying graph execution is paused."""
return self._graph_execution.is_paused
def mark_complete(self) -> None:
"""Mark execution as complete."""
if self._graph_execution.is_paused:
return
if not self._graph_execution.completed:
self._graph_execution.complete()
@@ -85,3 +84,21 @@ class ExecutionCoordinator:
error: The error that caused failure
"""
self._graph_execution.fail(error)
def handle_pause_if_needed(self) -> None:
"""If the execution has been paused, stop workers immediately."""
if not self._graph_execution.is_paused:
return
self._worker_pool.stop()
self._state_manager.clear_executing()
def handle_abort_if_needed(self) -> None:
"""If the execution has been aborted, stop workers immediately."""
if not self._graph_execution.aborted:
return
self._worker_pool.stop()
self._state_manager.clear_executing()

View File

@@ -14,11 +14,11 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
from .path import Path
from .session import ResponseSession

View File

@@ -13,6 +13,7 @@ from .graph import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@@ -37,6 +38,7 @@ from .loop import (
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
@@ -51,6 +53,7 @@ __all__ = [
"GraphRunAbortedEvent",
"GraphRunFailedEvent",
"GraphRunPartialSucceededEvent",
"GraphRunPausedEvent",
"GraphRunStartedEvent",
"GraphRunSucceededEvent",
"NodeRunAgentLogEvent",
@@ -64,6 +67,7 @@ __all__ = [
"NodeRunLoopNextEvent",
"NodeRunLoopStartedEvent",
"NodeRunLoopSucceededEvent",
"NodeRunPauseRequestedEvent",
"NodeRunRetrieverResourceEvent",
"NodeRunRetryEvent",
"NodeRunStartedEvent",

View File

@@ -8,7 +8,12 @@ class GraphRunStartedEvent(BaseGraphEvent):
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: dict[str, object] = Field(default_factory=dict)
"""Event emitted when a run completes successfully with final outputs."""
outputs: dict[str, object] = Field(
default_factory=dict,
description="Final workflow outputs keyed by output selector.",
)
class GraphRunFailedEvent(BaseGraphEvent):
@@ -17,12 +22,30 @@ class GraphRunFailedEvent(BaseGraphEvent):
class GraphRunPartialSucceededEvent(BaseGraphEvent):
"""Event emitted when a run finishes with partial success and failures."""
exceptions_count: int = Field(..., description="exception count")
outputs: dict[str, object] = Field(default_factory=dict)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs that were materialised before failures occurred.",
)
class GraphRunAbortedEvent(BaseGraphEvent):
"""Event emitted when a graph run is aborted by user command."""
reason: str | None = Field(default=None, description="reason for abort")
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs produced before the abort was requested.",
)
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
reason: str | None = Field(default=None, description="reason for pause")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",
)

View File

@@ -51,3 +51,7 @@ class NodeRunExceptionEvent(GraphNodeEventBase):
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")

View File

@@ -14,6 +14,7 @@ from .loop import (
)
from .node import (
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
@@ -33,6 +34,7 @@ __all__ = [
"ModelInvokeCompletedEvent",
"NodeEventBase",
"NodeRunResult",
"PauseRequestedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"StreamChunkEvent",

View File

@@ -40,3 +40,7 @@ class StreamChunkEvent(NodeEventBase):
class StreamCompletedEvent(NodeEventBase):
node_run_result: NodeRunResult = Field(..., description="run result")
class PauseRequestedEvent(NodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")

View File

@@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy

View File

@@ -6,7 +6,7 @@ from typing import Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -20,6 +20,7 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
@@ -37,10 +38,12 @@ from core.workflow.node_events import (
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
PauseRequestedEvent,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
@@ -385,6 +388,16 @@ class Node:
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
reason=event.reason,
)
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(

View File

@@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile

View File

@@ -15,7 +15,7 @@ from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import (
HttpRequestNodeAuthorization,

View File

@@ -0,0 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]

View File

@@ -0,0 +1,10 @@
from pydantic import Field
from core.workflow.nodes.base import BaseNodeData
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)

View File

@@ -0,0 +1,132 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self._node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None

View File

@@ -3,12 +3,12 @@ from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.entities import VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor

View File

@@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
@@ -38,6 +37,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
@@ -557,11 +557,12 @@ class IterationNode(Node):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@@ -9,13 +9,13 @@ from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

View File

@@ -67,7 +67,7 @@ from .exc import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation

View File

@@ -52,7 +52,7 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@@ -71,6 +71,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from . import llm_utils
from .entities import (
@@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@@ -406,11 +406,12 @@ class LoopNode(Node):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@final

View File

@@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
@@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,

View File

@@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData

View File

@@ -41,7 +41,7 @@ from .template_prompts import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):

View File

@@ -36,7 +36,7 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
class ToolNode(Node):

View File

@@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]

View File

@@ -0,0 +1,14 @@
from .graph_runtime_state import GraphRuntimeState
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
from .variable_pool import VariablePool, VariableValue
__all__ = [
"GraphRuntimeState",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
"VariablePool",
"VariableValue",
]

View File

@@ -0,0 +1,393 @@
from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping
from copy import deepcopy
from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime.variable_pool import VariablePool
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
def put(self, item: str) -> None:
"""Enqueue the identifier of a node that is ready to run."""
...
def get(self, timeout: float | None = None) -> str:
"""Return the next node identifier, blocking until available or timeout expires."""
...
def task_done(self) -> None:
"""Signal that the most recently dequeued node has completed processing."""
...
def empty(self) -> bool:
"""Return True when the queue contains no pending nodes."""
...
def qsize(self) -> int:
"""Approximate the number of pending nodes awaiting execution."""
...
def dumps(self) -> str:
"""Serialize the queue contents for persistence."""
...
def loads(self, data: str) -> None:
"""Restore the queue contents from a serialized payload."""
...
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
workflow_id: str
started: bool
completed: bool
aborted: bool
error: Exception | None
exceptions_count: int
def start(self) -> None:
"""Transition execution into the running state."""
...
def complete(self) -> None:
"""Mark execution as successfully completed."""
...
def abort(self, reason: str) -> None:
"""Abort execution in response to an external stop request."""
...
def fail(self, error: Exception) -> None:
"""Record an unrecoverable error and end execution."""
...
def dumps(self) -> str:
"""Serialize execution state into a JSON payload."""
...
def loads(self, data: str) -> None:
"""Restore execution state from a previously serialized payload."""
...
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""
def register(self, response_node_id: str) -> None:
"""Register a response node so its outputs can be streamed."""
...
def loads(self, data: str) -> None:
"""Restore coordinator state from a serialized payload."""
...
def dumps(self) -> str:
"""Serialize coordinator state for persistence."""
...
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: TypingMapping[str, object]
edges: TypingMapping[str, object]
root_node: object
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue: ReadyQueueProtocol | None = None,
graph_execution: GraphExecutionProtocol | None = None,
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
graph: GraphProtocol | None = None,
) -> None:
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
self._outputs = deepcopy(outputs) if outputs is not None else {}
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._graph: GraphProtocol | None = None
self._ready_queue = ready_queue
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
if graph is not None:
self.attach_graph(graph)
# ------------------------------------------------------------------
# Context binding helpers
# ------------------------------------------------------------------
def attach_graph(self, graph: GraphProtocol) -> None:
"""Attach the materialized graph to the runtime state."""
if self._graph is not None and self._graph is not graph:
raise ValueError("GraphRuntimeState already attached to a different graph instance")
self._graph = graph
if self._response_coordinator is None:
self._response_coordinator = self._build_response_coordinator(graph)
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
if graph is not None:
self.attach_graph(graph)
# Ensure collaborators are instantiated
_ = self.ready_queue
_ = self.graph_execution
if self._graph is not None:
_ = self.response_coordinator
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------
@property
def variable_pool(self) -> VariablePool:
return self._variable_pool
@property
def ready_queue(self) -> ReadyQueueProtocol:
if self._ready_queue is None:
self._ready_queue = self._build_ready_queue()
return self._ready_queue
@property
def graph_execution(self) -> GraphExecutionProtocol:
if self._graph_execution is None:
self._graph_execution = self._build_graph_execution()
return self._graph_execution
@property
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
if self._response_coordinator is None:
if self._graph is None:
raise ValueError("Graph must be attached before accessing response coordinator")
self._response_coordinator = self._build_response_coordinator(self._graph)
return self._response_coordinator
# ------------------------------------------------------------------
# Scalar state
# ------------------------------------------------------------------
@property
def start_at(self) -> float:
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
self._start_at = value
@property
def total_tokens(self) -> int:
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int) -> None:
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage) -> None:
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
def dumps(self) -> str:
"""Serialize runtime state into a JSON string."""
snapshot: dict[str, Any] = {
"version": "1.0",
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"llm_usage": self._llm_usage.model_dump(mode="json"),
"outputs": self.outputs,
"variable_pool": self.variable_pool.model_dump(mode="json"),
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
}
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
return json.dumps(snapshot, default=pydantic_encoder)
def loads(self, data: str | Mapping[str, Any]) -> None:
"""Restore runtime state from a serialized snapshot."""
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
self._start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
llm_usage_payload = payload.get("llm_usage", {})
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
self._outputs = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
if variable_pool_payload is not None:
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
ready_queue_payload = payload.get("ready_queue")
if ready_queue_payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(ready_queue_payload)
else:
self._ready_queue = None
graph_execution_payload = payload.get("graph_execution")
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if graph_execution_payload is not None:
try:
execution_payload = json.loads(graph_execution_payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(graph_execution_payload)
response_payload = payload.get("response_coordinator")
if response_payload is not None:
if self._graph is not None:
self.response_coordinator.loads(response_payload)
else:
self._pending_response_coordinator_dump = response_payload
else:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
paused_nodes_payload = payload.get("paused_nodes", [])
self._paused_nodes = set(map(str, paused_nodes_payload))
def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued."""
self._paused_nodes.add(node_id)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
nodes = list(self._paused_nodes)
self._paused_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
def _build_ready_queue(self) -> ReadyQueueProtocol:
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
in_memory_cls = module.InMemoryReadyQueue
return in_memory_cls()
def _build_graph_execution(self) -> GraphExecutionProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id)
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)

View File

@@ -16,6 +16,10 @@ class ReadOnlyVariablePool(Protocol):
"""Get all variables for a node (read-only)."""
...
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Get all variables stored under a given node prefix (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
@@ -56,6 +60,20 @@ class ReadOnlyGraphRuntimeState(Protocol):
"""Get the node run steps count (read-only)."""
...
@property
def ready_queue_size(self) -> int:
"""Get the number of nodes currently in the ready queue."""
...
@property
def exceptions_count(self) -> int:
"""Get the number of node execution exceptions recorded."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...
def dumps(self) -> str:
"""Serialize the runtime state into a JSON snapshot (read-only)."""
...

View File

@@ -1,77 +1,82 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from .graph_runtime_state import GraphRuntimeState
from .variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
"""Provide defensive, read-only access to ``VariablePool``."""
def __init__(self, variable_pool: VariablePool):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
"""Return a copy of all variables for the specified node."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
variables[key] = deepcopy(variable.value)
return variables
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given prefix."""
return self._variable_pool.get_by_prefix(prefix)
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
def __init__(self, state: GraphRuntimeState) -> None:
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
@property
def ready_queue_size(self) -> int:
return self._state.ready_queue.qsize()
@property
def exceptions_count(self) -> int:
return self._state.graph_execution.exceptions_count
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)
def dumps(self) -> str:
"""Serialize the underlying runtime state for external persistence."""
return self._state.dumps()

View File

@@ -1,6 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@@ -235,6 +236,20 @@ class VariablePool(BaseModel):
return segment
return None
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given node prefix."""
nodes = self.variable_dictionary.get(prefix)
if not nodes:
return {}
result: dict[str, object] = {}
for key, variable in nodes.items():
value = variable.value
result[key] = deepcopy(value)
return result
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():

View File

@@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator

View File

@@ -4,7 +4,7 @@ from typing import Any, Protocol
from core.variables import Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class VariableLoader(Protocol):

View File

@@ -1,459 +0,0 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.entities import (
WorkflowExecution,
WorkflowNodeExecution,
)
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
@dataclass
class CycleManagerWorkflowInfo:
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
class WorkflowCycleManager:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: SystemVariable,
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = self._prepare_workflow_inputs()
execution_id = self._get_or_generate_execution_id()
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
started_at=naive_utc_now(),
)
return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
self._update_workflow_execution_completion(
workflow_execution,
status=WorkflowExecutionStatus.SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
def handle_workflow_run_partial_success(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
self._update_workflow_execution_completion(
execution,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
exceptions_count=exceptions_count,
)
self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(execution)
return execution
def handle_workflow_run_failed(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
status: WorkflowExecutionStatus,
error_message: str,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
exceptions_count: int = 0,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
self._update_workflow_execution_completion(
workflow_execution,
status=status,
total_tokens=total_tokens,
total_steps=total_steps,
error_message=error_message,
exceptions_count=exceptions_count,
finished_at=now,
)
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
def handle_node_execution_start(
self,
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
)
return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
self._update_node_execution_completion(
domain_execution,
event=event,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent | QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
status = (
WorkflowNodeExecutionStatus.EXCEPTION
if isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.FAILED
)
self._update_node_execution_completion(
domain_execution,
event=event,
status=status,
error=event.error,
handle_special_values=True,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
created_at=event.start_at,
)
# Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
metadata = self._merge_event_metadata(event)
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
execution = self._save_and_cache_node_execution(domain_execution)
self._workflow_node_execution_repository.save_execution_data(execution)
return execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)
def _prepare_workflow_inputs(self) -> dict[str, Any]:
"""Prepare workflow inputs by merging application inputs with system variables."""
inputs = {**self._application_generate_entity.inputs}
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name != SystemVariableKey.CONVERSATION_ID:
inputs[f"sys.{field_name}"] = value
return dict(WorkflowEntry.handle_special_values(inputs) or {})
def _get_or_generate_execution_id(self) -> str:
"""Get execution ID from system variables or generate a new one."""
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
return str(self._workflow_system_variables.workflow_execution_id)
return str(uuidv7())
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
"""Save workflow execution to repository and cache it."""
self._workflow_execution_repository.save(execution)
self._workflow_execution_cache[execution.id_] = execution
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID.
This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model.
"""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution
return execution
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
"""Get node execution from cache or raise error if not found."""
domain_execution = self._node_execution_cache.get(node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {node_execution_id}")
return domain_execution
def _update_workflow_execution_completion(
self,
execution: WorkflowExecution,
*,
status: WorkflowExecutionStatus,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: str | None = None,
exceptions_count: int = 0,
finished_at: datetime | None = None,
):
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = finished_at or naive_utc_now()
execution.exceptions_count = exceptions_count
if error_message:
execution.error_message = error_message
def _add_trace_task_if_needed(
self,
trace_manager: TraceQueueManager | None,
workflow_execution: WorkflowExecution,
conversation_id: str | None,
external_trace_id: str | None,
):
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
external_trace_id=external_trace_id,
)
)
def _fail_running_node_executions(
self,
workflow_execution_id: str,
error_message: str,
now: datetime,
):
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution_id
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
for node_execution in running_node_executions:
if node_execution.node_execution_id:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
self._workflow_node_execution_repository.save(node_execution)
def _create_node_execution_from_event(
self,
*,
workflow_execution: WorkflowExecution,
event: QueueNodeStartedEvent,
status: WorkflowNodeExecutionStatus,
error: str | None = None,
created_at: datetime | None = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = naive_utc_now()
created_at = created_at or now
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.node_execution_id,
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=status,
metadata=metadata,
created_at=created_at,
error=error,
)
if status == WorkflowNodeExecutionStatus.RETRY:
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
def _update_node_execution_completion(
self,
domain_execution: WorkflowNodeExecution,
*,
event: Union[
QueueNodeSucceededEvent,
QueueNodeFailedEvent,
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: str | None = None,
handle_special_values: bool = False,
):
"""Update node execution with completion data."""
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
# Process data
if handle_special_values:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
else:
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
# Update domain model
domain_execution.status = status
domain_execution.update_from_mapping(
inputs=inputs,
process_data=process_data,
outputs=outputs,
metadata=execution_metadata_dict,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
if error:
domain_execution.error = error
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
"""Merge event metadata with origin metadata."""
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

View File

@@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
@@ -20,6 +20,7 @@ from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, Gra
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory

View File

@@ -37,7 +37,6 @@ from core.rag.entities.event import (
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -50,6 +49,7 @@ from core.workflow.node_events.base import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db

View File

@@ -14,7 +14,7 @@ from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities import VariablePool, WorkflowNodeExecution
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
@@ -23,6 +23,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated

View File

@@ -5,12 +5,13 @@ import pytest
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock

View File

@@ -5,10 +5,11 @@ from urllib.parse import urlencode
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@@ -174,13 +175,13 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
# Create variable pool

View File

@@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom

View File

@@ -5,11 +5,12 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom

View File

@@ -4,11 +4,12 @@ import uuid
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock

View File

@@ -4,12 +4,13 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@@ -99,6 +99,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
@@ -237,6 +239,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session
@@ -390,6 +394,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
workflow=mock_workflow,
system_user_id=str(uuid4()),
app=MagicMock(),
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
)
# Mock database session

View File

@@ -0,0 +1,63 @@
from types import SimpleNamespace
import pytest
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.workflow.runtime import GraphRuntimeState
from core.workflow.runtime.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
def _make_state(workflow_run_id: str | None) -> GraphRuntimeState:
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id))
return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
class _StubPipeline(GraphRuntimeStateSupport):
def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None):
self._graph_runtime_state = cached_state
self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state))
def test_ensure_graph_runtime_initialized_caches_explicit_state():
explicit_state = _make_state("run-explicit")
pipeline = _StubPipeline(cached_state=None, queue_state=None)
resolved = pipeline._ensure_graph_runtime_initialized(explicit_state)
assert resolved is explicit_state
assert pipeline._graph_runtime_state is explicit_state
def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty():
queued_state = _make_state("run-queue")
pipeline = _StubPipeline(cached_state=None, queue_state=queued_state)
resolved = pipeline._resolve_graph_runtime_state()
assert resolved is queued_state
assert pipeline._graph_runtime_state is queued_state
def test_resolve_graph_runtime_state_raises_when_no_state_available():
pipeline = _StubPipeline(cached_state=None, queue_state=None)
with pytest.raises(ValueError):
pipeline._resolve_graph_runtime_state()
def test_extract_workflow_run_id_returns_value():
state = _make_state("run-identifier")
pipeline = _StubPipeline(cached_state=state, queue_state=None)
run_id = pipeline._extract_workflow_run_id(state)
assert run_id == "run-identifier"
def test_extract_workflow_run_id_raises_when_missing():
state = _make_state(None)
pipeline = _StubPipeline(cached_state=state, queue_state=None)
with pytest.raises(ValueError):
pipeline._extract_workflow_run_id(state)

View File

@@ -3,8 +3,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
"""
import uuid
from dataclasses import dataclass
from datetime import datetime
from collections.abc import Mapping
from typing import Any
from unittest.mock import Mock
@@ -12,24 +11,17 @@ import pytest
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
@dataclass
class ProcessDataResponseScenario:
"""Test scenario for process_data in responses."""
name: str
original_process_data: dict[str, Any] | None
truncated_process_data: dict[str, Any] | None
expected_response_data: dict[str, Any] | None
expected_truncated_flag: bool
class TestWorkflowResponseConverterCenarios:
"""Test process_data truncation in WorkflowResponseConverter."""
@@ -39,6 +31,7 @@ class TestWorkflowResponseConverterCenarios:
mock_app_config = Mock()
mock_app_config.tenant_id = "test-tenant-id"
mock_entity.app_config = mock_app_config
mock_entity.inputs = {}
return mock_entity
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
@@ -50,54 +43,59 @@ class TestWorkflowResponseConverterCenarios:
mock_user.name = "Test User"
mock_user.email = "test@example.com"
return WorkflowResponseConverter(application_generate_entity=mock_entity, user=mock_user)
def create_workflow_node_execution(
self,
process_data: dict[str, Any] | None = None,
truncated_process_data: dict[str, Any] | None = None,
execution_id: str = "test-execution-id",
) -> WorkflowNodeExecution:
"""Create a WorkflowNodeExecution for testing."""
execution = WorkflowNodeExecution(
id=execution_id,
workflow_id="test-workflow-id",
workflow_execution_id="test-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=process_data,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=datetime.now(),
finished_at=datetime.now(),
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
return WorkflowResponseConverter(
application_generate_entity=mock_entity,
user=mock_user,
system_variables=system_variables,
)
if truncated_process_data is not None:
execution.set_truncated_process_data(truncated_process_data)
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
"""Create a QueueNodeStartedEvent for testing."""
return QueueNodeStartedEvent(
node_execution_id=node_execution_id or str(uuid.uuid4()),
node_id="test-node-id",
node_title="Test Node",
node_type=NodeType.CODE,
start_at=naive_utc_now(),
predecessor_node_id=None,
in_iteration_id=None,
in_loop_id=None,
provider_type="built-in",
provider_id="code",
)
return execution
def create_node_succeeded_event(self) -> QueueNodeSucceededEvent:
def create_node_succeeded_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeSucceededEvent:
"""Create a QueueNodeSucceededEvent for testing."""
return QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
node_execution_id=str(uuid.uuid4()),
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
parallel_id=None,
parallel_start_node_id=None,
parent_parallel_id=None,
parent_parallel_start_node_id=None,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data=process_data or {},
outputs={},
execution_metadata={},
)
def create_node_retry_event(self) -> QueueNodeRetryEvent:
def create_node_retry_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeRetryEvent:
"""Create a QueueNodeRetryEvent for testing."""
return QueueNodeRetryEvent(
inputs={"data": "inputs"},
outputs={"data": "outputs"},
process_data=process_data or {},
error="oops",
retry_index=1,
node_id="test-node-id",
@@ -105,12 +103,8 @@ class TestWorkflowResponseConverterCenarios:
node_title="test code",
provider_type="built-in",
provider_id="code",
node_execution_id=str(uuid.uuid4()),
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
parallel_id=None,
parallel_start_node_id=None,
parent_parallel_id=None,
parent_parallel_start_node_id=None,
in_iteration_id=None,
in_loop_id=None,
)
@@ -122,15 +116,28 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
execution = self.create_workflow_node_execution(
process_data=original_data, truncated_process_data=truncated_data
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event()
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use truncated data, not original
@@ -145,13 +152,26 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"small": "data"}
execution = self.create_workflow_node_execution(process_data=original_data)
event = self.create_node_succeeded_event()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use original data
@@ -163,18 +183,31 @@ class TestWorkflowResponseConverterCenarios:
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
execution = self.create_workflow_node_execution(process_data=None)
event = self.create_node_succeeded_event()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=None,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should have None process_data
# Response should normalize missing process_data to an empty mapping
assert response is not None
assert response.data.process_data is None
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_retry_response_uses_truncated_process_data(self):
@@ -184,15 +217,28 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
execution = self.create_workflow_node_execution(
process_data=original_data, truncated_process_data=truncated_data
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event()
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use truncated data, not original
@@ -207,224 +253,72 @@ class TestWorkflowResponseConverterCenarios:
original_data = {"small": "data"}
execution = self.create_workflow_node_execution(process_data=original_data)
event = self.create_node_retry_event()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Response should use original data
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_iteration_and_loop_nodes_return_none(self):
"""Test that iteration and loop nodes return None (no change from existing behavior)."""
"""Test that iteration and loop nodes return None (no streaming events)."""
converter = self.create_workflow_response_converter()
# Test iteration node
iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"})
iteration_execution.node_type = NodeType.ITERATION
event = self.create_node_succeeded_event()
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=iteration_execution,
)
# Should return None for iteration nodes
assert response is None
# Test loop node
loop_execution = self.create_workflow_node_execution(process_data={"test": "data"})
loop_execution.node_type = NodeType.LOOP
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=loop_execution,
)
# Should return None for loop nodes
assert response is None
def test_execution_without_workflow_execution_id_returns_none(self):
"""Test that executions without workflow_execution_id return None."""
converter = self.create_workflow_response_converter()
execution = self.create_workflow_node_execution(process_data={"test": "data"})
execution.workflow_execution_id = None # Single-step debugging
event = self.create_node_succeeded_event()
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
workflow_node_execution=execution,
)
# Should return None for single-step debugging
assert response is None
@staticmethod
def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]:
"""Create test scenarios for process_data responses."""
return [
ProcessDataResponseScenario(
name="none_process_data",
original_process_data=None,
truncated_process_data=None,
expected_response_data=None,
expected_truncated_flag=False,
),
ProcessDataResponseScenario(
name="small_process_data_no_truncation",
original_process_data={"small": "data"},
truncated_process_data=None,
expected_response_data={"small": "data"},
expected_truncated_flag=False,
),
ProcessDataResponseScenario(
name="large_process_data_with_truncation",
original_process_data={"large": "x" * 10000, "metadata": "info"},
truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"},
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
expected_truncated_flag=True,
),
ProcessDataResponseScenario(
name="empty_process_data",
original_process_data={},
truncated_process_data=None,
expected_response_data={},
expected_truncated_flag=False,
),
ProcessDataResponseScenario(
name="complex_data_with_truncation",
original_process_data={
"logs": ["entry"] * 1000, # Large array
"config": {"setting": "value"},
"status": "processing",
},
truncated_process_data={
"logs": "[TRUNCATED: 1000 items]",
"config": {"setting": "value"},
"status": "processing",
},
expected_response_data={
"logs": "[TRUNCATED: 1000 items]",
"config": {"setting": "value"},
"status": "processing",
},
expected_truncated_flag=True,
),
]
@pytest.mark.parametrize(
"scenario",
get_process_data_response_scenarios(),
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
)
def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario):
"""Test various scenarios for node finish responses."""
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
converter = WorkflowResponseConverter(
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
user=mock_user,
)
execution = WorkflowNodeExecution(
id="test-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=scenario.original_process_data,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
created_at=datetime.now(),
finished_at=datetime.now(),
)
if scenario.truncated_process_data is not None:
execution.set_truncated_process_data(scenario.truncated_process_data)
event = QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
iteration_event = QueueNodeSucceededEvent(
node_id="iteration-node",
node_type=NodeType.ITERATION,
node_execution_id=str(uuid.uuid4()),
start_at=naive_utc_now(),
parallel_id=None,
parallel_start_node_id=None,
parent_parallel_id=None,
parent_parallel_start_node_id=None,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
event=iteration_event,
task_id="test-task-id",
workflow_node_execution=execution,
)
assert response is None
assert response is not None
assert response.data.process_data == scenario.expected_response_data
assert response.data.process_data_truncated == scenario.expected_truncated_flag
@pytest.mark.parametrize(
"scenario",
get_process_data_response_scenarios(),
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
)
def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario):
"""Test various scenarios for node retry responses."""
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
converter = WorkflowResponseConverter(
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
user=mock_user,
)
execution = WorkflowNodeExecution(
id="test-execution-id",
workflow_id="test-workflow-id",
workflow_execution_id="test-run-id",
index=1,
node_id="test-node-id",
node_type=NodeType.LLM,
title="Test Node",
process_data=scenario.original_process_data,
status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario
created_at=datetime.now(),
finished_at=datetime.now(),
)
if scenario.truncated_process_data is not None:
execution.set_truncated_process_data(scenario.truncated_process_data)
event = self.create_node_retry_event()
response = converter.workflow_node_retry_to_stream_response(
event=event,
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
response = converter.workflow_node_finish_to_stream_response(
event=loop_event,
task_id="test-task-id",
workflow_node_execution=execution,
)
assert response is None
def test_finish_without_start_raises(self):
"""Ensure finish responses require a prior workflow start."""
converter = self.create_workflow_response_converter()
event = self.create_node_succeeded_event(
node_execution_id=str(uuid.uuid4()),
process_data={},
)
assert response is not None
assert response.data.process_data == scenario.expected_response_data
assert response.data.process_data_truncated == scenario.expected_truncated_flag
with pytest.raises(ValueError):
converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)

View File

@@ -37,7 +37,7 @@ from core.variables.variables import (
Variable,
VariableUnion,
)
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable

View File

@@ -1,9 +1,11 @@
import json
from time import time
from unittest.mock import MagicMock, patch
import pytest
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
class TestGraphRuntimeState:
@@ -95,3 +97,141 @@ class TestGraphRuntimeState:
# Test add_tokens validation
with pytest.raises(ValueError):
state.add_tokens(-1)
def test_ready_queue_default_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
queue = state.ready_queue
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
assert isinstance(queue, InMemoryReadyQueue)
assert state.ready_queue is queue
def test_graph_execution_lazy_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
execution = state.graph_execution
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
assert isinstance(execution, GraphExecution)
assert execution.workflow_id == ""
assert state.graph_execution is execution
def test_response_coordinator_configuration(self):
variable_pool = VariablePool()
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
with pytest.raises(ValueError):
_ = state.response_coordinator
mock_graph = MagicMock()
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
coordinator_instance = MagicMock()
coordinator_cls.return_value = coordinator_instance
state.configure(graph=mock_graph)
assert state.response_coordinator is coordinator_instance
coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph)
# Configure again with same graph should be idempotent
state.configure(graph=mock_graph)
other_graph = MagicMock()
with pytest.raises(ValueError):
state.attach_graph(other_graph)
def test_read_only_wrapper_exposes_additional_state(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
state.configure()
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
assert wrapper.ready_queue_size == 0
assert wrapper.exceptions_count == 0
def test_read_only_wrapper_serializes_runtime_state(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
state.total_tokens = 5
state.set_output("result", {"success": True})
state.ready_queue.put("node-1")
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
wrapper_snapshot = json.loads(wrapper.dumps())
state_snapshot = json.loads(state.dumps())
assert wrapper_snapshot == state_snapshot
def test_dumps_and_loads_roundtrip_with_response_coordinator(self):
variable_pool = VariablePool()
variable_pool.add(("node1", "value"), "payload")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 10
state.node_run_steps = 3
state.set_output("final", {"result": True})
usage = LLMUsage.from_metadata(
{
"prompt_tokens": 2,
"completion_tokens": 3,
"total_tokens": 5,
"total_price": "1.23",
"currency": "USD",
"latency": 0.5,
}
)
state.llm_usage = usage
state.ready_queue.put("node-A")
graph_execution = state.graph_execution
graph_execution.workflow_id = "wf-123"
graph_execution.exceptions_count = 4
graph_execution.started = True
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
mock_graph = MagicMock()
stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
state.attach_graph(mock_graph)
stub.state = "configured"
snapshot = state.dumps()
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.loads(snapshot)
assert restored.total_tokens == 10
assert restored.node_run_steps == 3
assert restored.get_output("final") == {"result": True}
assert restored.llm_usage.total_tokens == usage.total_tokens
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-A"
restored_segment = restored.variable_pool.get(("node1", "value"))
assert restored_segment is not None
assert restored_segment.value == "payload"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-123"
assert restored_execution.exceptions_count == 4
assert restored_execution.started is True
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored.attach_graph(mock_graph)
assert new_stub.state == "configured"

View File

@@ -4,7 +4,7 @@ from core.variables.segments import (
NoneSegment,
StringSegment,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class TestVariablePoolGetAndNestedAttribute:

View File

@@ -0,0 +1,59 @@
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.nodes.base.node import Node
def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node:
node = MagicMock(spec=Node)
node.id = node_id
node.node_type = node_type
node.execution_type = None # attribute not used in builder path
return node
def test_graph_builder_creates_linear_graph():
builder = Graph.new()
root = _make_node("root", NodeType.START)
mid = _make_node("mid", NodeType.LLM)
end = _make_node("end", NodeType.END)
graph = builder.add_root(root).add_node(mid).add_node(end).build()
assert graph.root_node is root
assert graph.nodes == {"root": root, "mid": mid, "end": end}
assert len(graph.edges) == 2
first_edge = next(iter(graph.edges.values()))
assert first_edge.tail == "root"
assert first_edge.head == "mid"
assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"]
def test_graph_builder_supports_custom_predecessor():
builder = Graph.new()
root = _make_node("root")
branch = _make_node("branch")
other = _make_node("other")
graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build()
outgoing_root = graph.out_edges["root"]
assert len(outgoing_root) == 2
edge_targets = {graph.edges[eid].head for eid in outgoing_root}
assert edge_targets == {"branch", "other"}
def test_graph_builder_validates_usage():
builder = Graph.new()
node = _make_node("node")
with pytest.raises(ValueError, match="Root node"):
builder.add_node(node)
builder.add_root(node)
duplicate = _make_node("node")
with pytest.raises(ValueError, match="Duplicate"):
builder.add_node(duplicate)

View File

@@ -20,9 +20,6 @@ The TableTestRunner (`test_table_runner.py`) provides a robust table-driven test
- **Mock configuration** - Seamless integration with the auto-mock system
- **Performance metrics** - Track execution times and bottlenecks
- **Detailed error reporting** - Comprehensive failure diagnostics
- **Test tagging** - Organize and filter tests by tags
- **Retry mechanism** - Handle flaky tests gracefully
- **Custom validators** - Define custom validation logic
### Basic Usage
@@ -68,49 +65,6 @@ suite_result = runner.run_table_tests(
print(f"Success rate: {suite_result.success_rate:.1f}%")
```
#### Test Tagging and Filtering
```python
test_case = WorkflowTestCase(
fixture_path="workflow",
inputs={},
expected_outputs={},
tags=["smoke", "critical"],
)
# Run only tests with specific tags
suite_result = runner.run_table_tests(
test_cases,
tags_filter=["smoke"]
)
```
#### Retry Mechanism
```python
test_case = WorkflowTestCase(
fixture_path="flaky_workflow",
inputs={},
expected_outputs={},
retry_count=2, # Retry up to 2 times on failure
)
```
#### Custom Validators
```python
def custom_validator(outputs: dict) -> bool:
# Custom validation logic
return "error" not in outputs.get("status", "")
test_case = WorkflowTestCase(
fixture_path="workflow",
inputs={},
expected_outputs={"status": "success"},
custom_validator=custom_validator,
)
```
#### Event Sequence Validation
```python

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
from datetime import datetime
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
@@ -16,6 +15,7 @@ from core.workflow.graph_engine.response_coordinator.coordinator import Response
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import RetryConfig
from core.workflow.runtime import GraphRuntimeState, VariablePool
class _StubEdgeProcessor:

View File

@@ -3,12 +3,12 @@
import time
from unittest.mock import MagicMock
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
from core.workflow.runtime import GraphRuntimeState, VariablePool
def test_abort_command():
@@ -100,8 +100,57 @@ def test_redis_channel_serialization():
assert command_data["command_type"] == "abort"
assert command_data["reason"] == "Test abort"
# Test pause command serialization
pause_command = PauseCommand(reason="User requested pause")
channel.send_command(pause_command)
if __name__ == "__main__":
test_abort_command()
test_redis_channel_serialization()
print("All tests passed!")
assert len(mock_pipeline.rpush.call_args_list) == 2
second_call_args = mock_pipeline.rpush.call_args_list[1]
pause_command_json = second_call_args[0][1]
pause_command_data = json.loads(pause_command_json)
assert pause_command_data["command_type"] == CommandType.PAUSE.value
assert pause_command_data["reason"] == "User requested pause"
def test_pause_command():
"""Test that GraphEngine properly handles pause commands."""
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
mock_start_node = MagicMock()
mock_start_node.state = None
mock_start_node.id = "start"
mock_start_node.graph_runtime_state = shared_runtime_state
mock_graph.nodes["start"] = mock_start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
command_channel = InMemoryChannel()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=shared_runtime_state,
command_channel=command_channel,
)
pause_command = PauseCommand(reason="User requested pause")
command_channel.send_command(pause_command)
events = list(engine.run())
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1
assert pause_events[0].reason == "User requested pause"
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.is_paused
assert graph_execution.pause_reason == "User requested pause"

View File

@@ -21,6 +21,7 @@ class _StubExecutionCoordinator:
self._execution_complete = False
self.mark_complete_called = False
self.failed = False
self._paused = False
def check_commands(self) -> None:
self.command_checks += 1
@@ -28,6 +29,10 @@ class _StubExecutionCoordinator:
def check_scaling(self) -> None:
self.scaling_checks += 1
@property
def is_paused(self) -> bool:
return self._paused
def is_execution_complete(self) -> bool:
return self._execution_complete
@@ -96,7 +101,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
"""Dispatcher polls commands when idle and re-checks after completion events."""
"""Dispatcher polls commands when idle and after completion events."""
started_checks = _run_dispatcher_for_event(_make_started_event())
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())

View File

@@ -0,0 +1,62 @@
"""Unit tests for the execution coordinator orchestration logic."""
from unittest.mock import MagicMock
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator
from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool
def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]:
command_processor = MagicMock(spec=CommandProcessor)
state_manager = MagicMock(spec=GraphStateManager)
worker_pool = MagicMock(spec=WorkerPool)
coordinator = ExecutionCoordinator(
graph_execution=graph_execution,
state_manager=state_manager,
command_processor=command_processor,
worker_pool=worker_pool,
)
return coordinator, state_manager, worker_pool
def test_handle_pause_stops_workers_and_clears_state() -> None:
"""Paused execution should stop workers and clear executing state."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
graph_execution.pause("Awaiting human input")
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
coordinator.handle_pause_if_needed()
worker_pool.stop.assert_called_once_with()
state_manager.clear_executing.assert_called_once_with()
def test_handle_pause_noop_when_execution_running() -> None:
"""Running execution should not trigger pause handling."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
coordinator.handle_pause_if_needed()
worker_pool.stop.assert_not_called()
state_manager.clear_executing.assert_not_called()
def test_is_execution_complete_when_paused() -> None:
"""Paused execution should be treated as complete."""
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
graph_execution.pause("Awaiting input")
coordinator, state_manager, _worker_pool = _build_coordinator(graph_execution)
state_manager.is_execution_complete.return_value = False
assert coordinator.is_execution_complete()

View File

@@ -0,0 +1,341 @@
import time
from collections.abc import Iterable
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
llm_node = MockLLMNode(
id=node_id,
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
end_primary_data = EndNodeData(
title="End Primary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
],
desc=None,
)
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
end_primary = EndNode(
id=end_primary_config["id"],
config=end_primary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
],
desc=None,
)
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
end_secondary = EndNode(
id=end_secondary_config["id"],
config=end_secondary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()
.add_root(start_node)
.add_node(llm_initial)
.add_node(human_node)
.add_node(llm_primary, from_node_id="human", source_handle="primary")
.add_node(end_primary, from_node_id="llm_primary")
.add_node(llm_secondary, from_node_id="human", source_handle="secondary")
.add_node(end_secondary, from_node_id="llm_secondary")
.build()
)
return graph, graph_runtime_state
def _expected_mock_llm_chunks(text: str) -> list[str]:
chunks: list[str] = []
for index, word in enumerate(text.split(" ")):
chunk = word if index == 0 else f" {word}"
chunks.append(chunk)
chunks.append("")
return chunks
def _assert_stream_chunk_sequence(
chunk_events: Iterable[NodeRunStreamChunkEvent],
expected_nodes: list[str],
expected_chunks: list[str],
) -> None:
actual_nodes = [event.node_id for event in chunk_events]
actual_chunks = [event.chunk for event in chunk_events]
assert actual_nodes == expected_nodes
assert actual_chunks == expected_chunks
def test_human_input_llm_streaming_across_multiple_branches() -> None:
mock_config = MockConfig()
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
branch_scenarios = [
{
"handle": "primary",
"resume_llm": "llm_primary",
"end_node": "end_primary",
"expected_pre_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
("end_primary", ["\n"]), # literal segment emitted when end_primary session activates
],
"expected_post_chunks": [
("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch
],
},
{
"handle": "secondary",
"resume_llm": "llm_secondary",
"end_node": "end_secondary",
"expected_pre_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates
],
"expected_post_chunks": [
("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch
],
},
]
for scenario in branch_scenarios:
runner = TableTestRunner()
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config)
initial_case = WorkflowTestCase(
description="HumanInput pause before branching decision",
graph_factory=initial_graph_factory,
expected_event_sequence=[
GraphRunStartedEvent, # initial run: graph execution starts
NodeRunStartedEvent, # start node begins execution
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial starts streaming
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # human node begins and issues pause
NodeRunPauseRequestedEvent, # human node requests pause awaiting input
GraphRunPausedEvent, # graph run pauses awaiting resume
],
)
initial_result = runner.run_test_case(initial_case)
assert initial_result.success, initial_result.event_mismatch_details
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
graph_runtime_state = initial_result.graph_runtime_state
graph = initial_result.graph
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
graph_runtime_state.graph_execution.pause_reason = None
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
expected_resume_sequence: list[type] = (
[
GraphRunStartedEvent,
NodeRunStartedEvent,
]
+ [NodeRunStreamChunkEvent] * pre_chunk_count
+ [
NodeRunSucceededEvent,
NodeRunStartedEvent,
]
+ [NodeRunStreamChunkEvent] * post_chunk_count
+ [
NodeRunSucceededEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
GraphRunSucceededEvent,
]
)
def resume_graph_factory(
graph_snapshot: Graph = graph,
state_snapshot: GraphRuntimeState = graph_runtime_state,
) -> tuple[Graph, GraphRuntimeState]:
return graph_snapshot, state_snapshot
resume_case = WorkflowTestCase(
description=f"HumanInput resumes via {scenario['handle']} branch",
graph_factory=resume_graph_factory,
expected_event_sequence=expected_resume_sequence,
)
resume_result = runner.run_test_case(resume_case)
assert resume_result.success, resume_result.event_mismatch_details
resume_events = resume_result.events
chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
assert len(chunk_events) == pre_chunk_count + post_chunk_count
pre_chunk_events = chunk_events[:pre_chunk_count]
post_chunk_events = chunk_events[pre_chunk_count:]
expected_pre_nodes: list[str] = []
expected_pre_chunks: list[str] = []
for node_id, chunks in scenario["expected_pre_chunks"]:
expected_pre_nodes.extend([node_id] * len(chunks))
expected_pre_chunks.extend(chunks)
_assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks)
expected_post_nodes: list[str] = []
expected_post_chunks: list[str] = []
for node_id, chunks in scenario["expected_post_chunks"]:
expected_post_nodes.extend([node_id] * len(chunks))
expected_post_chunks.extend(chunks)
_assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks)
human_success_index = next(
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
)
pre_indices = [
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
]
assert pre_indices == list(range(2, 2 + pre_chunk_count))
resume_chunk_indices = [
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
]
assert resume_chunk_indices, "Expected streaming output from the selected branch"
resume_start_index = next(
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
)
resume_success_index = next(
index
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
)
assert resume_start_index < min(resume_chunk_indices)
assert max(resume_chunk_indices) < resume_success_index
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]]

View File

@@ -0,0 +1,297 @@
import time
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
llm_node = MockLLMNode(
id=node_id,
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
end_data = EndNodeData(
title="End",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]),
],
desc=None,
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_node.init_node_data(end_config["data"])
graph = (
Graph.new()
.add_root(start_node)
.add_node(llm_first)
.add_node(human_node)
.add_node(llm_second)
.add_node(end_node)
.build()
)
return graph, graph_runtime_state
def _expected_mock_llm_chunks(text: str) -> list[str]:
chunks: list[str] = []
for index, word in enumerate(text.split(" ")):
chunk = word if index == 0 else f" {word}"
chunks.append(chunk)
chunks.append("")
return chunks
def test_human_input_llm_streaming_order_across_pause() -> None:
runner = TableTestRunner()
initial_text = "Hello, pause"
resume_text = "Welcome back!"
mock_config = MockConfig()
mock_config.set_node_outputs("llm_initial", {"text": initial_text})
mock_config.set_node_outputs("llm_resume", {"text": resume_text})
expected_initial_sequence: list[type] = [
GraphRunStartedEvent, # graph run begins
NodeRunStartedEvent, # start node begins
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial begins streaming
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # human node begins and requests pause
NodeRunPauseRequestedEvent, # human node pause requested
GraphRunPausedEvent, # graph run pauses awaiting resume
]
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_llm_human_llm_graph(mock_config)
initial_case = WorkflowTestCase(
description="HumanInput pause preserves LLM streaming order",
graph_factory=graph_factory,
expected_event_sequence=expected_initial_sequence,
)
initial_result = runner.run_test_case(initial_case)
assert initial_result.success, initial_result.event_mismatch_details
initial_events = initial_result.events
initial_chunks = _expected_mock_llm_chunks(initial_text)
initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)]
assert initial_stream_chunk_events == []
pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent))
llm_succeeded_index = next(
i
for i, event in enumerate(initial_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial"
)
assert llm_succeeded_index < pause_index
graph_runtime_state = initial_result.graph_runtime_state
graph = initial_result.graph
assert graph_runtime_state is not None
assert graph is not None
coordinator = graph_runtime_state.response_coordinator
stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions
assert ("llm_initial", "text") in stream_buffers
initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]]
assert initial_stream_chunks == initial_chunks
assert ("llm_resume", "text") not in stream_buffers
resume_chunks = _expected_mock_llm_chunks(resume_text)
expected_resume_sequence: list[type] = [
GraphRunStartedEvent, # resumed graph run begins
NodeRunStartedEvent, # human node restarts
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
NodeRunStreamChunkEvent, # cached llm_initial final chunk
NodeRunStreamChunkEvent, # end node emits combined template separator
NodeRunSucceededEvent, # human node finishes instantly after input
NodeRunStartedEvent, # llm_resume begins streaming
NodeRunStreamChunkEvent, # llm_resume chunk 1
NodeRunStreamChunkEvent, # llm_resume chunk 2
NodeRunStreamChunkEvent, # llm_resume final chunk
NodeRunSucceededEvent, # llm_resume completes streaming
NodeRunStartedEvent, # end node starts
NodeRunSucceededEvent, # end node finishes
GraphRunSucceededEvent, # graph run succeeds after resume
]
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.graph_execution.pause_reason = None
return graph, graph_runtime_state
resume_case = WorkflowTestCase(
description="HumanInput resume continues LLM streaming order",
graph_factory=resume_graph_factory,
expected_event_sequence=expected_resume_sequence,
)
resume_result = runner.run_test_case(resume_case)
assert resume_result.success, resume_result.event_mismatch_details
resume_events = resume_result.events
success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent))
llm_resume_succeeded_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
)
assert llm_resume_succeeded_index < success_index
resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3
assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks
assert resume_chunk_events[3].node_id == "end"
assert resume_chunk_events[3].chunk == "\n"
assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3
assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks
human_success_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
)
cached_chunk_indices = [
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"}
]
assert all(index < human_success_index for index in cached_chunk_indices)
llm_resume_start_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume"
)
llm_resume_success_index = next(
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
)
llm_resume_chunk_indices = [
i
for i, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume"
]
assert llm_resume_chunk_indices
first_resume_chunk_index = min(llm_resume_chunk_indices)
last_resume_chunk_index = max(llm_resume_chunk_indices)
assert llm_resume_start_index < first_resume_chunk_index
assert last_resume_chunk_index < llm_resume_success_index
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
assert started_nodes == ["human", "llm_resume", "end"]

View File

@@ -0,0 +1,321 @@
import time
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.add(("branch", "value"), branch_value)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
llm_node = MockLLMNode(
id=node_id,
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
if_else_data = IfElseNodeData(
title="IfElse",
cases=[
IfElseNodeData.Case(
case_id="primary",
logical_operator="and",
conditions=[
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary")
],
),
IfElseNodeData.Case(
case_id="secondary",
logical_operator="and",
conditions=[
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary")
],
),
],
)
if_else_config = {"id": "if_else", "data": if_else_data.model_dump()}
if_else_node = IfElseNode(
id=if_else_config["id"],
config=if_else_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if_else_node.init_node_data(if_else_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
end_primary_data = EndNodeData(
title="End Primary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
],
desc=None,
)
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
end_primary = EndNode(
id=end_primary_config["id"],
config=end_primary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
outputs=[
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
],
desc=None,
)
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
end_secondary = EndNode(
id=end_secondary_config["id"],
config=end_secondary_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()
.add_root(start_node)
.add_node(llm_initial)
.add_node(if_else_node)
.add_node(llm_primary, from_node_id="if_else", source_handle="primary")
.add_node(end_primary, from_node_id="llm_primary")
.add_node(llm_secondary, from_node_id="if_else", source_handle="secondary")
.add_node(end_secondary, from_node_id="llm_secondary")
.build()
)
return graph, graph_runtime_state
def _expected_mock_llm_chunks(text: str) -> list[str]:
chunks: list[str] = []
for index, word in enumerate(text.split(" ")):
chunk = word if index == 0 else f" {word}"
chunks.append(chunk)
chunks.append("")
return chunks
def test_if_else_llm_streaming_order() -> None:
mock_config = MockConfig()
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
scenarios = [
{
"branch": "primary",
"resume_llm": "llm_primary",
"end_node": "end_primary",
"expected_sequence": [
GraphRunStartedEvent, # graph run begins
NodeRunStartedEvent, # start node begins execution
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial starts and streams
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # if_else evaluates conditions
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
NodeRunStreamChunkEvent, # template literal newline emitted
NodeRunSucceededEvent, # if_else completes branch selection
NodeRunStartedEvent, # llm_primary begins streaming
NodeRunStreamChunkEvent, # llm_primary chunk 1
NodeRunStreamChunkEvent, # llm_primary chunk 2
NodeRunStreamChunkEvent, # llm_primary chunk 3
NodeRunStreamChunkEvent, # llm_primary final chunk
NodeRunSucceededEvent, # llm_primary completes streaming
NodeRunStartedEvent, # end_primary node starts
NodeRunSucceededEvent, # end_primary finishes aggregation
GraphRunSucceededEvent, # graph run succeeds
],
"expected_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
("end_primary", ["\n"]),
("llm_primary", _expected_mock_llm_chunks("Primary stream output")),
],
},
{
"branch": "secondary",
"resume_llm": "llm_secondary",
"end_node": "end_secondary",
"expected_sequence": [
GraphRunStartedEvent, # graph run begins
NodeRunStartedEvent, # start node begins execution
NodeRunSucceededEvent, # start node completes
NodeRunStartedEvent, # llm_initial starts and streams
NodeRunSucceededEvent, # llm_initial completes streaming
NodeRunStartedEvent, # if_else evaluates conditions
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
NodeRunStreamChunkEvent, # template literal newline emitted
NodeRunSucceededEvent, # if_else completes branch selection
NodeRunStartedEvent, # llm_secondary begins streaming
NodeRunStreamChunkEvent, # llm_secondary chunk 1
NodeRunStreamChunkEvent, # llm_secondary final chunk
NodeRunSucceededEvent, # llm_secondary completes
NodeRunStartedEvent, # end_secondary node starts
NodeRunSucceededEvent, # end_secondary finishes aggregation
GraphRunSucceededEvent, # graph run succeeds
],
"expected_chunks": [
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
("end_secondary", ["\n"]),
("llm_secondary", _expected_mock_llm_chunks("Secondary")),
],
},
]
for scenario in scenarios:
runner = TableTestRunner()
def graph_factory(
branch_value: str = scenario["branch"],
cfg: MockConfig = mock_config,
) -> tuple[Graph, GraphRuntimeState]:
return _build_if_else_graph(branch_value, cfg)
test_case = WorkflowTestCase(
description=f"IfElse streaming via {scenario['branch']} branch",
graph_factory=graph_factory,
expected_event_sequence=scenario["expected_sequence"],
)
result = runner.run_test_case(test_case)
assert result.success, result.event_mismatch_details
chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)]
expected_nodes: list[str] = []
expected_chunks: list[str] = []
for node_id, chunks in scenario["expected_chunks"]:
expected_nodes.extend([node_id] * len(chunks))
expected_chunks.extend(chunks)
assert [event.node_id for event in chunk_events] == expected_nodes
assert [event.chunk for event in chunk_events] == expected_chunks
branch_node_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else"
)
branch_success_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else"
)
pre_branch_chunk_indices = [
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index
]
assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1
assert min(pre_branch_chunk_indices) == branch_node_index + 1
assert max(pre_branch_chunk_indices) < branch_success_index
resume_chunk_indices = [
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
]
assert resume_chunk_indices
resume_start_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
)
resume_success_index = next(
index
for index, event in enumerate(result.events)
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
)
assert resume_start_index < min(resume_chunk_indices)
assert max(resume_chunk_indices) < resume_success_index
started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)]
assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]]

View File

@@ -27,7 +27,8 @@ from .test_mock_nodes import (
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
from .test_mock_config import MockConfig

View File

@@ -42,7 +42,8 @@ def test_mock_iteration_node_preserves_config():
"""Test that MockIterationNode preserves mock configuration."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode
@@ -103,7 +104,8 @@ def test_mock_loop_node_preserves_config():
"""Test that MockLoopNode preserves mock configuration."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode

View File

@@ -24,7 +24,8 @@ from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
from .test_mock_config import MockConfig
@@ -561,10 +562,11 @@ class MockIterationNode(MockNodeMixin, IterationNode):
def _create_graph_engine(self, index: int, item: Any):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState
# Import our MockNodeFactory instead of DifyNodeFactory
from .test_mock_factory import MockNodeFactory
@@ -635,10 +637,11 @@ class MockLoopNode(MockNodeMixin, LoopNode):
def _create_graph_engine(self, start_at, root_node_id: str):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.runtime import GraphRuntimeState
# Import our MockNodeFactory instead of DifyNodeFactory
from .test_mock_factory import MockNodeFactory

View File

@@ -16,8 +16,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_default_output(self):
"""Test that MockTemplateTransformNode processes templates with Jinja2."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -76,8 +76,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_custom_output(self):
"""Test that MockTemplateTransformNode returns custom configured output."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -137,8 +137,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_error_simulation(self):
"""Test that MockTemplateTransformNode can simulate errors."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -196,8 +196,8 @@ class TestMockTemplateTransformNode:
def test_mock_template_transform_node_with_variables(self):
"""Test that MockTemplateTransformNode processes templates with variables."""
from core.variables import StringVariable
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -262,8 +262,8 @@ class TestMockCodeNode:
def test_mock_code_node_default_output(self):
"""Test that MockCodeNode returns default output."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -323,8 +323,8 @@ class TestMockCodeNode:
def test_mock_code_node_with_output_schema(self):
"""Test that MockCodeNode generates outputs based on schema."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -392,8 +392,8 @@ class TestMockCodeNode:
def test_mock_code_node_custom_output(self):
"""Test that MockCodeNode returns custom configured output."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -463,8 +463,8 @@ class TestMockNodeFactory:
def test_code_and_template_nodes_mocked_by_default(self):
"""Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy)."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -504,8 +504,8 @@ class TestMockNodeFactory:
def test_factory_creates_mock_template_transform_node(self):
"""Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(
@@ -555,8 +555,8 @@ class TestMockNodeFactory:
def test_factory_creates_mock_code_node(self):
"""Test that MockNodeFactory creates MockCodeNode for code type."""
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
# Create test parameters
graph_init_params = GraphInitParams(

View File

@@ -13,7 +13,7 @@ from unittest.mock import patch
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
@@ -27,6 +27,7 @@ from core.workflow.graph_events import (
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom

View File

@@ -13,7 +13,7 @@ import redis
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_engine.manager import GraphEngineManager
@@ -52,6 +52,29 @@ class TestRedisStopIntegration:
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "Test stop"
def test_graph_engine_manager_sends_pause_command(self):
"""Test that GraphEngineManager correctly sends pause command through Redis."""
task_id = "test-task-pause-123"
expected_channel_key = f"workflow:{task_id}:commands"
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources")
mock_redis.pipeline.assert_called_once()
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
assert calls[0][0][0] == expected_channel_key
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.PAUSE.value
assert command_data["reason"] == "Awaiting resources"
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
task_id = "test-task-456"
@@ -105,28 +128,37 @@ class TestRedisStopIntegration:
channel_key = "workflow:test:commands"
channel = RedisChannel(mock_redis, channel_key)
# Create abort command
# Create commands
abort_command = AbortCommand(reason="User requested stop")
pause_command = PauseCommand(reason="User requested pause")
# Execute
channel.send_command(abort_command)
channel.send_command(pause_command)
# Verify
mock_redis.pipeline.assert_called_once()
mock_redis.pipeline.assert_called()
# Check rpush was called
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
assert len(calls) == 2
assert calls[0][0][0] == channel_key
assert calls[1][0][0] == channel_key
# Verify serialized command
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "User requested stop"
# Verify serialized commands
abort_command_json = calls[0][0][1]
abort_command_data = json.loads(abort_command_json)
assert abort_command_data["command_type"] == CommandType.ABORT.value
assert abort_command_data["reason"] == "User requested stop"
# Check expire was set
mock_pipeline.expire.assert_called_once_with(channel_key, 3600)
pause_command_json = calls[1][0][1]
pause_command_data = json.loads(pause_command_json)
assert pause_command_data["command_type"] == CommandType.PAUSE.value
assert pause_command_data["reason"] == "User requested pause"
# Check expire was set for each
assert mock_pipeline.expire.call_count == 2
mock_pipeline.expire.assert_any_call(channel_key, 3600)
def test_redis_channel_fetch_commands(self):
"""Test RedisChannel correctly fetches and deserializes commands."""
@@ -143,12 +175,17 @@ class TestRedisStopIntegration:
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
# Mock command data
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
abort_command_json = json.dumps(
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
)
pause_command_json = json.dumps(
{"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None}
)
# Mock pipeline execute to return commands
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [
[abort_command_json.encode()], # lrange result
[abort_command_json.encode(), pause_command_json.encode()], # lrange result
True, # delete result
]
@@ -159,10 +196,13 @@ class TestRedisStopIntegration:
commands = channel.fetch_commands()
# Verify
assert len(commands) == 1
assert len(commands) == 2
assert isinstance(commands[0], AbortCommand)
assert commands[0].command_type == CommandType.ABORT
assert commands[0].reason == "Test abort"
assert isinstance(commands[1], PauseCommand)
assert commands[1].command_type == CommandType.PAUSE
assert commands[1].reason == "Pause requested"
# Verify Redis operations
pending_pipe.get.assert_called_once_with(f"{channel_key}:pending")

Some files were not shown because too many files have changed in this diff Show More