mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
feat(graph_engine): Support pausing workflow graph executions (#26585)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -447,6 +447,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -466,8 +468,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
|
||||
)
|
||||
@@ -483,6 +483,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -538,6 +540,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -570,8 +574,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@@ -584,7 +586,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message: message
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
@@ -596,8 +597,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message=message,
|
||||
user=user,
|
||||
dialogue_count=self._dialogue_count,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=stream,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
@@ -23,8 +23,12 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@@ -55,6 +59,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -68,11 +74,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._workflow = workflow
|
||||
self.system_user_id = system_user_id
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def run(self):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
query=self.application_generate_entity.query,
|
||||
files=self.application_generate_entity.files,
|
||||
conversation_id=self.conversation.id,
|
||||
user_id=self.system_user_id,
|
||||
dialogue_count=self._dialogue_count,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
@@ -89,7 +108,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
@@ -114,17 +132,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables = self._initialize_conversation_variables()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
query=query,
|
||||
files=files,
|
||||
conversation_id=self.conversation.id,
|
||||
user_id=self.system_user_id,
|
||||
dialogue_count=self._dialogue_count,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
@@ -172,6 +179,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
@@ -60,14 +61,11 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
@@ -77,7 +75,7 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
@@ -92,8 +90,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
):
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
@@ -113,31 +109,20 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables=SystemVariable(
|
||||
query=message.query,
|
||||
files=application_generate_entity.files,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_session_id,
|
||||
dialogue_count=dialogue_count,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
),
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
self._workflow_system_variables = SystemVariable(
|
||||
query=message.query,
|
||||
files=application_generate_entity.files,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_session_id,
|
||||
dialogue_count=dialogue_count,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=self._workflow_system_variables,
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
@@ -156,6 +141,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -288,12 +275,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
return graph_runtime_state
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
@@ -304,21 +285,28 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
def _handle_workflow_started_event(
|
||||
self,
|
||||
event: QueueWorkflowStartedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_run_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
|
||||
message.workflow_run_id = workflow_execution.id_
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
message.workflow_run_id = run_id
|
||||
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
|
||||
yield workflow_start_resp
|
||||
|
||||
@@ -326,13 +314,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_retry_resp:
|
||||
@@ -344,14 +328,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""Handle node started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
|
||||
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_resp:
|
||||
@@ -367,14 +346,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
@@ -385,16 +362,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
|
||||
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
@@ -504,29 +478,19 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
@@ -535,30 +499,20 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
@@ -567,32 +521,25 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowFailedEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
graph_runtime_state=validated_state,
|
||||
error=event.error,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
|
||||
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
@@ -607,25 +554,23 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle stop events."""
|
||||
if self._workflow_run_id and graph_runtime_state:
|
||||
_ = trace_manager
|
||||
resolved_state = None
|
||||
if self._workflow_run_id:
|
||||
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
|
||||
if self._workflow_run_id and resolved_state:
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow_id,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
graph_runtime_state=resolved_state,
|
||||
error=event.get_stop_reason(),
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.get_stop_reason(),
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@@ -647,7 +592,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
@@ -661,7 +606,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -670,10 +615,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle retriever resources events."""
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
@@ -682,10 +623,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle annotation reply events."""
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
|
||||
with self._database_session() as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
@@ -739,7 +676,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: Any,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
@@ -752,7 +688,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if handler := handlers.get(event_type):
|
||||
yield from handler(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -769,7 +704,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
):
|
||||
yield from self._handle_node_failed_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -788,15 +722,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
Maintains exact same functionality as original 57-if-statement version.
|
||||
"""
|
||||
# Initialize graph runtime state
|
||||
graph_runtime_state: GraphRuntimeState | None = None
|
||||
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
match event:
|
||||
case QueueWorkflowStartedEvent():
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
self._resolve_graph_runtime_state()
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueErrorEvent():
|
||||
@@ -804,15 +735,11 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
break
|
||||
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_event(
|
||||
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
|
||||
)
|
||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(
|
||||
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
|
||||
)
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
# Handle all other events through elegant dispatch
|
||||
@@ -820,7 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
if responses := list(
|
||||
self._dispatch_event(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -878,6 +804,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
if candidate is not None:
|
||||
self._graph_runtime_state = candidate
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
Message end to stream response.
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,6 +48,7 @@ class AppQueueManager:
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
@@ -109,6 +111,16 @@ class AppQueueManager:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState | None:
|
||||
"""Retrieve the attached graph runtime state, if available."""
|
||||
return self._graph_runtime_state
|
||||
|
||||
@graph_runtime_state.setter
|
||||
def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None:
|
||||
"""Attach the live graph runtime state reference for downstream consumers."""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
55
api/core/app/apps/common/graph_runtime_state_support.py
Normal file
55
api/core/app/apps/common/graph_runtime_state_support.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Shared helpers for managing GraphRuntimeState across task pipelines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
|
||||
|
||||
class GraphRuntimeStateSupport:
|
||||
"""
|
||||
Mixin that centralises common GraphRuntimeState access patterns used by task pipelines.
|
||||
|
||||
Subclasses are expected to provide:
|
||||
* `_base_task_pipeline` – exposing the queue manager with an optional cached runtime state.
|
||||
* `_graph_runtime_state` attribute used as the local cache for the runtime state.
|
||||
"""
|
||||
|
||||
_base_task_pipeline: BasedGenerateTaskPipeline
|
||||
_graph_runtime_state: GraphRuntimeState | None = None
|
||||
|
||||
def _ensure_graph_runtime_initialized(
|
||||
self,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> GraphRuntimeState:
|
||||
"""Validate and return the active graph runtime state."""
|
||||
return self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
|
||||
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
|
||||
system_variables = graph_runtime_state.variable_pool.system_variables
|
||||
if not system_variables or not system_variables.workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id missing from runtime state")
|
||||
return str(system_variables.workflow_execution_id)
|
||||
|
||||
def _resolve_graph_runtime_state(
|
||||
self,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> GraphRuntimeState:
|
||||
"""Return the cached runtime state or bootstrap it from the queue manager."""
|
||||
if graph_runtime_state is not None:
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
return graph_runtime_state
|
||||
|
||||
if self._graph_runtime_state is None:
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
if candidate is not None:
|
||||
self._graph_runtime_state = candidate
|
||||
|
||||
if self._graph_runtime_state is None:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
return self._graph_runtime_state
|
||||
@@ -1,9 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
@@ -39,16 +38,36 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import (
|
||||
Account,
|
||||
EndUser,
|
||||
)
|
||||
from models import Account, EndUser
|
||||
from services.variable_truncator import VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSnapshot:
|
||||
"""In-memory cache for node metadata between start and completion events."""
|
||||
|
||||
title: str
|
||||
index: int
|
||||
start_at: datetime
|
||||
iteration_id: str = ""
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
def __init__(
|
||||
@@ -56,37 +75,151 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
user: Union[Account, EndUser],
|
||||
system_variables: SystemVariable,
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._user = user
|
||||
self._system_variables = system_variables
|
||||
self._workflow_inputs = self._prepare_workflow_inputs()
|
||||
self._truncator = VariableTruncator.default()
|
||||
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
|
||||
self._workflow_execution_id: str | None = None
|
||||
self._workflow_started_at: datetime | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Workflow lifecycle helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = dict(self._application_generate_entity.inputs)
|
||||
for field_name, value in self._system_variables.to_dict().items():
|
||||
# TODO(@future-refactor): store system variables separately from user inputs so we don't
|
||||
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID:
|
||||
# Conversation IDs are session-scoped; omitting them keeps workflow inputs
|
||||
# reusable without pinning new runs to a prior conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return dict(handled or {})
|
||||
|
||||
def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str:
|
||||
"""Return the memoized workflow run id, optionally seeding it during start events."""
|
||||
if workflow_run_id is not None:
|
||||
self._workflow_execution_id = workflow_run_id
|
||||
if not self._workflow_execution_id:
|
||||
raise ValueError("workflow_run_id missing before streaming workflow events")
|
||||
return self._workflow_execution_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Node snapshot helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot:
|
||||
snapshot = _NodeSnapshot(
|
||||
title=event.node_title,
|
||||
index=event.node_run_index,
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
return snapshot
|
||||
|
||||
def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
|
||||
return self._node_snapshots.get(NodeExecutionId(node_execution_id))
|
||||
|
||||
def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
|
||||
return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None)
|
||||
|
||||
@staticmethod
|
||||
def _merge_metadata(
|
||||
base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None,
|
||||
snapshot: _NodeSnapshot | None,
|
||||
) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None:
|
||||
if not base_metadata and not snapshot:
|
||||
return base_metadata
|
||||
|
||||
merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
if base_metadata:
|
||||
merged.update(base_metadata)
|
||||
|
||||
if snapshot:
|
||||
if snapshot.iteration_id:
|
||||
merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id
|
||||
if snapshot.loop_id:
|
||||
merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id
|
||||
|
||||
return merged or None
|
||||
|
||||
def _truncate_mapping(
|
||||
self,
|
||||
mapping: Mapping[str, Any] | None,
|
||||
) -> tuple[Mapping[str, Any] | None, bool]:
|
||||
if mapping is None:
|
||||
return None, False
|
||||
if not mapping:
|
||||
return {}, False
|
||||
|
||||
normalized = WorkflowEntry.handle_special_values(dict(mapping))
|
||||
if normalized is None:
|
||||
return None, False
|
||||
|
||||
truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized))
|
||||
return truncated, is_truncated
|
||||
|
||||
@staticmethod
|
||||
def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
if outputs is None:
|
||||
return None
|
||||
converter = WorkflowRuntimeTypeConverter()
|
||||
return converter.to_json_encodable(outputs)
|
||||
|
||||
def workflow_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_execution: WorkflowExecution,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||
started_at = naive_utc_now()
|
||||
self._workflow_started_at = started_at
|
||||
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
inputs=workflow_execution.inputs,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
inputs=self._workflow_inputs,
|
||||
created_at=int(started_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_execution: WorkflowExecution,
|
||||
workflow_id: str,
|
||||
status: WorkflowExecutionStatus,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
error: str | None = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
started_at = self._workflow_started_at
|
||||
if started_at is None:
|
||||
raise ValueError(
|
||||
"workflow_finish_to_stream_response called before workflow_start_to_stream_response",
|
||||
)
|
||||
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - started_at).total_seconds()
|
||||
|
||||
outputs_mapping = graph_runtime_state.outputs or {}
|
||||
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
|
||||
|
||||
created_by: Mapping[str, object] | None
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
@@ -94,38 +227,29 @@ class WorkflowResponseConverter:
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
elif isinstance(user, EndUser):
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
# Handle the case where finished_at is None by using current time as default
|
||||
finished_at_timestamp = (
|
||||
int(workflow_execution.finished_at.timestamp())
|
||||
if workflow_execution.finished_at
|
||||
else int(datetime.now(UTC).timestamp())
|
||||
)
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_execution.id_,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
status=workflow_execution.status,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
|
||||
error=workflow_execution.error_message,
|
||||
elapsed_time=workflow_execution.elapsed_time,
|
||||
total_tokens=workflow_execution.total_tokens,
|
||||
total_steps=workflow_execution.total_steps,
|
||||
id=run_id,
|
||||
workflow_id=workflow_id,
|
||||
status=status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=error,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
finished_at=finished_at_timestamp,
|
||||
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
|
||||
exceptions_count=workflow_execution.exceptions_count,
|
||||
created_at=int(started_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(outputs_mapping),
|
||||
exceptions_count=exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -134,38 +258,28 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> NodeStartStreamResponse | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._store_snapshot(event)
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=snapshot.title,
|
||||
index=snapshot.index,
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
@@ -189,41 +303,54 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> NodeFinishStreamResponse | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
snapshot = self._pop_snapshot(event.node_execution_id)
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
start_at = snapshot.start_at if snapshot else event.start_at
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
|
||||
encoded_outputs = self._encode_outputs(event.outputs)
|
||||
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
error_message = event.error
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
error_message = event.error
|
||||
else:
|
||||
status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
error_message = event.error
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
process_data=workflow_node_execution.get_response_process_data(),
|
||||
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.metadata,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
index=snapshot.index if snapshot else 0,
|
||||
title=snapshot.title if snapshot else "",
|
||||
inputs=inputs,
|
||||
inputs_truncated=inputs_truncated,
|
||||
process_data=process_data,
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=status,
|
||||
error=error_message,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
created_at=int(start_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
@@ -234,44 +361,45 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
) -> NodeRetryStreamResponse | None:
|
||||
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
snapshot = self._get_snapshot(event.node_execution_id)
|
||||
if snapshot is None:
|
||||
raise AssertionError("node retry event arrived without a stored snapshot")
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
|
||||
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
|
||||
encoded_outputs = self._encode_outputs(event.outputs)
|
||||
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
|
||||
metadata = self._merge_metadata(event.execution_metadata, snapshot)
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.get_response_inputs(),
|
||||
inputs_truncated=workflow_node_execution.inputs_truncated,
|
||||
process_data=workflow_node_execution.get_response_process_data(),
|
||||
process_data_truncated=workflow_node_execution.process_data_truncated,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
|
||||
outputs_truncated=workflow_node_execution.outputs_truncated,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.metadata,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
index=snapshot.index,
|
||||
title=snapshot.title,
|
||||
inputs=inputs,
|
||||
inputs_truncated=inputs_truncated,
|
||||
process_data=process_data,
|
||||
process_data_truncated=process_data_truncated,
|
||||
outputs=outputs,
|
||||
outputs_truncated=outputs_truncated,
|
||||
status=WorkflowNodeExecutionStatus.RETRY.value,
|
||||
error=event.error,
|
||||
elapsed_time=elapsed_time,
|
||||
execution_metadata=metadata,
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
retry_index=event.retry_index,
|
||||
@@ -379,8 +507,6 @@ class WorkflowResponseConverter:
|
||||
inputs=new_inputs,
|
||||
inputs_truncated=truncated,
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -405,9 +531,6 @@ class WorkflowResponseConverter:
|
||||
pre_loop_output={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -446,8 +569,6 @@ class WorkflowResponseConverter:
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -352,6 +352,8 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -367,8 +369,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
@@ -573,6 +573,8 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -620,6 +622,8 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
@@ -648,8 +652,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@@ -660,7 +662,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
@@ -670,8 +671,6 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@@ -40,6 +43,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -56,6 +61,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
@@ -163,6 +170,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@@ -231,6 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -244,8 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=streaming,
|
||||
)
|
||||
@@ -424,6 +424,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -465,6 +467,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -493,8 +497,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@@ -514,8 +516,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
draft_var_saver_factory=draft_var_saver_factory,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@@ -5,12 +5,13 @@ from typing import cast
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@@ -34,6 +35,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -43,6 +46,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@@ -51,6 +56,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
@@ -60,18 +73,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
@@ -96,6 +100,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
command_channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
self._queue_manager.graph_runtime_state = graph_runtime_state
|
||||
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
app_id=self._workflow.app_id,
|
||||
@@ -115,6 +121,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
persistence_layer = WorkflowPersistenceLayer(
|
||||
application_generate_entity=self.application_generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self._workflow.id,
|
||||
workflow_type=WorkflowType(self._workflow.type),
|
||||
version=self._workflow.version,
|
||||
graph_data=self._workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@@ -8,11 +8,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
@@ -53,27 +51,20 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
)
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
@@ -85,8 +76,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
):
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
@@ -99,42 +88,30 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self._user_id = user.id
|
||||
user_session_id = user.session_id
|
||||
self._created_by_role = CreatorUserRole.END_USER
|
||||
elif isinstance(user, Account):
|
||||
else:
|
||||
self._user_id = user.id
|
||||
user_session_id = user.id
|
||||
self._created_by_role = CreatorUserRole.ACCOUNT
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables=SystemVariable(
|
||||
files=application_generate_entity.files,
|
||||
user_id=user_session_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
),
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_run_id = ""
|
||||
self._workflow_execution_id = ""
|
||||
self._invoke_from = queue_manager.invoke_from
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = SystemVariable(
|
||||
files=application_generate_entity.files,
|
||||
user_id=user_session_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_execution_id=application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
self._workflow_response_converter = WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=self._workflow_system_variables,
|
||||
)
|
||||
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -261,15 +238,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
def _ensure_workflow_initialized(self):
|
||||
"""Fluent validation for workflow state."""
|
||||
if not self._workflow_run_id:
|
||||
if not self._workflow_execution_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
return graph_runtime_state
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline.ping_stream_response()
|
||||
@@ -283,12 +254,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self, event: QueueWorkflowStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_execution_id = run_id
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
@@ -296,14 +269,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
@@ -315,13 +283,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
"""Handle node started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_response:
|
||||
@@ -331,14 +295,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self, event: QueueNodeSucceededEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
@@ -349,17 +311,13 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
self._save_output_for_event(event, event.node_execution_id)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
@@ -372,7 +330,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_start_resp
|
||||
@@ -385,7 +343,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_next_resp
|
||||
@@ -398,7 +356,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_finish_resp
|
||||
@@ -409,7 +367,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_start_resp
|
||||
@@ -420,7 +378,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_next_resp
|
||||
@@ -433,7 +391,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_finish_resp
|
||||
@@ -442,33 +400,22 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
@@ -476,34 +423,23 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
graph_runtime_state=validated_state,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
@@ -511,37 +447,33 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed and stop events."""
|
||||
_ = trace_manager
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
|
||||
if isinstance(event, QueueWorkflowFailedEvent):
|
||||
status = WorkflowExecutionStatus.FAILED
|
||||
error = event.error
|
||||
exceptions_count = event.exceptions_count
|
||||
else:
|
||||
status = WorkflowExecutionStatus.STOPPED
|
||||
error = event.get_stop_reason()
|
||||
exceptions_count = 0
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=status,
|
||||
graph_runtime_state=validated_state,
|
||||
error=error,
|
||||
exceptions_count=exceptions_count,
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
@@ -601,7 +533,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
self,
|
||||
event: AppQueueEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
tts_publisher: AppGeneratorTTSPublisher | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
|
||||
@@ -614,7 +545,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if handler := handlers.get(event_type):
|
||||
yield from handler(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -631,7 +561,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
):
|
||||
yield from self._handle_node_failed_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -642,7 +571,6 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -661,15 +589,12 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
Maintains exact same functionality as original 44-if-statement version.
|
||||
"""
|
||||
# Initialize graph runtime state
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._base_task_pipeline.queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
match event:
|
||||
case QueueWorkflowStartedEvent():
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
self._resolve_graph_runtime_state()
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
case QueueTextChunkEvent():
|
||||
@@ -681,12 +606,19 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
|
||||
# Handle all other events through elegant dispatch
|
||||
case _:
|
||||
if responses := list(
|
||||
self._dispatch_event(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
@@ -697,7 +629,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
@@ -709,11 +641,14 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
|
||||
workflow_app_log.workflow_id = workflow_execution.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_execution.id_
|
||||
workflow_app_log.workflow_id = self._workflow.id
|
||||
workflow_app_log.workflow_run_id = workflow_run_id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = self._created_by_role
|
||||
workflow_app_log.created_by = self._user_id
|
||||
|
||||
@@ -25,7 +25,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
@@ -54,6 +54,7 @@ from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@@ -346,9 +347,7 @@ class WorkflowBasedAppRunner:
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
|
||||
)
|
||||
self._publish_event(QueueWorkflowStartedEvent())
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
@@ -372,7 +371,6 @@ class WorkflowBasedAppRunner:
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
inputs=inputs,
|
||||
@@ -393,7 +391,6 @@ class WorkflowBasedAppRunner:
|
||||
node_title=event.node_title,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
@@ -494,7 +491,6 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
@@ -536,7 +532,6 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,11 +3,11 @@ from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@@ -54,6 +54,7 @@ class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
|
||||
event: QueueEvent
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
@@ -80,7 +81,6 @@ class QueueIterationStartEvent(AppQueueEvent):
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -132,19 +132,10 @@ class QueueLoopStartEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
predecessor_node_id: str | None = None
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -160,16 +151,6 @@ class QueueLoopNextEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: str | None = None
|
||||
"""iteration run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Any = None # output for the current loop
|
||||
|
||||
@@ -185,14 +166,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
@@ -285,12 +258,9 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
|
||||
|
||||
class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
"""QueueWorkflowStartedEvent entity."""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
@@ -334,15 +304,9 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_title: str
|
||||
node_type: NodeType
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
predecessor_node_id: str | None = None
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
@@ -360,14 +324,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
@@ -423,14 +379,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: str | None = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: str | None = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: str | None = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
@@ -455,7 +403,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: str | None = None
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
|
||||
@@ -257,13 +257,8 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
inputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
parallel_run_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@@ -285,10 +280,6 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
},
|
||||
@@ -324,10 +315,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
|
||||
@@ -357,10 +344,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
},
|
||||
@@ -396,10 +379,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parent_parallel_id: str | None = None
|
||||
parent_parallel_start_node_id: str | None = None
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
retry_index: int = 0
|
||||
@@ -430,10 +409,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
@@ -541,8 +516,6 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_STARTED
|
||||
workflow_run_id: str
|
||||
@@ -567,9 +540,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
pre_loop_output: Any = None
|
||||
extras: Mapping[str, object] = Field(default_factory=dict)
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
parallel_mode_run_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
@@ -603,8 +573,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_COMPLETED
|
||||
workflow_run_id: str
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
@@ -104,7 +104,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# Initialize in-memory cache for node executions
|
||||
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
|
||||
|
||||
# Initialize FileService for handling offloaded data
|
||||
@@ -332,17 +331,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
Args:
|
||||
execution: The NodeExecution domain entity to persist
|
||||
"""
|
||||
# NOTE: As per the implementation of `WorkflowCycleManager`,
|
||||
# the `save` method is invoked multiple times during the node's execution lifecycle, including:
|
||||
#
|
||||
# - When the node starts execution
|
||||
# - When the node retries execution
|
||||
# - When the node completes execution (either successfully or with failure)
|
||||
#
|
||||
# Only the final invocation will have `inputs` and `outputs` populated.
|
||||
#
|
||||
# This simplifies the logic for saving offloaded variables but introduces a tight coupling
|
||||
# between this module and `WorkflowCycleManager`.
|
||||
# NOTE: The workflow engine triggers `save` multiple times for a single node execution:
|
||||
# when the node starts, any time it retries, and once more when it reaches a terminal state.
|
||||
# Only the final call contains the complete inputs and outputs payloads, so earlier invocations
|
||||
# must tolerate missing data without attempting to offload variables.
|
||||
|
||||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self._to_db_model(execution)
|
||||
|
||||
@@ -58,8 +58,8 @@ from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .run_condition import RunCondition
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"RunCondition",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
# Private attributes to prevent direct modification
|
||||
_variable_pool: VariablePool = PrivateAttr()
|
||||
_start_at: float = PrivateAttr()
|
||||
_total_tokens: int = PrivateAttr(default=0)
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
_ready_queue_json: str = PrivateAttr()
|
||||
_graph_execution_json: str = PrivateAttr()
|
||||
_response_coordinator_json: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_json: str = "",
|
||||
graph_execution_json: str = "",
|
||||
response_coordinator_json: str = "",
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize private attributes with validation
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
if llm_usage is None:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
self._llm_usage = llm_usage
|
||||
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
self._outputs = deepcopy(outputs)
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._ready_queue_json = ready_queue_json
|
||||
self._graph_execution_json = graph_execution_json
|
||||
self._response_coordinator_json = response_coordinator_json
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time."""
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
"""Set the start time."""
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count."""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int):
|
||||
"""Set the total tokens count."""
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get the LLM usage info."""
|
||||
# Return a copy to prevent external modification
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage):
|
||||
"""Set the LLM usage info."""
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
"""Get a copy of the outputs dictionary."""
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, object]) -> None:
|
||||
"""Set the outputs dictionary."""
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
"""Set a single output value."""
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
"""Get a single output value."""
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
"""Update multiple output values."""
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count."""
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
"""Set the node run steps count."""
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
"""Increment the node run steps by 1."""
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""Add tokens to the total count."""
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
@property
|
||||
def ready_queue_json(self) -> str:
|
||||
"""Get a copy of the ready queue state."""
|
||||
return self._ready_queue_json
|
||||
|
||||
@property
|
||||
def graph_execution_json(self) -> str:
|
||||
"""Get a copy of the serialized graph execution state."""
|
||||
return self._graph_execution_json
|
||||
|
||||
@property
|
||||
def response_coordinator_json(self) -> str:
|
||||
"""Get a copy of the serialized response coordinator state."""
|
||||
return self._response_coordinator_json
|
||||
@@ -1,21 +0,0 @@
|
||||
import hashlib
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: str | None = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: list[Condition] | None = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
@@ -58,6 +58,7 @@ class NodeType(StrEnum):
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
HUMAN_INPUT = "human-input"
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
@@ -96,6 +97,7 @@ class WorkflowExecutionStatus(StrEnum):
|
||||
FAILED = "failed"
|
||||
STOPPED = "stopped"
|
||||
PARTIAL_SUCCEEDED = "partial-succeeded"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .graph import Graph, GraphBuilder, NodeFactory
|
||||
from .graph_template import GraphTemplate
|
||||
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphBuilder",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
]
|
||||
|
||||
@@ -195,6 +195,12 @@ class Graph:
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> "GraphBuilder":
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
@@ -344,3 +350,96 @@ class Graph:
|
||||
"""
|
||||
edge_ids = self.in_edges.get(node_id, [])
|
||||
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
|
||||
|
||||
|
||||
@final
|
||||
class GraphBuilder:
|
||||
"""Fluent helper for constructing simple graphs, primarily for tests."""
|
||||
|
||||
def __init__(self, *, graph_cls: type[Graph]):
|
||||
self._graph_cls = graph_cls
|
||||
self._nodes: list[Node] = []
|
||||
self._nodes_by_id: dict[str, Node] = {}
|
||||
self._edges: list[Edge] = []
|
||||
self._edge_counter = 0
|
||||
|
||||
def add_root(self, node: Node) -> "GraphBuilder":
|
||||
"""Register the root node. Must be called exactly once."""
|
||||
|
||||
if self._nodes:
|
||||
raise ValueError("Root node has already been added")
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
return self
|
||||
|
||||
def add_node(
|
||||
self,
|
||||
node: Node,
|
||||
*,
|
||||
from_node_id: str | None = None,
|
||||
source_handle: str = "source",
|
||||
) -> "GraphBuilder":
|
||||
"""Append a node and connect it from the specified predecessor."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Root node must be added before adding other nodes")
|
||||
|
||||
predecessor_id = from_node_id or self._nodes[-1].id
|
||||
if predecessor_id not in self._nodes_by_id:
|
||||
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
|
||||
|
||||
predecessor = self._nodes_by_id[predecessor_id]
|
||||
self._register_node(node)
|
||||
self._nodes.append(node)
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
|
||||
"""Connect two existing nodes without adding a new node."""
|
||||
|
||||
if tail not in self._nodes_by_id:
|
||||
raise ValueError(f"Tail node '{tail}' not found")
|
||||
if head not in self._nodes_by_id:
|
||||
raise ValueError(f"Head node '{head}' not found")
|
||||
|
||||
edge_id = f"edge_{self._edge_counter}"
|
||||
self._edge_counter += 1
|
||||
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
|
||||
self._edges.append(edge)
|
||||
|
||||
return self
|
||||
|
||||
def build(self) -> Graph:
|
||||
"""Materialize the graph instance from the accumulated nodes and edges."""
|
||||
|
||||
if not self._nodes:
|
||||
raise ValueError("Cannot build an empty graph")
|
||||
|
||||
nodes = {node.id: node for node in self._nodes}
|
||||
edges = {edge.id: edge for edge in self._edges}
|
||||
in_edges: dict[str, list[str]] = defaultdict(list)
|
||||
out_edges: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for edge in self._edges:
|
||||
out_edges[edge.tail].append(edge.id)
|
||||
in_edges[edge.head].append(edge.id)
|
||||
|
||||
return self._graph_cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=dict(in_edges),
|
||||
out_edges=dict(out_edges),
|
||||
root_node=self._nodes[0],
|
||||
)
|
||||
|
||||
def _register_node(self, node: Node) -> None:
|
||||
if not node.id:
|
||||
raise ValueError("Node must have a non-empty id")
|
||||
if node.id in self._nodes_by_id:
|
||||
raise ValueError(f"Duplicate node id detected: {node.id}")
|
||||
self._nodes_by_id[node.id] = node
|
||||
|
||||
@@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
@@ -111,9 +111,11 @@ class RedisChannel:
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand.model_validate(data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
if command_type == CommandType.PAUSE:
|
||||
return PauseCommand.model_validate(data)
|
||||
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@@ -5,10 +5,11 @@ This package handles external commands sent to the engine
|
||||
during execution.
|
||||
"""
|
||||
|
||||
from .command_handlers import AbortCommandHandler
|
||||
from .command_handlers import AbortCommandHandler, PauseCommandHandler
|
||||
from .command_processor import CommandProcessor
|
||||
|
||||
__all__ = [
|
||||
"AbortCommandHandler",
|
||||
"CommandProcessor",
|
||||
"PauseCommandHandler",
|
||||
]
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""
|
||||
Command handler implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand
|
||||
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from .command_processor import CommandHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -16,17 +12,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@final
|
||||
class AbortCommandHandler(CommandHandler):
|
||||
"""Handles abort commands."""
|
||||
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
"""
|
||||
Handle an abort command.
|
||||
|
||||
Args:
|
||||
command: The abort command
|
||||
execution: Graph execution to abort
|
||||
"""
|
||||
assert isinstance(command, AbortCommand)
|
||||
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.abort(command.reason or "User requested abort")
|
||||
|
||||
|
||||
@final
|
||||
class PauseCommandHandler(CommandHandler):
|
||||
@override
|
||||
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
|
||||
assert isinstance(command, PauseCommand)
|
||||
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
|
||||
execution.pause(command.reason)
|
||||
|
||||
@@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel):
|
||||
started: bool = Field(default=False)
|
||||
completed: bool = Field(default=False)
|
||||
aborted: bool = Field(default=False)
|
||||
paused: bool = Field(default=False)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
error: GraphExecutionErrorState | None = Field(default=None)
|
||||
exceptions_count: int = Field(default=0)
|
||||
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
|
||||
@@ -103,6 +105,8 @@ class GraphExecution:
|
||||
started: bool = False
|
||||
completed: bool = False
|
||||
aborted: bool = False
|
||||
paused: bool = False
|
||||
pause_reason: str | None = None
|
||||
error: Exception | None = None
|
||||
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
|
||||
exceptions_count: int = 0
|
||||
@@ -126,6 +130,17 @@ class GraphExecution:
|
||||
self.aborted = True
|
||||
self.error = RuntimeError(f"Aborted: {reason}")
|
||||
|
||||
def pause(self, reason: str | None = None) -> None:
|
||||
"""Pause the graph execution without marking it complete."""
|
||||
if self.completed:
|
||||
raise RuntimeError("Cannot pause execution that has completed")
|
||||
if self.aborted:
|
||||
raise RuntimeError("Cannot pause execution that has been aborted")
|
||||
if self.paused:
|
||||
return
|
||||
self.paused = True
|
||||
self.pause_reason = reason
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Mark the graph execution as failed."""
|
||||
self.error = error
|
||||
@@ -140,7 +155,12 @@ class GraphExecution:
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the execution is currently running."""
|
||||
return self.started and not self.completed and not self.aborted
|
||||
return self.started and not self.completed and not self.aborted and not self.paused
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Check if the execution is currently paused."""
|
||||
return self.paused
|
||||
|
||||
@property
|
||||
def has_error(self) -> bool:
|
||||
@@ -173,6 +193,8 @@ class GraphExecution:
|
||||
started=self.started,
|
||||
completed=self.completed,
|
||||
aborted=self.aborted,
|
||||
paused=self.paused,
|
||||
pause_reason=self.pause_reason,
|
||||
error=_serialize_error(self.error),
|
||||
exceptions_count=self.exceptions_count,
|
||||
node_executions=node_states,
|
||||
@@ -197,6 +219,8 @@ class GraphExecution:
|
||||
self.started = state.started
|
||||
self.completed = state.completed
|
||||
self.aborted = state.aborted
|
||||
self.paused = state.paused
|
||||
self.pause_reason = state.pause_reason
|
||||
self.error = _deserialize_error(state.error)
|
||||
self.exceptions_count = state.exceptions_count
|
||||
self.node_executions = {
|
||||
|
||||
@@ -16,7 +16,6 @@ class CommandType(StrEnum):
|
||||
|
||||
ABORT = "abort"
|
||||
PAUSE = "pause"
|
||||
RESUME = "resume"
|
||||
|
||||
|
||||
class GraphEngineCommand(BaseModel):
|
||||
@@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand):
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for abort")
|
||||
|
||||
|
||||
class PauseCommand(GraphEngineCommand):
|
||||
"""Command to pause a running workflow execution."""
|
||||
|
||||
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
|
||||
reason: str | None = Field(default=None, description="Optional reason for pause")
|
||||
|
||||
@@ -8,8 +8,7 @@ from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -24,11 +23,13 @@ from core.workflow.graph_events import (
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from ..domain.graph_execution import GraphExecution
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
@@ -203,6 +204,18 @@ class EventHandler:
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
"""Handle pause requests emitted by nodes."""
|
||||
|
||||
pause_reason = event.reason or "Awaiting human input"
|
||||
self._graph_execution.pause(pause_reason)
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
if event.node_id in self._graph.nodes:
|
||||
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
|
||||
self._graph_runtime_state.register_paused_node(event.node_id)
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
|
||||
@@ -97,6 +97,10 @@ class EventManager:
|
||||
"""
|
||||
self._layers = layers
|
||||
|
||||
def notify_layers(self, event: GraphEngineEvent) -> None:
|
||||
"""Notify registered layers about an event without buffering it."""
|
||||
self._notify_layers(event)
|
||||
|
||||
def collect(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Thread-safe method to collect an event.
|
||||
|
||||
@@ -9,28 +9,29 @@ import contextvars
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor
|
||||
from .domain import GraphExecution
|
||||
from .entities.commands import AbortCommand
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from core.workflow.runtime.graph_runtime_state import GraphProtocol
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
|
||||
from .entities.commands import AbortCommand, PauseCommand
|
||||
from .error_handler import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_state_manager import GraphStateManager
|
||||
@@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .ready_queue import ReadyQueue
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -67,17 +71,16 @@ class GraphEngine:
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
if graph_runtime_state.graph_execution_json != "":
|
||||
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
|
||||
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
self._graph_execution.workflow_id = workflow_id
|
||||
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
@@ -86,13 +89,7 @@ class GraphEngine:
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# === Execution Queues ===
|
||||
# Create ready queue from saved state or initialize new one
|
||||
self._ready_queue: ReadyQueue
|
||||
if self._graph_runtime_state.ready_queue_json == "":
|
||||
self._ready_queue = InMemoryReadyQueue()
|
||||
else:
|
||||
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
|
||||
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
|
||||
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
|
||||
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
@@ -103,11 +100,7 @@ class GraphEngine:
|
||||
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
if graph_runtime_state.response_coordinator_json != "":
|
||||
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
|
||||
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
@@ -133,19 +126,6 @@ class GraphEngine:
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
@@ -153,12 +133,12 @@ class GraphEngine:
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register abort command handler
|
||||
# Register command handlers
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
self._command_processor.register_handler(AbortCommand, abort_handler)
|
||||
|
||||
pause_handler = PauseCommandHandler()
|
||||
self._command_processor.register_handler(PauseCommand, pause_handler)
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
@@ -191,12 +171,23 @@ class GraphEngine:
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
@@ -237,26 +228,41 @@ class GraphEngine:
|
||||
# Initialize layers
|
||||
self._initialize_layers()
|
||||
|
||||
# Start execution
|
||||
self._graph_execution.start()
|
||||
is_resume = self._graph_execution.started
|
||||
if not is_resume:
|
||||
self._graph_execution.start()
|
||||
else:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reason = None
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
# Start subsystems
|
||||
self._start_execution()
|
||||
self._start_execution(resume=is_resume)
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.aborted:
|
||||
if self._graph_execution.is_paused:
|
||||
paused_event = GraphRunPausedEvent(
|
||||
reason=self._graph_execution.pause_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(paused_event)
|
||||
yield paused_event
|
||||
elif self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
yield GraphRunAbortedEvent(
|
||||
aborted_event = GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(aborted_event)
|
||||
yield aborted_event
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
@@ -264,20 +270,26 @@ class GraphEngine:
|
||||
outputs = self._graph_runtime_state.outputs
|
||||
exceptions_count = self._graph_execution.exceptions_count
|
||||
if exceptions_count > 0:
|
||||
yield GraphRunPartialSucceededEvent(
|
||||
partial_event = GraphRunPartialSucceededEvent(
|
||||
exceptions_count=exceptions_count,
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(partial_event)
|
||||
yield partial_event
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
succeeded_event = GraphRunSucceededEvent(
|
||||
outputs=outputs,
|
||||
)
|
||||
self._event_manager.notify_layers(succeeded_event)
|
||||
yield succeeded_event
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(
|
||||
failed_event = GraphRunFailedEvent(
|
||||
error=str(e),
|
||||
exceptions_count=self._graph_execution.exceptions_count,
|
||||
)
|
||||
self._event_manager.notify_layers(failed_event)
|
||||
yield failed_event
|
||||
raise
|
||||
|
||||
finally:
|
||||
@@ -299,8 +311,12 @@ class GraphEngine:
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
|
||||
def _start_execution(self) -> None:
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
@@ -309,10 +325,15 @@ class GraphEngine:
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
if not resume:
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
for node_id in paused_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Start dispatcher
|
||||
self._dispatcher.start()
|
||||
|
||||
@@ -7,9 +7,9 @@ intercept and respond to GraphEngine events.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
from core.workflow.runtime import ReadOnlyGraphRuntimeState
|
||||
|
||||
|
||||
class GraphEngineLayer(ABC):
|
||||
|
||||
410
api/core/workflow/graph_engine/layers/persistence.py
Normal file
410
api/core/workflow/graph_engine/layers/persistence.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""Workflow persistence layer for GraphEngine.
|
||||
|
||||
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
|
||||
listening to ``GraphEngineEvent`` instances directly and persisting workflow
|
||||
and node execution state via the injected repositories.
|
||||
|
||||
The design keeps domain persistence concerns inside the engine thread, while
|
||||
allowing presentation layers to remain read-only observers of repository
|
||||
state.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PersistenceWorkflowInfo:
|
||||
"""Static workflow metadata required for persistence."""
|
||||
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeRuntimeSnapshot:
|
||||
"""Lightweight cache to keep node metadata across event phases."""
|
||||
|
||||
node_id: str
|
||||
title: str
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
"""GraphEngine layer that persists workflow and node execution state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_info: PersistenceWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._trace_manager = trace_manager
|
||||
|
||||
self._workflow_execution: WorkflowExecution | None = None
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
|
||||
self._node_sequence: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._handle_graph_run_started()
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
self._handle_graph_run_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._handle_graph_run_partial_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
self._handle_graph_run_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunAbortedEvent):
|
||||
self._handle_graph_run_aborted(event)
|
||||
return
|
||||
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._handle_graph_run_paused(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self._handle_node_started(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunRetryEvent):
|
||||
self._handle_node_retry(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
self._handle_node_succeeded(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
self._handle_node_failed(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunExceptionEvent):
|
||||
self._handle_node_exception(event)
|
||||
return
|
||||
|
||||
if isinstance(event, NodeRunPauseRequestedEvent):
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Graph-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_graph_run_started(self) -> None:
|
||||
execution_id = self._get_execution_id()
|
||||
workflow_execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=self._prepare_workflow_inputs(),
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
self._workflow_execution = workflow_execution
|
||||
|
||||
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.outputs = event.outputs
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.FAILED
|
||||
execution.error_message = event.error
|
||||
execution.exceptions_count = event.exceptions_count
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=event.error)
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.STOPPED
|
||||
execution.error_message = event.reason or "Workflow execution aborted"
|
||||
self._populate_completion_statistics(execution)
|
||||
|
||||
self._fail_running_node_executions(error_message=execution.error_message or "")
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
execution.status = WorkflowExecutionStatus.PAUSED
|
||||
execution.error_message = event.reason or "Workflow execution paused"
|
||||
execution.outputs = event.outputs
|
||||
self._populate_completion_statistics(execution, update_finished=False)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Node-level handlers
|
||||
# ------------------------------------------------------------------
|
||||
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=event.id,
|
||||
node_execution_id=event.id,
|
||||
workflow_id=execution.workflow_id,
|
||||
workflow_execution_id=execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=self._next_node_sequence(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
metadata=metadata,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
self._node_execution_cache[event.id] = domain_execution
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
snapshot = _NodeRuntimeSnapshot(
|
||||
node_id=event.node_id,
|
||||
title=event.node_title,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
|
||||
domain_execution.error = event.error
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(
|
||||
domain_execution,
|
||||
event.node_run_result,
|
||||
WorkflowNodeExecutionStatus.PAUSED,
|
||||
error=event.reason,
|
||||
update_outputs=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _get_execution_id(self) -> str:
|
||||
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
if not workflow_execution_id:
|
||||
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
|
||||
return str(workflow_execution_id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for field_name, value in self._system_variables().items():
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID.value:
|
||||
# Conversation IDs are tied to the current session; omit them so persisted
|
||||
# workflow inputs stay reusable without binding future runs to this conversation.
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
handled = WorkflowEntry.handle_special_values(inputs)
|
||||
return handled or {}
|
||||
|
||||
def _get_workflow_execution(self) -> WorkflowExecution:
|
||||
if self._workflow_execution is None:
|
||||
raise ValueError("workflow execution not initialized")
|
||||
return self._workflow_execution
|
||||
|
||||
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
if node_execution_id not in self._node_execution_cache:
|
||||
raise ValueError(f"Node execution not found for id={node_execution_id}")
|
||||
return self._node_execution_cache[node_execution_id]
|
||||
|
||||
def _next_node_sequence(self) -> int:
|
||||
self._node_sequence += 1
|
||||
return self._node_sequence
|
||||
|
||||
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
|
||||
if update_finished:
|
||||
execution.finished_at = naive_utc_now()
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
node_result: NodeRunResult,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
update_outputs: bool = True,
|
||||
) -> None:
|
||||
finished_at = naive_utc_now()
|
||||
snapshot = self._node_snapshots.get(domain_execution.id)
|
||||
start_at = snapshot.created_at if snapshot else domain_execution.created_at
|
||||
domain_execution.status = status
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
if update_outputs:
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=node_result.inputs,
|
||||
process_data=node_result.process_data,
|
||||
outputs=node_result.outputs,
|
||||
metadata=node_result.metadata,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
for execution in self._node_execution_cache.values():
|
||||
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
|
||||
execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
execution.error = error_message
|
||||
execution.finished_at = now
|
||||
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
|
||||
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
if runtime_state is None:
|
||||
return {}
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
@@ -9,7 +9,7 @@ Supports stop, pause, and resume operations.
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class GraphEngineManager:
|
||||
|
||||
This class provides a simple interface for controlling workflow executions
|
||||
by sending commands through Redis channels, without user validation.
|
||||
Supports stop, pause, and resume operations.
|
||||
Supports stop and pause operations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -32,19 +32,29 @@ class GraphEngineManager:
|
||||
task_id: The task ID of the workflow to stop
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
# Create Redis channel for this task
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
# Create and send abort command
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
|
||||
try:
|
||||
channel.send_command(abort_command)
|
||||
channel.send_command(command)
|
||||
except Exception:
|
||||
# Silently fail if Redis is unavailable
|
||||
# The legacy stop flag mechanism will still work
|
||||
# The legacy control mechanisms will still work
|
||||
pass
|
||||
|
||||
@@ -33,6 +33,12 @@ class Dispatcher:
|
||||
with timeout and completion detection.
|
||||
"""
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
@@ -77,33 +83,41 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
|
||||
_COMMAND_TRIGGER_EVENTS = (
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
commands_checked = False
|
||||
should_check_commands = False
|
||||
should_break = False
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
if self._should_check_commands(event):
|
||||
self._execution_coordinator.check_commands()
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Process commands even when no new events arrive so abort requests are not missed
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
should_check_commands = True
|
||||
should_break = True
|
||||
else:
|
||||
# Check for scaling
|
||||
self._execution_coordinator.check_scaling()
|
||||
|
||||
# Process events
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
# Route to the event handler
|
||||
self._event_handler.dispatch(event)
|
||||
should_check_commands = self._should_check_commands(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
# Process commands even when no new events arrive so abort requests are not missed
|
||||
should_check_commands = True
|
||||
time.sleep(0.1)
|
||||
|
||||
if should_check_commands and not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
# Check if execution is complete
|
||||
if self._execution_coordinator.is_execution_complete():
|
||||
break
|
||||
commands_checked = True
|
||||
|
||||
if should_break:
|
||||
if not commands_checked:
|
||||
self._execution_coordinator.check_commands()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
|
||||
@@ -2,17 +2,13 @@
|
||||
Execution coordinator for managing overall workflow execution.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventManager
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionCoordinator:
|
||||
@@ -27,8 +23,6 @@ class ExecutionCoordinator:
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
state_manager: GraphStateManager,
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
@@ -38,15 +32,11 @@ class ExecutionCoordinator:
|
||||
Args:
|
||||
graph_execution: Graph execution aggregate
|
||||
state_manager: Unified state manager
|
||||
event_handler: Event handler registry for processing events
|
||||
event_collector: Event manager for collecting events
|
||||
command_processor: Processor for commands
|
||||
worker_pool: Pool of workers
|
||||
"""
|
||||
self._graph_execution = graph_execution
|
||||
self._state_manager = state_manager
|
||||
self._event_handler = event_handler
|
||||
self._event_collector = event_collector
|
||||
self._command_processor = command_processor
|
||||
self._worker_pool = worker_pool
|
||||
|
||||
@@ -65,15 +55,24 @@ class ExecutionCoordinator:
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
# Check if aborted or failed
|
||||
# Treat paused, aborted, or failed executions as terminal states
|
||||
if self._graph_execution.is_paused:
|
||||
return True
|
||||
|
||||
if self._graph_execution.aborted or self._graph_execution.has_error:
|
||||
return True
|
||||
|
||||
# Complete if no work remains
|
||||
return self._state_manager.is_execution_complete()
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
"""Expose whether the underlying graph execution is paused."""
|
||||
return self._graph_execution.is_paused
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark execution as complete."""
|
||||
if self._graph_execution.is_paused:
|
||||
return
|
||||
if not self._graph_execution.completed:
|
||||
self._graph_execution.complete()
|
||||
|
||||
@@ -85,3 +84,21 @@ class ExecutionCoordinator:
|
||||
error: The error that caused failure
|
||||
"""
|
||||
self._graph_execution.fail(error)
|
||||
|
||||
def handle_pause_if_needed(self) -> None:
|
||||
"""If the execution has been paused, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.is_paused:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def handle_abort_if_needed(self) -> None:
|
||||
"""If the execution has been aborted, stop workers immediately."""
|
||||
|
||||
if not self._graph_execution.aborted:
|
||||
return
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
@@ -14,11 +14,11 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .path import Path
|
||||
from .session import ResponseSession
|
||||
|
||||
@@ -13,6 +13,7 @@ from .graph import (
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
@@ -37,6 +38,7 @@ from .loop import (
|
||||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
@@ -51,6 +53,7 @@ __all__ = [
|
||||
"GraphRunAbortedEvent",
|
||||
"GraphRunFailedEvent",
|
||||
"GraphRunPartialSucceededEvent",
|
||||
"GraphRunPausedEvent",
|
||||
"GraphRunStartedEvent",
|
||||
"GraphRunSucceededEvent",
|
||||
"NodeRunAgentLogEvent",
|
||||
@@ -64,6 +67,7 @@ __all__ = [
|
||||
"NodeRunLoopNextEvent",
|
||||
"NodeRunLoopStartedEvent",
|
||||
"NodeRunLoopSucceededEvent",
|
||||
"NodeRunPauseRequestedEvent",
|
||||
"NodeRunRetrieverResourceEvent",
|
||||
"NodeRunRetryEvent",
|
||||
"NodeRunStartedEvent",
|
||||
|
||||
@@ -8,7 +8,12 @@ class GraphRunStartedEvent(BaseGraphEvent):
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: dict[str, object] = Field(default_factory=dict)
|
||||
"""Event emitted when a run completes successfully with final outputs."""
|
||||
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Final workflow outputs keyed by output selector.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
@@ -17,12 +22,30 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
||||
|
||||
|
||||
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||
"""Event emitted when a run finishes with partial success and failures."""
|
||||
|
||||
exceptions_count: int = Field(..., description="exception count")
|
||||
outputs: dict[str, object] = Field(default_factory=dict)
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs that were materialised before failures occurred.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunAbortedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is aborted by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for abort")
|
||||
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs produced before the abort was requested.",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunPausedEvent(BaseGraphEvent):
|
||||
"""Event emitted when a graph run is paused by user command."""
|
||||
|
||||
reason: str | None = Field(default=None, description="reason for pause")
|
||||
outputs: dict[str, object] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs available to the client while the run is paused.",
|
||||
)
|
||||
|
||||
@@ -51,3 +51,7 @@ class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
|
||||
@@ -14,6 +14,7 @@ from .loop import (
|
||||
)
|
||||
from .node import (
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
@@ -33,6 +34,7 @@ __all__ = [
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEventBase",
|
||||
"NodeRunResult",
|
||||
"PauseRequestedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
|
||||
@@ -40,3 +40,7 @@ class StreamChunkEvent(NodeEventBase):
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
node_run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: str | None = Field(default=None, description="Optional pause reason")
|
||||
|
||||
@@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
@@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, ClassVar
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -20,6 +20,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
@@ -37,10 +38,12 @@ from core.workflow.node_events import (
|
||||
LoopSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
@@ -385,6 +388,16 @@ class Node:
|
||||
f"Node {self._node_id} does not support status {event.node_run_result.status}"
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
|
||||
return NodeRunPauseRequestedEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
|
||||
return NodeRunAgentLogEvent(
|
||||
|
||||
@@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
@@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
|
||||
@@ -15,7 +15,7 @@ from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
|
||||
3
api/core/workflow/nodes/human_input/__init__.py
Normal file
3
api/core/workflow/nodes/human_input/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode"]
|
||||
10
api/core/workflow/nodes/human_input/entities.py
Normal file
10
api/core/workflow/nodes/human_input/entities.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Configuration schema for the HumanInput node."""
|
||||
|
||||
required_variables: list[str] = Field(default_factory=list)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
132
api/core/workflow/nodes/human_input/human_input_node.py
Normal file
132
api/core/workflow/nodes/human_input/human_input_node.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
|
||||
class HumanInputNode(Node):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
|
||||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
"branchId",
|
||||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = HumanInputNodeData(**data)
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self._node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self._node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
|
||||
if handle:
|
||||
return handle
|
||||
|
||||
default_values = self._node_data.default_value_dict
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._normalize_branch_value(default_values.get(key))
|
||||
if handle:
|
||||
return handle
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_branch_handle(segment: Any) -> str | None:
|
||||
if segment is None:
|
||||
return None
|
||||
|
||||
candidate = getattr(segment, "to_object", None)
|
||||
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
|
||||
if raw_value is None:
|
||||
return None
|
||||
|
||||
return HumanInputNode._normalize_branch_value(raw_value)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_branch_value(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
|
||||
candidate = value.get(key)
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
@@ -3,12 +3,12 @@ from typing import Any, Literal
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@@ -38,6 +37,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
@@ -557,11 +557,12 @@ class IterationNode(Node):
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@@ -9,13 +9,13 @@ from sqlalchemy import func, select
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ from .exc import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
|
||||
@@ -52,7 +52,7 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
@@ -71,6 +71,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -406,11 +406,12 @@ class LoopNode(Node):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
@@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: {
|
||||
LATEST_VERSION: HumanInputNode,
|
||||
"1": HumanInputNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
|
||||
@@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
|
||||
@@ -41,7 +41,7 @@ from .template_prompts import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class QuestionClassifierNode(Node):
|
||||
|
||||
@@ -36,7 +36,7 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node):
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
14
api/core/workflow/runtime/__init__.py
Normal file
14
api/core/workflow/runtime/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
|
||||
__all__ = [
|
||||
"GraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
"VariablePool",
|
||||
"VariableValue",
|
||||
]
|
||||
393
api/core/workflow/runtime/graph_runtime_state.py
Normal file
393
api/core/workflow/runtime/graph_runtime_state.py
Normal file
@@ -0,0 +1,393 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping as TypingMapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
|
||||
def put(self, item: str) -> None:
|
||||
"""Enqueue the identifier of a node that is ready to run."""
|
||||
...
|
||||
|
||||
def get(self, timeout: float | None = None) -> str:
|
||||
"""Return the next node identifier, blocking until available or timeout expires."""
|
||||
...
|
||||
|
||||
def task_done(self) -> None:
|
||||
"""Signal that the most recently dequeued node has completed processing."""
|
||||
...
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Return True when the queue contains no pending nodes."""
|
||||
...
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Approximate the number of pending nodes awaiting execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the queue contents for persistence."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore the queue contents from a serialized payload."""
|
||||
...
|
||||
|
||||
|
||||
class GraphExecutionProtocol(Protocol):
|
||||
"""Structural interface for graph execution aggregate."""
|
||||
|
||||
workflow_id: str
|
||||
started: bool
|
||||
completed: bool
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
|
||||
def complete(self) -> None:
|
||||
"""Mark execution as successfully completed."""
|
||||
...
|
||||
|
||||
def abort(self, reason: str) -> None:
|
||||
"""Abort execution in response to an external stop request."""
|
||||
...
|
||||
|
||||
def fail(self, error: Exception) -> None:
|
||||
"""Record an unrecoverable error and end execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize execution state into a JSON payload."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore execution state from a previously serialized payload."""
|
||||
...
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
"""Structural interface for response stream coordinator."""
|
||||
|
||||
def register(self, response_node_id: str) -> None:
|
||||
"""Register a response node so its outputs can be streamed."""
|
||||
...
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
"""Restore coordinator state from a serialized payload."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize coordinator state for persistence."""
|
||||
...
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: TypingMapping[str, object]
|
||||
edges: TypingMapping[str, object]
|
||||
root_node: object
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
"""Mutable runtime state shared across graph execution components."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, object] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue: ReadyQueueProtocol | None = None,
|
||||
graph_execution: GraphExecutionProtocol | None = None,
|
||||
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
|
||||
graph: GraphProtocol | None = None,
|
||||
) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
|
||||
self._outputs = deepcopy(outputs) if outputs is not None else {}
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
self._graph: GraphProtocol | None = None
|
||||
|
||||
self._ready_queue = ready_queue
|
||||
self._graph_execution = graph_execution
|
||||
self._response_coordinator = response_coordinator
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context binding helpers
|
||||
# ------------------------------------------------------------------
|
||||
def attach_graph(self, graph: GraphProtocol) -> None:
|
||||
"""Attach the materialized graph to the runtime state."""
|
||||
if self._graph is not None and self._graph is not graph:
|
||||
raise ValueError("GraphRuntimeState already attached to a different graph instance")
|
||||
|
||||
self._graph = graph
|
||||
|
||||
if self._response_coordinator is None:
|
||||
self._response_coordinator = self._build_response_coordinator(graph)
|
||||
|
||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
# Ensure collaborators are instantiated
|
||||
_ = self.ready_queue
|
||||
_ = self.graph_execution
|
||||
if self._graph is not None:
|
||||
_ = self.response_coordinator
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Primary collaborators
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def ready_queue(self) -> ReadyQueueProtocol:
|
||||
if self._ready_queue is None:
|
||||
self._ready_queue = self._build_ready_queue()
|
||||
return self._ready_queue
|
||||
|
||||
@property
|
||||
def graph_execution(self) -> GraphExecutionProtocol:
|
||||
if self._graph_execution is None:
|
||||
self._graph_execution = self._build_graph_execution()
|
||||
return self._graph_execution
|
||||
|
||||
@property
|
||||
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
|
||||
if self._response_coordinator is None:
|
||||
if self._graph is None:
|
||||
raise ValueError("Graph must be attached before accessing response coordinator")
|
||||
self._response_coordinator = self._build_response_coordinator(self._graph)
|
||||
return self._response_coordinator
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scalar state
|
||||
# ------------------------------------------------------------------
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage) -> None:
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, Any]) -> None:
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: object) -> None:
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, object]) -> None:
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
def dumps(self) -> str:
|
||||
"""Serialize runtime state into a JSON string."""
|
||||
|
||||
snapshot: dict[str, Any] = {
|
||||
"version": "1.0",
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"llm_usage": self._llm_usage.model_dump(mode="json"),
|
||||
"outputs": self.outputs,
|
||||
"variable_pool": self.variable_pool.model_dump(mode="json"),
|
||||
"ready_queue": self.ready_queue.dumps(),
|
||||
"graph_execution": self.graph_execution.dumps(),
|
||||
"paused_nodes": list(self._paused_nodes),
|
||||
}
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
return json.dumps(snapshot, default=pydantic_encoder)
|
||||
|
||||
def loads(self, data: str | Mapping[str, Any]) -> None:
|
||||
"""Restore runtime state from a serialized snapshot."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
if isinstance(data, str):
|
||||
payload = json.loads(data)
|
||||
else:
|
||||
payload = dict(data)
|
||||
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
|
||||
|
||||
self._start_at = float(payload.get("start_at", 0.0))
|
||||
total_tokens = int(payload.get("total_tokens", 0))
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
node_run_steps = int(payload.get("node_run_steps", 0))
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
llm_usage_payload = payload.get("llm_usage", {})
|
||||
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
|
||||
|
||||
self._outputs = deepcopy(payload.get("outputs", {}))
|
||||
|
||||
variable_pool_payload = payload.get("variable_pool")
|
||||
if variable_pool_payload is not None:
|
||||
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
|
||||
|
||||
ready_queue_payload = payload.get("ready_queue")
|
||||
if ready_queue_payload is not None:
|
||||
self._ready_queue = self._build_ready_queue()
|
||||
self._ready_queue.loads(ready_queue_payload)
|
||||
else:
|
||||
self._ready_queue = None
|
||||
|
||||
graph_execution_payload = payload.get("graph_execution")
|
||||
self._graph_execution = None
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
if graph_execution_payload is not None:
|
||||
try:
|
||||
execution_payload = json.loads(graph_execution_payload)
|
||||
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
|
||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
self.graph_execution.loads(graph_execution_payload)
|
||||
|
||||
response_payload = payload.get("response_coordinator")
|
||||
if response_payload is not None:
|
||||
if self._graph is not None:
|
||||
self.response_coordinator.loads(response_payload)
|
||||
else:
|
||||
self._pending_response_coordinator_dump = response_payload
|
||||
else:
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._response_coordinator = None
|
||||
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
self._paused_nodes = set(map(str, paused_nodes_payload))
|
||||
|
||||
def register_paused_node(self, node_id: str) -> None:
|
||||
"""Record a node that should resume when execution is continued."""
|
||||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
nodes = list(self._paused_nodes)
|
||||
self._paused_nodes.clear()
|
||||
return nodes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
def _build_ready_queue(self) -> ReadyQueueProtocol:
|
||||
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
|
||||
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
|
||||
in_memory_cls = module.InMemoryReadyQueue
|
||||
return in_memory_cls()
|
||||
|
||||
def _build_graph_execution(self) -> GraphExecutionProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
|
||||
graph_execution_cls = module.GraphExecution
|
||||
workflow_id = self._pending_graph_execution_workflow_id or ""
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
return graph_execution_cls(workflow_id=workflow_id)
|
||||
|
||||
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
|
||||
coordinator_cls = module.ResponseStreamCoordinator
|
||||
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
|
||||
@@ -16,6 +16,10 @@ class ReadOnlyVariablePool(Protocol):
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
|
||||
"""Get all variables stored under a given node prefix (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
@@ -56,6 +60,20 @@ class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
"""Get the number of nodes currently in the ready queue."""
|
||||
...
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
"""Get the number of node execution exceptions recorded."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the runtime state into a JSON snapshot (read-only)."""
|
||||
...
|
||||
@@ -1,77 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Wrapper that provides read-only access to VariablePool."""
|
||||
"""Provide defensive, read-only access to ``VariablePool``."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
def __init__(self, variable_pool: VariablePool) -> None:
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
"""Return a copy of a variable value if present."""
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
"""Return a copy of all variables for the specified node."""
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# Variables have a value property that contains the actual data
|
||||
variables[key] = deepcopy(var.value)
|
||||
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
|
||||
variables[key] = deepcopy(variable.value)
|
||||
return variables
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given prefix."""
|
||||
return self._variable_pool.get_by_prefix(prefix)
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""
|
||||
Wrapper that provides read-only access to GraphRuntimeState.
|
||||
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
|
||||
|
||||
This wrapper ensures that layers can observe the state without
|
||||
modifying it. All returned values are defensive copies.
|
||||
"""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState):
|
||||
def __init__(self, state: GraphRuntimeState) -> None:
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
"""Get read-only access to the variable pool."""
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
# Return a copy to prevent modification
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
return self._state.node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._state.ready_queue.qsize()
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._state.graph_execution.exceptions_count
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
return self._state.get_output(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize the underlying runtime state for external persistence."""
|
||||
return self._state.dumps()
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -235,6 +236,20 @@ class VariablePool(BaseModel):
|
||||
return segment
|
||||
return None
|
||||
|
||||
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
|
||||
"""Return a copy of all variables stored under the given node prefix."""
|
||||
|
||||
nodes = self.variable_dictionary.get(prefix)
|
||||
if not nodes:
|
||||
return {}
|
||||
|
||||
result: dict[str, object] = {}
|
||||
for key, variable in nodes.items():
|
||||
value = variable.value
|
||||
result[key] = deepcopy(value)
|
||||
|
||||
return result
|
||||
|
||||
def _add_system_variables(self, system_variable: SystemVariable):
|
||||
sys_var_mapping = system_variable.to_dict()
|
||||
for key, value in sys_var_mapping.items():
|
||||
@@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class VariableLoader(Protocol):
|
||||
|
||||
@@ -1,459 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.entities import (
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from core.workflow.enums import (
|
||||
SystemVariableKey,
|
||||
WorkflowExecutionStatus,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
|
||||
@dataclass
|
||||
class CycleManagerWorkflowInfo:
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
version: str
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
class WorkflowCycleManager:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: SystemVariable,
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
# Initialize caches for workflow execution cycle
|
||||
# These caches avoid redundant repository calls during a single workflow execution
|
||||
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
|
||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||
inputs = self._prepare_workflow_inputs()
|
||||
execution_id = self._get_or_generate_execution_id()
|
||||
|
||||
execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
workflow_type=self._workflow_info.workflow_type,
|
||||
workflow_version=self._workflow_info.version,
|
||||
graph=self._workflow_info.graph_data,
|
||||
inputs=inputs,
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
return self._save_and_cache_workflow_execution(execution)
|
||||
|
||||
def handle_workflow_run_success(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
self._update_workflow_execution_completion(
|
||||
workflow_execution,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
return workflow_execution
|
||||
|
||||
def handle_workflow_run_partial_success(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
self._update_workflow_execution_completion(
|
||||
execution,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
outputs=outputs,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
exceptions_count=exceptions_count,
|
||||
)
|
||||
|
||||
self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
return execution
|
||||
|
||||
def handle_workflow_run_failed(
|
||||
self,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowExecutionStatus,
|
||||
error_message: str,
|
||||
conversation_id: str | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
exceptions_count: int = 0,
|
||||
external_trace_id: str | None = None,
|
||||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
now = naive_utc_now()
|
||||
|
||||
self._update_workflow_execution_completion(
|
||||
workflow_execution,
|
||||
status=status,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
error_message=error_message,
|
||||
exceptions_count=exceptions_count,
|
||||
finished_at=now,
|
||||
)
|
||||
|
||||
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
|
||||
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
return workflow_execution
|
||||
|
||||
def handle_node_execution_start(
|
||||
self,
|
||||
*,
|
||||
workflow_execution_id: str,
|
||||
event: QueueNodeStartedEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
domain_execution = self._create_node_execution_from_event(
|
||||
workflow_execution=workflow_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
)
|
||||
|
||||
return self._save_and_cache_node_execution(domain_execution)
|
||||
|
||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||
|
||||
self._update_node_execution_completion(
|
||||
domain_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_failed(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeFailedEvent | QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||
|
||||
status = (
|
||||
WorkflowNodeExecutionStatus.EXCEPTION
|
||||
if isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.FAILED
|
||||
)
|
||||
|
||||
self._update_node_execution_completion(
|
||||
domain_execution,
|
||||
event=event,
|
||||
status=status,
|
||||
error=event.error,
|
||||
handle_special_values=True,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_retried(
|
||||
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
domain_execution = self._create_node_execution_from_event(
|
||||
workflow_execution=workflow_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
error=event.error,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
# Handle inputs and outputs
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = event.outputs
|
||||
metadata = self._merge_event_metadata(event)
|
||||
|
||||
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
|
||||
|
||||
execution = self._save_and_cache_node_execution(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(execution)
|
||||
return execution
|
||||
|
||||
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
|
||||
# Check cache first
|
||||
if id in self._workflow_execution_cache:
|
||||
return self._workflow_execution_cache[id]
|
||||
|
||||
raise WorkflowRunNotFoundError(id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> dict[str, Any]:
|
||||
"""Prepare workflow inputs by merging application inputs with system variables."""
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
|
||||
if self._workflow_system_variables:
|
||||
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||
if field_name != SystemVariableKey.CONVERSATION_ID:
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
|
||||
return dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
def _get_or_generate_execution_id(self) -> str:
|
||||
"""Get execution ID from system variables or generate a new one."""
|
||||
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
|
||||
return str(self._workflow_system_variables.workflow_execution_id)
|
||||
return str(uuidv7())
|
||||
|
||||
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
|
||||
"""Save workflow execution to repository and cache it."""
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._workflow_execution_cache[execution.id_] = execution
|
||||
return execution
|
||||
|
||||
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
|
||||
"""Save node execution to repository and cache it if it has an ID.
|
||||
|
||||
This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model.
|
||||
"""
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
if execution.node_execution_id:
|
||||
self._node_execution_cache[execution.node_execution_id] = execution
|
||||
return execution
|
||||
|
||||
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""Get node execution from cache or raise error if not found."""
|
||||
domain_execution = self._node_execution_cache.get(node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {node_execution_id}")
|
||||
return domain_execution
|
||||
|
||||
def _update_workflow_execution_completion(
|
||||
self,
|
||||
execution: WorkflowExecution,
|
||||
*,
|
||||
status: WorkflowExecutionStatus,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
exceptions_count: int = 0,
|
||||
finished_at: datetime | None = None,
|
||||
):
|
||||
"""Update workflow execution with completion data."""
|
||||
execution.status = status
|
||||
execution.outputs = outputs or {}
|
||||
execution.total_tokens = total_tokens
|
||||
execution.total_steps = total_steps
|
||||
execution.finished_at = finished_at or naive_utc_now()
|
||||
execution.exceptions_count = exceptions_count
|
||||
if error_message:
|
||||
execution.error_message = error_message
|
||||
|
||||
def _add_trace_task_if_needed(
|
||||
self,
|
||||
trace_manager: TraceQueueManager | None,
|
||||
workflow_execution: WorkflowExecution,
|
||||
conversation_id: str | None,
|
||||
external_trace_id: str | None,
|
||||
):
|
||||
"""Add trace task if trace manager is provided."""
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=workflow_execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _fail_running_node_executions(
|
||||
self,
|
||||
workflow_execution_id: str,
|
||||
error_message: str,
|
||||
now: datetime,
|
||||
):
|
||||
"""Fail all running node executions for a workflow."""
|
||||
running_node_executions = [
|
||||
node_exec
|
||||
for node_exec in self._node_execution_cache.values()
|
||||
if node_exec.workflow_execution_id == workflow_execution_id
|
||||
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
]
|
||||
|
||||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error_message
|
||||
node_execution.finished_at = now
|
||||
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
||||
self._workflow_node_execution_repository.save(node_execution)
|
||||
|
||||
def _create_node_execution_from_event(
|
||||
self,
|
||||
*,
|
||||
workflow_execution: WorkflowExecution,
|
||||
event: QueueNodeStartedEvent,
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: str | None = None,
|
||||
created_at: datetime | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a node execution from an event."""
|
||||
now = naive_utc_now()
|
||||
created_at = created_at or now
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=event.node_execution_id,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
index=event.node_run_index,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_title,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
error=error,
|
||||
)
|
||||
|
||||
if status == WorkflowNodeExecutionStatus.RETRY:
|
||||
domain_execution.finished_at = now
|
||||
domain_execution.elapsed_time = (now - created_at).total_seconds()
|
||||
|
||||
return domain_execution
|
||||
|
||||
def _update_node_execution_completion(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
*,
|
||||
event: Union[
|
||||
QueueNodeSucceededEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
],
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: str | None = None,
|
||||
handle_special_values: bool = False,
|
||||
):
|
||||
"""Update node execution with completion data."""
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
# Process data
|
||||
if handle_special_values:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
else:
|
||||
inputs = event.inputs
|
||||
process_data = event.process_data
|
||||
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
if event.execution_metadata:
|
||||
execution_metadata_dict.update(event.execution_metadata)
|
||||
|
||||
# Update domain model
|
||||
domain_execution.status = status
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata=execution_metadata_dict,
|
||||
)
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = elapsed_time
|
||||
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
|
||||
"""Merge event metadata with origin metadata."""
|
||||
origin_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
|
||||
if event.execution_metadata:
|
||||
execution_metadata_dict.update(event.execution_metadata)
|
||||
|
||||
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||
@@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
@@ -20,6 +20,7 @@ from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, Gra
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from factories import file_factory
|
||||
|
||||
@@ -37,7 +37,6 @@ from core.rag.entities.event import (
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
@@ -50,6 +49,7 @@ from core.workflow.node_events.base import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.entities import VariablePool, WorkflowNodeExecution
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
@@ -23,6 +23,7 @@ from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
|
||||
@@ -5,12 +5,13 @@ import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
@@ -5,10 +5,11 @@ from urllib.parse import urlencode
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
@@ -174,13 +175,13 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
HttpRequestNodeTimeout,
|
||||
)
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Create variable pool
|
||||
|
||||
@@ -6,12 +6,13 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
||||
@@ -5,11 +5,12 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
|
||||
@@ -4,11 +4,12 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
@@ -4,12 +4,13 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -99,6 +99,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
@@ -237,6 +239,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
@@ -390,6 +394,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
# Mock database session
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _make_state(workflow_run_id: str | None) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id))
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
|
||||
class _StubPipeline(GraphRuntimeStateSupport):
|
||||
def __init__(self, *, cached_state: GraphRuntimeState | None, queue_state: GraphRuntimeState | None):
|
||||
self._graph_runtime_state = cached_state
|
||||
self._base_task_pipeline = SimpleNamespace(queue_manager=SimpleNamespace(graph_runtime_state=queue_state))
|
||||
|
||||
|
||||
def test_ensure_graph_runtime_initialized_caches_explicit_state():
|
||||
explicit_state = _make_state("run-explicit")
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=None)
|
||||
|
||||
resolved = pipeline._ensure_graph_runtime_initialized(explicit_state)
|
||||
|
||||
assert resolved is explicit_state
|
||||
assert pipeline._graph_runtime_state is explicit_state
|
||||
|
||||
|
||||
def test_resolve_graph_runtime_state_reads_from_queue_when_cache_empty():
|
||||
queued_state = _make_state("run-queue")
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=queued_state)
|
||||
|
||||
resolved = pipeline._resolve_graph_runtime_state()
|
||||
|
||||
assert resolved is queued_state
|
||||
assert pipeline._graph_runtime_state is queued_state
|
||||
|
||||
|
||||
def test_resolve_graph_runtime_state_raises_when_no_state_available():
|
||||
pipeline = _StubPipeline(cached_state=None, queue_state=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pipeline._resolve_graph_runtime_state()
|
||||
|
||||
|
||||
def test_extract_workflow_run_id_returns_value():
|
||||
state = _make_state("run-identifier")
|
||||
pipeline = _StubPipeline(cached_state=state, queue_state=None)
|
||||
|
||||
run_id = pipeline._extract_workflow_run_id(state)
|
||||
|
||||
assert run_id == "run-identifier"
|
||||
|
||||
|
||||
def test_extract_workflow_run_id_raises_when_missing():
|
||||
state = _make_state(None)
|
||||
pipeline = _StubPipeline(cached_state=state, queue_state=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
pipeline._extract_workflow_run_id(state)
|
||||
@@ -3,8 +3,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -12,24 +11,17 @@ import pytest
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataResponseScenario:
|
||||
"""Test scenario for process_data in responses."""
|
||||
|
||||
name: str
|
||||
original_process_data: dict[str, Any] | None
|
||||
truncated_process_data: dict[str, Any] | None
|
||||
expected_response_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterCenarios:
|
||||
"""Test process_data truncation in WorkflowResponseConverter."""
|
||||
|
||||
@@ -39,6 +31,7 @@ class TestWorkflowResponseConverterCenarios:
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.app_config = mock_app_config
|
||||
mock_entity.inputs = {}
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
@@ -50,54 +43,59 @@ class TestWorkflowResponseConverterCenarios:
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
return WorkflowResponseConverter(application_generate_entity=mock_entity, user=mock_user)
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
truncated_process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution for testing."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=mock_entity,
|
||||
user=mock_user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
if truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(truncated_process_data)
|
||||
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
||||
"""Create a QueueNodeStartedEvent for testing."""
|
||||
return QueueNodeStartedEvent(
|
||||
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
||||
node_id="test-node-id",
|
||||
node_title="Test Node",
|
||||
node_type=NodeType.CODE,
|
||||
start_at=naive_utc_now(),
|
||||
predecessor_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
)
|
||||
|
||||
return execution
|
||||
|
||||
def create_node_succeeded_event(self) -> QueueNodeSucceededEvent:
|
||||
def create_node_succeeded_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data=process_data or {},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
def create_node_retry_event(self) -> QueueNodeRetryEvent:
|
||||
def create_node_retry_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
process_data=process_data or {},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
@@ -105,12 +103,8 @@ class TestWorkflowResponseConverterCenarios:
|
||||
node_title="test code",
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
@@ -122,15 +116,28 @@ class TestWorkflowResponseConverterCenarios:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
@@ -145,13 +152,26 @@ class TestWorkflowResponseConverterCenarios:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_succeeded_event()
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
@@ -163,18 +183,31 @@ class TestWorkflowResponseConverterCenarios:
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
event = self.create_node_succeeded_event()
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=None,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should have None process_data
|
||||
# Response should normalize missing process_data to an empty mapping
|
||||
assert response is not None
|
||||
assert response.data.process_data is None
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
@@ -184,15 +217,28 @@ class TestWorkflowResponseConverterCenarios:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
@@ -207,224 +253,72 @@ class TestWorkflowResponseConverterCenarios:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_retry_event()
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no change from existing behavior)."""
|
||||
"""Test that iteration and loop nodes return None (no streaming events)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
# Test iteration node
|
||||
iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
iteration_execution.node_type = NodeType.ITERATION
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=iteration_execution,
|
||||
)
|
||||
|
||||
# Should return None for iteration nodes
|
||||
assert response is None
|
||||
|
||||
# Test loop node
|
||||
loop_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
loop_execution.node_type = NodeType.LOOP
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=loop_execution,
|
||||
)
|
||||
|
||||
# Should return None for loop nodes
|
||||
assert response is None
|
||||
|
||||
def test_execution_without_workflow_execution_id_returns_none(self):
|
||||
"""Test that executions without workflow_execution_id return None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
execution.workflow_execution_id = None # Single-step debugging
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Should return None for single-step debugging
|
||||
assert response is None
|
||||
|
||||
@staticmethod
|
||||
def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]:
|
||||
"""Create test scenarios for process_data responses."""
|
||||
return [
|
||||
ProcessDataResponseScenario(
|
||||
name="none_process_data",
|
||||
original_process_data=None,
|
||||
truncated_process_data=None,
|
||||
expected_response_data=None,
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="small_process_data_no_truncation",
|
||||
original_process_data={"small": "data"},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={"small": "data"},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="large_process_data_with_truncation",
|
||||
original_process_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="empty_process_data",
|
||||
original_process_data={},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="complex_data_with_truncation",
|
||||
original_process_data={
|
||||
"logs": ["entry"] * 1000, # Large array
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
truncated_process_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_response_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_response_scenarios(),
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node finish responses."""
|
||||
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
iteration_event = QueueNodeSucceededEvent(
|
||||
node_id="iteration-node",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
event=iteration_event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
assert response is None
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
get_process_data_response_scenarios(),
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node retry responses."""
|
||||
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant")),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=loop_event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
assert response is None
|
||||
|
||||
def test_finish_without_start_raises(self):
|
||||
"""Ensure finish responses require a prior workflow start."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
process_data={},
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
with pytest.raises(ValueError):
|
||||
converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ from core.variables.variables import (
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
@@ -95,3 +97,141 @@ class TestGraphRuntimeState:
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
|
||||
def test_ready_queue_default_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
queue = state.ready_queue
|
||||
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
|
||||
assert isinstance(queue, InMemoryReadyQueue)
|
||||
assert state.ready_queue is queue
|
||||
|
||||
def test_graph_execution_lazy_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
execution = state.graph_execution
|
||||
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
|
||||
assert isinstance(execution, GraphExecution)
|
||||
assert execution.workflow_id == ""
|
||||
assert state.graph_execution is execution
|
||||
|
||||
def test_response_coordinator_configuration(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = state.response_coordinator
|
||||
|
||||
mock_graph = MagicMock()
|
||||
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
|
||||
coordinator_instance = MagicMock()
|
||||
coordinator_cls.return_value = coordinator_instance
|
||||
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
assert state.response_coordinator is coordinator_instance
|
||||
coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph)
|
||||
|
||||
# Configure again with same graph should be idempotent
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
other_graph = MagicMock()
|
||||
with pytest.raises(ValueError):
|
||||
state.attach_graph(other_graph)
|
||||
|
||||
def test_read_only_wrapper_exposes_additional_state(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
state.configure()
|
||||
|
||||
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
|
||||
|
||||
assert wrapper.ready_queue_size == 0
|
||||
assert wrapper.exceptions_count == 0
|
||||
|
||||
def test_read_only_wrapper_serializes_runtime_state(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
state.total_tokens = 5
|
||||
state.set_output("result", {"success": True})
|
||||
state.ready_queue.put("node-1")
|
||||
|
||||
wrapper = ReadOnlyGraphRuntimeStateWrapper(state)
|
||||
|
||||
wrapper_snapshot = json.loads(wrapper.dumps())
|
||||
state_snapshot = json.loads(state.dumps())
|
||||
|
||||
assert wrapper_snapshot == state_snapshot
|
||||
|
||||
def test_dumps_and_loads_roundtrip_with_response_coordinator(self):
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node1", "value"), "payload")
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
state.total_tokens = 10
|
||||
state.node_run_steps = 3
|
||||
state.set_output("final", {"result": True})
|
||||
usage = LLMUsage.from_metadata(
|
||||
{
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 5,
|
||||
"total_price": "1.23",
|
||||
"currency": "USD",
|
||||
"latency": 0.5,
|
||||
}
|
||||
)
|
||||
state.llm_usage = usage
|
||||
state.ready_queue.put("node-A")
|
||||
|
||||
graph_execution = state.graph_execution
|
||||
graph_execution.workflow_id = "wf-123"
|
||||
graph_execution.exceptions_count = 4
|
||||
graph_execution.started = True
|
||||
|
||||
class StubCoordinator:
|
||||
def __init__(self) -> None:
|
||||
self.state = "initial"
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps({"state": self.state})
|
||||
|
||||
def loads(self, data: str) -> None:
|
||||
payload = json.loads(data)
|
||||
self.state = payload["state"]
|
||||
|
||||
mock_graph = MagicMock()
|
||||
stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
|
||||
state.attach_graph(mock_graph)
|
||||
|
||||
stub.state = "configured"
|
||||
|
||||
snapshot = state.dumps()
|
||||
|
||||
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
restored.loads(snapshot)
|
||||
|
||||
assert restored.total_tokens == 10
|
||||
assert restored.node_run_steps == 3
|
||||
assert restored.get_output("final") == {"result": True}
|
||||
assert restored.llm_usage.total_tokens == usage.total_tokens
|
||||
assert restored.ready_queue.qsize() == 1
|
||||
assert restored.ready_queue.get(timeout=0.01) == "node-A"
|
||||
|
||||
restored_segment = restored.variable_pool.get(("node1", "value"))
|
||||
assert restored_segment is not None
|
||||
assert restored_segment.value == "payload"
|
||||
|
||||
restored_execution = restored.graph_execution
|
||||
assert restored_execution.workflow_id == "wf-123"
|
||||
assert restored_execution.exceptions_count == 4
|
||||
assert restored_execution.started is True
|
||||
|
||||
new_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||
restored.attach_graph(mock_graph)
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
|
||||
@@ -4,7 +4,7 @@ from core.variables.segments import (
|
||||
NoneSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class TestVariablePoolGetAndNestedAttribute:
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node:
|
||||
node = MagicMock(spec=Node)
|
||||
node.id = node_id
|
||||
node.node_type = node_type
|
||||
node.execution_type = None # attribute not used in builder path
|
||||
return node
|
||||
|
||||
|
||||
def test_graph_builder_creates_linear_graph():
|
||||
builder = Graph.new()
|
||||
root = _make_node("root", NodeType.START)
|
||||
mid = _make_node("mid", NodeType.LLM)
|
||||
end = _make_node("end", NodeType.END)
|
||||
|
||||
graph = builder.add_root(root).add_node(mid).add_node(end).build()
|
||||
|
||||
assert graph.root_node is root
|
||||
assert graph.nodes == {"root": root, "mid": mid, "end": end}
|
||||
assert len(graph.edges) == 2
|
||||
first_edge = next(iter(graph.edges.values()))
|
||||
assert first_edge.tail == "root"
|
||||
assert first_edge.head == "mid"
|
||||
assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"]
|
||||
|
||||
|
||||
def test_graph_builder_supports_custom_predecessor():
|
||||
builder = Graph.new()
|
||||
root = _make_node("root")
|
||||
branch = _make_node("branch")
|
||||
other = _make_node("other")
|
||||
|
||||
graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build()
|
||||
|
||||
outgoing_root = graph.out_edges["root"]
|
||||
assert len(outgoing_root) == 2
|
||||
edge_targets = {graph.edges[eid].head for eid in outgoing_root}
|
||||
assert edge_targets == {"branch", "other"}
|
||||
|
||||
|
||||
def test_graph_builder_validates_usage():
|
||||
builder = Graph.new()
|
||||
node = _make_node("node")
|
||||
|
||||
with pytest.raises(ValueError, match="Root node"):
|
||||
builder.add_node(node)
|
||||
|
||||
builder.add_root(node)
|
||||
duplicate = _make_node("node")
|
||||
with pytest.raises(ValueError, match="Duplicate"):
|
||||
builder.add_node(duplicate)
|
||||
@@ -20,9 +20,6 @@ The TableTestRunner (`test_table_runner.py`) provides a robust table-driven test
|
||||
- **Mock configuration** - Seamless integration with the auto-mock system
|
||||
- **Performance metrics** - Track execution times and bottlenecks
|
||||
- **Detailed error reporting** - Comprehensive failure diagnostics
|
||||
- **Test tagging** - Organize and filter tests by tags
|
||||
- **Retry mechanism** - Handle flaky tests gracefully
|
||||
- **Custom validators** - Define custom validation logic
|
||||
|
||||
### Basic Usage
|
||||
|
||||
@@ -68,49 +65,6 @@ suite_result = runner.run_table_tests(
|
||||
print(f"Success rate: {suite_result.success_rate:.1f}%")
|
||||
```
|
||||
|
||||
#### Test Tagging and Filtering
|
||||
|
||||
```python
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
tags=["smoke", "critical"],
|
||||
)
|
||||
|
||||
# Run only tests with specific tags
|
||||
suite_result = runner.run_table_tests(
|
||||
test_cases,
|
||||
tags_filter=["smoke"]
|
||||
)
|
||||
```
|
||||
|
||||
#### Retry Mechanism
|
||||
|
||||
```python
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="flaky_workflow",
|
||||
inputs={},
|
||||
expected_outputs={},
|
||||
retry_count=2, # Retry up to 2 times on failure
|
||||
)
|
||||
```
|
||||
|
||||
#### Custom Validators
|
||||
|
||||
```python
|
||||
def custom_validator(outputs: dict) -> bool:
|
||||
# Custom validation logic
|
||||
return "error" not in outputs.get("status", "")
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="workflow",
|
||||
inputs={},
|
||||
expected_outputs={"status": "success"},
|
||||
custom_validator=custom_validator,
|
||||
)
|
||||
```
|
||||
|
||||
#### Event Sequence Validation
|
||||
|
||||
```python
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
@@ -16,6 +15,7 @@ from core.workflow.graph_engine.response_coordinator.coordinator import Response
|
||||
from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import RetryConfig
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
class _StubEdgeProcessor:
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunStartedEvent
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
def test_abort_command():
|
||||
@@ -100,8 +100,57 @@ def test_redis_channel_serialization():
|
||||
assert command_data["command_type"] == "abort"
|
||||
assert command_data["reason"] == "Test abort"
|
||||
|
||||
# Test pause command serialization
|
||||
pause_command = PauseCommand(reason="User requested pause")
|
||||
channel.send_command(pause_command)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_abort_command()
|
||||
test_redis_channel_serialization()
|
||||
print("All tests passed!")
|
||||
assert len(mock_pipeline.rpush.call_args_list) == 2
|
||||
second_call_args = mock_pipeline.rpush.call_args_list[1]
|
||||
pause_command_json = second_call_args[0][1]
|
||||
pause_command_data = json.loads(pause_command_json)
|
||||
assert pause_command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert pause_command_data["reason"] == "User requested pause"
|
||||
|
||||
|
||||
def test_pause_command():
|
||||
"""Test that GraphEngine properly handles pause commands."""
|
||||
|
||||
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
mock_start_node = MagicMock()
|
||||
mock_start_node.state = None
|
||||
mock_start_node.id = "start"
|
||||
mock_start_node.graph_runtime_state = shared_runtime_state
|
||||
mock_graph.nodes["start"] = mock_start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
pause_command = PauseCommand(reason="User requested pause")
|
||||
command_channel.send_command(pause_command)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
|
||||
@@ -21,6 +21,7 @@ class _StubExecutionCoordinator:
|
||||
self._execution_complete = False
|
||||
self.mark_complete_called = False
|
||||
self.failed = False
|
||||
self._paused = False
|
||||
|
||||
def check_commands(self) -> None:
|
||||
self.command_checks += 1
|
||||
@@ -28,6 +29,10 @@ class _StubExecutionCoordinator:
|
||||
def check_scaling(self) -> None:
|
||||
self.scaling_checks += 1
|
||||
|
||||
@property
|
||||
def is_paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
return self._execution_complete
|
||||
|
||||
@@ -96,7 +101,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
|
||||
|
||||
|
||||
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
|
||||
"""Dispatcher polls commands when idle and re-checks after completion events."""
|
||||
"""Dispatcher polls commands when idle and after completion events."""
|
||||
started_checks = _run_dispatcher_for_event(_make_started_event())
|
||||
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Unit tests for the execution coordinator orchestration logic."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator
|
||||
from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool
|
||||
|
||||
|
||||
def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]:
|
||||
command_processor = MagicMock(spec=CommandProcessor)
|
||||
state_manager = MagicMock(spec=GraphStateManager)
|
||||
worker_pool = MagicMock(spec=WorkerPool)
|
||||
|
||||
coordinator = ExecutionCoordinator(
|
||||
graph_execution=graph_execution,
|
||||
state_manager=state_manager,
|
||||
command_processor=command_processor,
|
||||
worker_pool=worker_pool,
|
||||
)
|
||||
return coordinator, state_manager, worker_pool
|
||||
|
||||
|
||||
def test_handle_pause_stops_workers_and_clears_state() -> None:
|
||||
"""Paused execution should stop workers and clear executing state."""
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
graph_execution.pause("Awaiting human input")
|
||||
|
||||
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
|
||||
|
||||
coordinator.handle_pause_if_needed()
|
||||
|
||||
worker_pool.stop.assert_called_once_with()
|
||||
state_manager.clear_executing.assert_called_once_with()
|
||||
|
||||
|
||||
def test_handle_pause_noop_when_execution_running() -> None:
|
||||
"""Running execution should not trigger pause handling."""
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
|
||||
coordinator, state_manager, worker_pool = _build_coordinator(graph_execution)
|
||||
|
||||
coordinator.handle_pause_if_needed()
|
||||
|
||||
worker_pool.stop.assert_not_called()
|
||||
state_manager.clear_executing.assert_not_called()
|
||||
|
||||
|
||||
def test_is_execution_complete_when_paused() -> None:
|
||||
"""Paused execution should be treated as complete."""
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
graph_execution.pause("Awaiting input")
|
||||
|
||||
coordinator, state_manager, _worker_pool = _build_coordinator(graph_execution)
|
||||
state_manager.is_execution_complete.return_value = False
|
||||
|
||||
assert coordinator.is_execution_complete()
|
||||
@@ -0,0 +1,341 @@
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
llm_node = MockLLMNode(
|
||||
id=node_id,
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
)
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
|
||||
end_primary_data = EndNodeData(
|
||||
title="End Primary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
|
||||
end_primary = EndNode(
|
||||
id=end_primary_config["id"],
|
||||
config=end_primary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
|
||||
end_secondary = EndNode(
|
||||
id=end_secondary_config["id"],
|
||||
config=end_secondary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_initial)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_primary, from_node_id="human", source_handle="primary")
|
||||
.add_node(end_primary, from_node_id="llm_primary")
|
||||
.add_node(llm_secondary, from_node_id="human", source_handle="secondary")
|
||||
.add_node(end_secondary, from_node_id="llm_secondary")
|
||||
.build()
|
||||
)
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
def _expected_mock_llm_chunks(text: str) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
for index, word in enumerate(text.split(" ")):
|
||||
chunk = word if index == 0 else f" {word}"
|
||||
chunks.append(chunk)
|
||||
chunks.append("")
|
||||
return chunks
|
||||
|
||||
|
||||
def _assert_stream_chunk_sequence(
|
||||
chunk_events: Iterable[NodeRunStreamChunkEvent],
|
||||
expected_nodes: list[str],
|
||||
expected_chunks: list[str],
|
||||
) -> None:
|
||||
actual_nodes = [event.node_id for event in chunk_events]
|
||||
actual_chunks = [event.chunk for event in chunk_events]
|
||||
assert actual_nodes == expected_nodes
|
||||
assert actual_chunks == expected_chunks
|
||||
|
||||
|
||||
def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
|
||||
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
|
||||
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
|
||||
|
||||
branch_scenarios = [
|
||||
{
|
||||
"handle": "primary",
|
||||
"resume_llm": "llm_primary",
|
||||
"end_node": "end_primary",
|
||||
"expected_pre_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
|
||||
("end_primary", ["\n"]), # literal segment emitted when end_primary session activates
|
||||
],
|
||||
"expected_post_chunks": [
|
||||
("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch
|
||||
],
|
||||
},
|
||||
{
|
||||
"handle": "secondary",
|
||||
"resume_llm": "llm_secondary",
|
||||
"end_node": "end_secondary",
|
||||
"expected_pre_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes
|
||||
("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates
|
||||
],
|
||||
"expected_post_chunks": [
|
||||
("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in branch_scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause before branching decision",
|
||||
graph_factory=initial_graph_factory,
|
||||
expected_event_sequence=[
|
||||
GraphRunStartedEvent, # initial run: graph execution starts
|
||||
NodeRunStartedEvent, # start node begins execution
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial starts streaming
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # human node begins and issues pause
|
||||
NodeRunPauseRequestedEvent, # human node requests pause awaiting input
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
],
|
||||
)
|
||||
|
||||
initial_result = runner.run_test_case(initial_case)
|
||||
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
|
||||
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
|
||||
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
|
||||
|
||||
expected_resume_sequence: list[type] = (
|
||||
[
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
+ [NodeRunStreamChunkEvent] * pre_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
+ [NodeRunStreamChunkEvent] * post_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
]
|
||||
)
|
||||
|
||||
def resume_graph_factory(
|
||||
graph_snapshot: Graph = graph,
|
||||
state_snapshot: GraphRuntimeState = graph_runtime_state,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return graph_snapshot, state_snapshot
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description=f"HumanInput resumes via {scenario['handle']} branch",
|
||||
graph_factory=resume_graph_factory,
|
||||
expected_event_sequence=expected_resume_sequence,
|
||||
)
|
||||
|
||||
resume_result = runner.run_test_case(resume_case)
|
||||
|
||||
assert resume_result.success, resume_result.event_mismatch_details
|
||||
|
||||
resume_events = resume_result.events
|
||||
|
||||
chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
assert len(chunk_events) == pre_chunk_count + post_chunk_count
|
||||
|
||||
pre_chunk_events = chunk_events[:pre_chunk_count]
|
||||
post_chunk_events = chunk_events[pre_chunk_count:]
|
||||
|
||||
expected_pre_nodes: list[str] = []
|
||||
expected_pre_chunks: list[str] = []
|
||||
for node_id, chunks in scenario["expected_pre_chunks"]:
|
||||
expected_pre_nodes.extend([node_id] * len(chunks))
|
||||
expected_pre_chunks.extend(chunks)
|
||||
_assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks)
|
||||
|
||||
expected_post_nodes: list[str] = []
|
||||
expected_post_chunks: list[str] = []
|
||||
for node_id, chunks in scenario["expected_post_chunks"]:
|
||||
expected_post_nodes.extend([node_id] * len(chunks))
|
||||
expected_post_chunks.extend(chunks)
|
||||
_assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks)
|
||||
|
||||
human_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
|
||||
)
|
||||
pre_indices = [
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
|
||||
]
|
||||
assert pre_indices == list(range(2, 2 + pre_chunk_count))
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
|
||||
]
|
||||
assert resume_chunk_indices, "Expected streaming output from the selected branch"
|
||||
resume_start_index = next(
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
resume_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
assert resume_start_index < min(resume_chunk_indices)
|
||||
assert max(resume_chunk_indices) < resume_success_index
|
||||
|
||||
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
|
||||
assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]]
|
||||
@@ -0,0 +1,297 @@
|
||||
import time
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
llm_node = MockLLMNode(
|
||||
id=node_id,
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
)
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="End",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_first)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_second)
|
||||
.add_node(end_node)
|
||||
.build()
|
||||
)
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
def _expected_mock_llm_chunks(text: str) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
for index, word in enumerate(text.split(" ")):
|
||||
chunk = word if index == 0 else f" {word}"
|
||||
chunks.append(chunk)
|
||||
chunks.append("")
|
||||
return chunks
|
||||
|
||||
|
||||
def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
runner = TableTestRunner()
|
||||
|
||||
initial_text = "Hello, pause"
|
||||
resume_text = "Welcome back!"
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_initial", {"text": initial_text})
|
||||
mock_config.set_node_outputs("llm_resume", {"text": resume_text})
|
||||
|
||||
expected_initial_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # graph run begins
|
||||
NodeRunStartedEvent, # start node begins
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial begins streaming
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # human node begins and requests pause
|
||||
NodeRunPauseRequestedEvent, # human node pause requested
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
]
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_llm_human_llm_graph(mock_config)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause preserves LLM streaming order",
|
||||
graph_factory=graph_factory,
|
||||
expected_event_sequence=expected_initial_sequence,
|
||||
)
|
||||
|
||||
initial_result = runner.run_test_case(initial_case)
|
||||
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
|
||||
initial_events = initial_result.events
|
||||
initial_chunks = _expected_mock_llm_chunks(initial_text)
|
||||
|
||||
initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
assert initial_stream_chunk_events == []
|
||||
|
||||
pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent))
|
||||
llm_succeeded_index = next(
|
||||
i
|
||||
for i, event in enumerate(initial_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial"
|
||||
)
|
||||
assert llm_succeeded_index < pause_index
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
coordinator = graph_runtime_state.response_coordinator
|
||||
stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions
|
||||
assert ("llm_initial", "text") in stream_buffers
|
||||
initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]]
|
||||
assert initial_stream_chunks == initial_chunks
|
||||
assert ("llm_resume", "text") not in stream_buffers
|
||||
|
||||
resume_chunks = _expected_mock_llm_chunks(resume_text)
|
||||
expected_resume_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # resumed graph run begins
|
||||
NodeRunStartedEvent, # human node restarts
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk
|
||||
NodeRunStreamChunkEvent, # end node emits combined template separator
|
||||
NodeRunSucceededEvent, # human node finishes instantly after input
|
||||
NodeRunStartedEvent, # llm_resume begins streaming
|
||||
NodeRunStreamChunkEvent, # llm_resume chunk 1
|
||||
NodeRunStreamChunkEvent, # llm_resume chunk 2
|
||||
NodeRunStreamChunkEvent, # llm_resume final chunk
|
||||
NodeRunSucceededEvent, # llm_resume completes streaming
|
||||
NodeRunStartedEvent, # end node starts
|
||||
NodeRunSucceededEvent, # end node finishes
|
||||
GraphRunSucceededEvent, # graph run succeeds after resume
|
||||
]
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
return graph, graph_runtime_state
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description="HumanInput resume continues LLM streaming order",
|
||||
graph_factory=resume_graph_factory,
|
||||
expected_event_sequence=expected_resume_sequence,
|
||||
)
|
||||
|
||||
resume_result = runner.run_test_case(resume_case)
|
||||
|
||||
assert resume_result.success, resume_result.event_mismatch_details
|
||||
|
||||
resume_events = resume_result.events
|
||||
|
||||
success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent))
|
||||
llm_resume_succeeded_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
|
||||
)
|
||||
assert llm_resume_succeeded_index < success_index
|
||||
|
||||
resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3
|
||||
assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks
|
||||
assert resume_chunk_events[3].node_id == "end"
|
||||
assert resume_chunk_events[3].chunk == "\n"
|
||||
assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3
|
||||
assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks
|
||||
|
||||
human_success_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human"
|
||||
)
|
||||
cached_chunk_indices = [
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"}
|
||||
]
|
||||
assert all(index < human_success_index for index in cached_chunk_indices)
|
||||
|
||||
llm_resume_start_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume"
|
||||
)
|
||||
llm_resume_success_index = next(
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume"
|
||||
)
|
||||
llm_resume_chunk_indices = [
|
||||
i
|
||||
for i, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume"
|
||||
]
|
||||
assert llm_resume_chunk_indices
|
||||
first_resume_chunk_index = min(llm_resume_chunk_indices)
|
||||
last_resume_chunk_index = max(llm_resume_chunk_indices)
|
||||
assert llm_resume_start_index < first_resume_chunk_index
|
||||
assert last_resume_chunk_index < llm_resume_success_index
|
||||
|
||||
started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)]
|
||||
assert started_nodes == ["human", "llm_resume", "end"]
|
||||
@@ -0,0 +1,321 @@
|
||||
import time
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("branch", "value"), branch_value)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
llm_node = MockLLMNode(
|
||||
id=node_id,
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
|
||||
if_else_data = IfElseNodeData(
|
||||
title="IfElse",
|
||||
cases=[
|
||||
IfElseNodeData.Case(
|
||||
case_id="primary",
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary")
|
||||
],
|
||||
),
|
||||
IfElseNodeData.Case(
|
||||
case_id="secondary",
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary")
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
if_else_config = {"id": "if_else", "data": if_else_data.model_dump()}
|
||||
if_else_node = IfElseNode(
|
||||
id=if_else_config["id"],
|
||||
config=if_else_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if_else_node.init_node_data(if_else_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
|
||||
end_primary_data = EndNodeData(
|
||||
title="End Primary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()}
|
||||
end_primary = EndNode(
|
||||
id=end_primary_config["id"],
|
||||
config=end_primary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()}
|
||||
end_secondary = EndNode(
|
||||
id=end_secondary_config["id"],
|
||||
config=end_secondary_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_initial)
|
||||
.add_node(if_else_node)
|
||||
.add_node(llm_primary, from_node_id="if_else", source_handle="primary")
|
||||
.add_node(end_primary, from_node_id="llm_primary")
|
||||
.add_node(llm_secondary, from_node_id="if_else", source_handle="secondary")
|
||||
.add_node(end_secondary, from_node_id="llm_secondary")
|
||||
.build()
|
||||
)
|
||||
return graph, graph_runtime_state
|
||||
|
||||
|
||||
def _expected_mock_llm_chunks(text: str) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
for index, word in enumerate(text.split(" ")):
|
||||
chunk = word if index == 0 else f" {word}"
|
||||
chunks.append(chunk)
|
||||
chunks.append("")
|
||||
return chunks
|
||||
|
||||
|
||||
def test_if_else_llm_streaming_order() -> None:
|
||||
mock_config = MockConfig()
|
||||
mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"})
|
||||
mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"})
|
||||
mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"})
|
||||
|
||||
scenarios = [
|
||||
{
|
||||
"branch": "primary",
|
||||
"resume_llm": "llm_primary",
|
||||
"end_node": "end_primary",
|
||||
"expected_sequence": [
|
||||
GraphRunStartedEvent, # graph run begins
|
||||
NodeRunStartedEvent, # start node begins execution
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial starts and streams
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # if_else evaluates conditions
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
|
||||
NodeRunStreamChunkEvent, # template literal newline emitted
|
||||
NodeRunSucceededEvent, # if_else completes branch selection
|
||||
NodeRunStartedEvent, # llm_primary begins streaming
|
||||
NodeRunStreamChunkEvent, # llm_primary chunk 1
|
||||
NodeRunStreamChunkEvent, # llm_primary chunk 2
|
||||
NodeRunStreamChunkEvent, # llm_primary chunk 3
|
||||
NodeRunStreamChunkEvent, # llm_primary final chunk
|
||||
NodeRunSucceededEvent, # llm_primary completes streaming
|
||||
NodeRunStartedEvent, # end_primary node starts
|
||||
NodeRunSucceededEvent, # end_primary finishes aggregation
|
||||
GraphRunSucceededEvent, # graph run succeeds
|
||||
],
|
||||
"expected_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
|
||||
("end_primary", ["\n"]),
|
||||
("llm_primary", _expected_mock_llm_chunks("Primary stream output")),
|
||||
],
|
||||
},
|
||||
{
|
||||
"branch": "secondary",
|
||||
"resume_llm": "llm_secondary",
|
||||
"end_node": "end_secondary",
|
||||
"expected_sequence": [
|
||||
GraphRunStartedEvent, # graph run begins
|
||||
NodeRunStartedEvent, # start node begins execution
|
||||
NodeRunSucceededEvent, # start node completes
|
||||
NodeRunStartedEvent, # llm_initial starts and streams
|
||||
NodeRunSucceededEvent, # llm_initial completes streaming
|
||||
NodeRunStartedEvent, # if_else evaluates conditions
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed
|
||||
NodeRunStreamChunkEvent, # template literal newline emitted
|
||||
NodeRunSucceededEvent, # if_else completes branch selection
|
||||
NodeRunStartedEvent, # llm_secondary begins streaming
|
||||
NodeRunStreamChunkEvent, # llm_secondary chunk 1
|
||||
NodeRunStreamChunkEvent, # llm_secondary final chunk
|
||||
NodeRunSucceededEvent, # llm_secondary completes
|
||||
NodeRunStartedEvent, # end_secondary node starts
|
||||
NodeRunSucceededEvent, # end_secondary finishes aggregation
|
||||
GraphRunSucceededEvent, # graph run succeeds
|
||||
],
|
||||
"expected_chunks": [
|
||||
("llm_initial", _expected_mock_llm_chunks("Initial stream")),
|
||||
("end_secondary", ["\n"]),
|
||||
("llm_secondary", _expected_mock_llm_chunks("Secondary")),
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for scenario in scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
def graph_factory(
|
||||
branch_value: str = scenario["branch"],
|
||||
cfg: MockConfig = mock_config,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_if_else_graph(branch_value, cfg)
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
description=f"IfElse streaming via {scenario['branch']} branch",
|
||||
graph_factory=graph_factory,
|
||||
expected_event_sequence=scenario["expected_sequence"],
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, result.event_mismatch_details
|
||||
|
||||
chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)]
|
||||
expected_nodes: list[str] = []
|
||||
expected_chunks: list[str] = []
|
||||
for node_id, chunks in scenario["expected_chunks"]:
|
||||
expected_nodes.extend([node_id] * len(chunks))
|
||||
expected_chunks.extend(chunks)
|
||||
assert [event.node_id for event in chunk_events] == expected_nodes
|
||||
assert [event.chunk for event in chunk_events] == expected_chunks
|
||||
|
||||
branch_node_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else"
|
||||
)
|
||||
branch_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else"
|
||||
)
|
||||
pre_branch_chunk_indices = [
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index
|
||||
]
|
||||
assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1
|
||||
assert min(pre_branch_chunk_indices) == branch_node_index + 1
|
||||
assert max(pre_branch_chunk_indices) < branch_success_index
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"]
|
||||
]
|
||||
assert resume_chunk_indices
|
||||
resume_start_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
resume_success_index = next(
|
||||
index
|
||||
for index, event in enumerate(result.events)
|
||||
if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"]
|
||||
)
|
||||
assert resume_start_index < min(resume_chunk_indices)
|
||||
assert max(resume_chunk_indices) < resume_success_index
|
||||
|
||||
started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)]
|
||||
assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]]
|
||||
@@ -27,7 +27,8 @@ from .test_mock_nodes import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
|
||||
|
||||
@@ -42,7 +42,8 @@ def test_mock_iteration_node_preserves_config():
|
||||
"""Test that MockIterationNode preserves mock configuration."""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode
|
||||
|
||||
@@ -103,7 +104,8 @@ def test_mock_loop_node_preserves_config():
|
||||
"""Test that MockLoopNode preserves mock configuration."""
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
|
||||
@@ -561,10 +562,11 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Import our MockNodeFactory instead of DifyNodeFactory
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
@@ -635,10 +637,11 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
# Import dependencies
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Import our MockNodeFactory instead of DifyNodeFactory
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
@@ -16,8 +16,8 @@ class TestMockTemplateTransformNode:
|
||||
|
||||
def test_mock_template_transform_node_default_output(self):
|
||||
"""Test that MockTemplateTransformNode processes templates with Jinja2."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -76,8 +76,8 @@ class TestMockTemplateTransformNode:
|
||||
|
||||
def test_mock_template_transform_node_custom_output(self):
|
||||
"""Test that MockTemplateTransformNode returns custom configured output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -137,8 +137,8 @@ class TestMockTemplateTransformNode:
|
||||
|
||||
def test_mock_template_transform_node_error_simulation(self):
|
||||
"""Test that MockTemplateTransformNode can simulate errors."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -196,8 +196,8 @@ class TestMockTemplateTransformNode:
|
||||
def test_mock_template_transform_node_with_variables(self):
|
||||
"""Test that MockTemplateTransformNode processes templates with variables."""
|
||||
from core.variables import StringVariable
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -262,8 +262,8 @@ class TestMockCodeNode:
|
||||
|
||||
def test_mock_code_node_default_output(self):
|
||||
"""Test that MockCodeNode returns default output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -323,8 +323,8 @@ class TestMockCodeNode:
|
||||
|
||||
def test_mock_code_node_with_output_schema(self):
|
||||
"""Test that MockCodeNode generates outputs based on schema."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -392,8 +392,8 @@ class TestMockCodeNode:
|
||||
|
||||
def test_mock_code_node_custom_output(self):
|
||||
"""Test that MockCodeNode returns custom configured output."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -463,8 +463,8 @@ class TestMockNodeFactory:
|
||||
|
||||
def test_code_and_template_nodes_mocked_by_default(self):
|
||||
"""Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy)."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -504,8 +504,8 @@ class TestMockNodeFactory:
|
||||
|
||||
def test_factory_creates_mock_template_transform_node(self):
|
||||
"""Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -555,8 +555,8 @@ class TestMockNodeFactory:
|
||||
|
||||
def test_factory_creates_mock_code_node(self):
|
||||
"""Test that MockNodeFactory creates MockCodeNode for code type."""
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
# Create test parameters
|
||||
graph_init_params = GraphInitParams(
|
||||
|
||||
@@ -13,7 +13,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
@@ -27,6 +27,7 @@ from core.workflow.graph_events import (
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import redis
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
|
||||
|
||||
@@ -52,6 +52,29 @@ class TestRedisStopIntegration:
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "Test stop"
|
||||
|
||||
def test_graph_engine_manager_sends_pause_command(self):
|
||||
"""Test that GraphEngineManager correctly sends pause command through Redis."""
|
||||
task_id = "test-task-pause-123"
|
||||
expected_channel_key = f"workflow:{task_id}:commands"
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_pipeline = MagicMock()
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources")
|
||||
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert command_data["reason"] == "Awaiting resources"
|
||||
|
||||
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
|
||||
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
|
||||
task_id = "test-task-456"
|
||||
@@ -105,28 +128,37 @@ class TestRedisStopIntegration:
|
||||
channel_key = "workflow:test:commands"
|
||||
channel = RedisChannel(mock_redis, channel_key)
|
||||
|
||||
# Create abort command
|
||||
# Create commands
|
||||
abort_command = AbortCommand(reason="User requested stop")
|
||||
pause_command = PauseCommand(reason="User requested pause")
|
||||
|
||||
# Execute
|
||||
channel.send_command(abort_command)
|
||||
channel.send_command(pause_command)
|
||||
|
||||
# Verify
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
mock_redis.pipeline.assert_called()
|
||||
|
||||
# Check rpush was called
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert len(calls) == 2
|
||||
assert calls[0][0][0] == channel_key
|
||||
assert calls[1][0][0] == channel_key
|
||||
|
||||
# Verify serialized command
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "User requested stop"
|
||||
# Verify serialized commands
|
||||
abort_command_json = calls[0][0][1]
|
||||
abort_command_data = json.loads(abort_command_json)
|
||||
assert abort_command_data["command_type"] == CommandType.ABORT.value
|
||||
assert abort_command_data["reason"] == "User requested stop"
|
||||
|
||||
# Check expire was set
|
||||
mock_pipeline.expire.assert_called_once_with(channel_key, 3600)
|
||||
pause_command_json = calls[1][0][1]
|
||||
pause_command_data = json.loads(pause_command_json)
|
||||
assert pause_command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert pause_command_data["reason"] == "User requested pause"
|
||||
|
||||
# Check expire was set for each
|
||||
assert mock_pipeline.expire.call_count == 2
|
||||
mock_pipeline.expire.assert_any_call(channel_key, 3600)
|
||||
|
||||
def test_redis_channel_fetch_commands(self):
|
||||
"""Test RedisChannel correctly fetches and deserializes commands."""
|
||||
@@ -143,12 +175,17 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
# Mock command data
|
||||
abort_command_json = json.dumps({"command_type": CommandType.ABORT, "reason": "Test abort", "payload": None})
|
||||
abort_command_json = json.dumps(
|
||||
{"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None}
|
||||
)
|
||||
pause_command_json = json.dumps(
|
||||
{"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None}
|
||||
)
|
||||
|
||||
# Mock pipeline execute to return commands
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [
|
||||
[abort_command_json.encode()], # lrange result
|
||||
[abort_command_json.encode(), pause_command_json.encode()], # lrange result
|
||||
True, # delete result
|
||||
]
|
||||
|
||||
@@ -159,10 +196,13 @@ class TestRedisStopIntegration:
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
# Verify
|
||||
assert len(commands) == 1
|
||||
assert len(commands) == 2
|
||||
assert isinstance(commands[0], AbortCommand)
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert commands[0].reason == "Test abort"
|
||||
assert isinstance(commands[1], PauseCommand)
|
||||
assert commands[1].command_type == CommandType.PAUSE
|
||||
assert commands[1].reason == "Pause requested"
|
||||
|
||||
# Verify Redis operations
|
||||
pending_pipe.get.assert_called_once_with(f"{channel_key}:pending")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user