feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@@ -4,8 +4,8 @@ import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[False],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[True],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
@@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool = True,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
@@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_runtime_state: GraphRuntimeState,
pause_state_config: PauseStateLayerConfig | None = None,
):
"""
Resume a paused advanced chat execution.
"""
return self._generate(
workflow=workflow,
user=user,
invoke_from=application_generate_entity.invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
message=message,
stream=application_generate_entity.stream,
pause_state_config=pause_state_config,
graph_runtime_state=graph_runtime_state,
)
def single_iteration_generate(
@@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Conversation | None = None,
message: Message | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
is_first_conversation = conversation is None
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
# message_ = session.get(Message, message.id)
# assert message_ is not None
# message = message_
# db.session.refresh(workflow)
# db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
@@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
"""
Generate worker in a new thread.
@@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:
@@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
invoke_from=invoke_from,
user_from=user_from,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,

View File

@@ -24,6 +24,8 @@ from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -42,6 +44,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._task_state = WorkflowTaskState()
self._seed_task_state_from_message(message)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
@@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_tenant_id = workflow.tenant_id
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
@@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
@@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
reason=event.reason,
)
yield workflow_start_resp
@@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
yield from responses
resolved_state: GraphRuntimeState | None = None
try:
resolved_state = self._ensure_graph_runtime_initialized()
except ValueError:
resolved_state = None
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
message = self._get_message(session=session)
if message is not None:
message.status = MessageStatus.PAUSED
self._message_saved_on_pause = True
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
@@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
# Save message unless it has already been persisted on pause.
if not self._message_saved_on_pause:
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""Handle message replace events."""
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
self._persist_human_input_extra_content(node_id=event.node_id)
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
if not self._workflow_run_id or not self._message_id:
return
if form_id is None:
if node_id is None:
return
form_id = self._load_human_input_form_id(node_id=node_id)
if form_id is None:
logger.warning(
"HumanInput form not found for workflow run %s node %s",
self._workflow_run_id,
node_id,
)
return
with self._database_session() as session:
exists_stmt = select(HumanInputContent).where(
HumanInputContent.workflow_run_id == self._workflow_run_id,
HumanInputContent.message_id == self._message_id,
HumanInputContent.form_id == form_id,
)
if session.scalar(exists_stmt) is not None:
return
content = HumanInputContent(
workflow_run_id=self._workflow_run_id,
message_id=self._message_id,
form_id=form_id,
)
session.add(content)
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
if form is None:
return None
return form.id
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
@@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
@@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueMessageReplaceEvent: self._handle_message_replace_event,
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
@@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
if message is None:
return
if message.status == MessageStatus.PAUSED:
message.status = MessageStatus.NORMAL
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer

View File

@@ -5,9 +5,14 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -19,9 +24,13 @@ from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueWorkflowPausedEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
HumanInputFormFilledResponse,
HumanInputFormTimeoutResponse,
HumanInputRequiredResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@@ -31,7 +40,9 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
StreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
@@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.human_input import HumanInputForm
from models.workflow import WorkflowRun
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@@ -191,6 +207,7 @@ class WorkflowResponseConverter:
task_id: str,
workflow_run_id: str,
workflow_id: str,
reason: WorkflowStartReason,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
@@ -204,6 +221,7 @@ class WorkflowResponseConverter:
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
reason=reason,
),
)
@@ -264,6 +282,160 @@ class WorkflowResponseConverter:
),
)
def workflow_pause_to_stream_response(
self,
*,
event: QueueWorkflowPausedEvent,
task_id: str,
graph_runtime_state: GraphRuntimeState,
) -> list[StreamResponse]:
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
)
paused_at = naive_utc_now()
elapsed_time = (paused_at - started_at).total_seconds()
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
if human_input_form_ids:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with Session(bind=db.engine) as session:
for form_id, expiration_time in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = expiration_time
responses: list[StreamResponse] = []
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
expiration_time = expiration_times_by_form_id.get(reason.form_id)
if expiration_time is None:
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
responses.append(
HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputRequiredResponse.Data(
form_id=reason.form_id,
node_id=reason.node_id,
node_title=reason.node_title,
form_content=reason.form_content,
inputs=reason.inputs,
actions=reason.actions,
display_in_ui=reason.display_in_ui,
form_token=reason.form_token,
resolved_default_values=reason.resolved_default_values,
expiration_time=int(expiration_time.timestamp()),
),
)
)
responses.append(
WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=run_id,
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
),
)
)
return responses
def human_input_form_filled_to_stream_response(
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
) -> HumanInputFormTimeoutResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormTimeoutResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormTimeoutResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
expiration_time=int(event.expiration_time.timestamp()),
),
)
@classmethod
def workflow_run_result_to_finish_response(
cls,
*,
task_id: str,
workflow_run: WorkflowRun,
creator_user: Account | EndUser,
) -> WorkflowFinishStreamResponse:
run_id = workflow_run.id
elapsed_time = workflow_run.elapsed_time
encoded_outputs = workflow_run.outputs_dict
finished_at = workflow_run.finished_at
assert finished_at is not None
created_by: Mapping[str, object]
user = creator_user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=encoded_outputs,
error=workflow_run.error,
elapsed_time=elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=cls.fetch_files_from_node_outputs(encoded_outputs),
exceptions_count=workflow_run.exceptions_count,
),
)
def workflow_node_start_to_stream_response(
self,
*,
@@ -592,7 +764,8 @@ class WorkflowResponseConverter:
),
)
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
@classmethod
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@@ -601,7 +774,7 @@ class WorkflowResponseConverter:
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list

View File

@@ -1,6 +1,6 @@
import json
import logging
from collections.abc import Generator
from collections.abc import Callable, Generator, Mapping
from typing import Union, cast
from sqlalchemy import select
@@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import (
@@ -27,6 +29,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
@@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
created_new_conversation = conversation is None
try:
if not conversation:
conversation = Conversation(
@@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add_all(message_files)
db.session.commit()
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
application_generate_entity.conversation_id = conversation.id
application_generate_entity.is_new_conversation = created_new_conversation
return conversation, message
except Exception:
db.session.rollback()
@@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator):
raise MessageNotExistsError("Message not exists")
return message
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{workflow_run_id}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
on_subscribe=on_subscribe,
)

View File

@@ -0,0 +1,36 @@
from collections.abc import Callable, Generator, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
class MessageGenerator:
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{str(workflow_run_id)}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
)

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any
from core.app.entities.task_entities import StreamEvent
from libs.broadcast_channel.channel import Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
def stream_topic_events(
*,
topic: Topic,
idle_timeout: float,
ping_interval: float | None = None,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
terminal_values = _normalize_terminal_events(terminal_events)
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
# on_subscribe fires only after the Redis subscription is active.
# This is used to gate task start and reduce pub/sub race for the first event.
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive(timeout=0.1)
except SubscriptionClosedError:
return
if msg is None:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
return
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
event = json.loads(msg)
yield event
if not isinstance(event, dict):
continue
event_type = event.get("event")
if event_type in terminal_values:
return
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:
if isinstance(item, StreamEvent):
values.add(item.value)
else:
values.add(str(item))
return values

View File

@@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, EndUser
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
if TYPE_CHECKING:
@@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
@@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
@@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
@@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(self, *, workflow_run_id: str) -> None:
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@TBD
Resume a paused workflow execution using the persisted runtime state.
"""
pass
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=application_generate_entity.invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=application_generate_entity.stream,
variable_loader=variable_loader,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_state_config,
)
def _generate(
self,
@@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
@@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def single_loop_generate(
@@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def _generate_worker(
@@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
) -> None:
"""
Generate worker in a new thread.
@@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:

View File

@@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
@@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
# Create a variable pool.
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,

View File

@@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class WorkflowPausedInBlockingModeError(BaseHTTPException):
error_code = "workflow_paused_in_blocking_mode"
description = "Workflow execution paused for human input; blocking response mode is not supported."
code = 400

View File

@@ -16,6 +16,8 @@ from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -32,6 +34,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@@ -46,11 +49,13 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.runtime import GraphRuntimeState
@@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
),
)
@@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
if event.reason == WorkflowStartReason.INITIAL:
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
reason=event.reason,
)
yield start_resp
@@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
yield from responses
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
@@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id, event=event
)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
@@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
@@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)

View File

@@ -1,3 +1,4 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
@@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -22,22 +25,27 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
class WorkflowBasedAppRunner:
@@ -327,7 +338,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent())
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@@ -338,6 +349,38 @@ class WorkflowBasedAppRunner:
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, GraphRunPausedEvent):
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._enqueue_human_input_notifications(event.reasons)
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
)
)
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
self._publish_event(
QueueHumanInputFormFilledEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
self._publish_event(
QueueHumanInputFormTimeoutEvent(
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
)
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
@@ -544,5 +587,19 @@ class WorkflowBasedAppRunner:
)
)
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
for reason in reasons:
if not isinstance(reason, HumanInputRequired):
continue
if not reason.form_id:
continue
try:
dispatch_human_input_email_task.apply_async(
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
queue="mail",
)
except Exception: # pragma: no cover - defensive logging
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = None
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
"""
conversation_id: str | None = None
is_new_conversation: bool = False
parent_message_id: str | None = Field(
default=None,
description=(

View File

@@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@@ -46,6 +48,9 @@ class QueueEvent(StrEnum):
PING = "ping"
STOP = "stop"
RETRY = "retry"
PAUSE = "pause"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class AppQueueEvent(BaseModel):
@@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
class QueueWorkflowSucceededEvent(AppQueueEvent):
@@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent):
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueHumanInputFormFilledEvent(AppQueueEvent):
"""
QueueHumanInputFormFilledEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
rendered_content: str
action_id: str
action_text: str
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""
QueueHumanInputFormTimeoutEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
node_id: str
node_type: NodeType
node_title: str
expiration_time: datetime
class QueueMessage(BaseModel):
"""
QueueMessage abstract entity
@@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueWorkflowPausedEvent(AppQueueEvent):
"""
QueueWorkflowPausedEvent entity
"""
event: QueueEvent = QueueEvent.PAUSE
reasons: Sequence[PauseReason] = Field(default_factory=list)
outputs: Mapping[str, object] = Field(default_factory=dict)
paused_nodes: Sequence[str] = Field(default_factory=list)

View File

@@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@@ -69,6 +71,7 @@ class StreamEvent(StrEnum):
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_PAUSED = "workflow_paused"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
@@ -82,6 +85,9 @@ class StreamEvent(StrEnum):
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
HUMAN_INPUT_REQUIRED = "human_input_required"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class StreamResponse(BaseModel):
@@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse):
workflow_id: str
inputs: Mapping[str, Any]
created_at: int
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
workflow_run_id: str
@@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
total_steps: int
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
finished_at: int | None
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
@@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse):
data: Data
class WorkflowPauseStreamResponse(StreamResponse):
"""
WorkflowPauseStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
workflow_run_id: str
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: WorkflowExecutionStatus
created_at: int
elapsed_time: float
total_tokens: int
total_steps: int
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
workflow_run_id: str
data: Data
class HumanInputRequiredResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int = Field(..., description="Unix timestamp in seconds")
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
workflow_run_id: str
data: Data
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data
class HumanInputFormTimeoutResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
expiration_time: int
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
workflow_run_id: str
data: Data
class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
total_tokens: int
total_steps: int
created_at: int
finished_at: int
finished_at: int | None
workflow_run_id: str
data: Data

View File

@@ -1,3 +1,4 @@
import contextlib
import logging
import time
import uuid
@@ -103,6 +104,14 @@ class RateLimit:
)
@contextlib.contextmanager
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
request_id = rate_limit.enter(request_id)
yield
if request_id is not None:
rate_limit.exit(request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
self.rate_limit = rate_limit

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
@@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
return self.generate_entity.entity
@dataclass(frozen=True)
class PauseStateLayerConfig:
"""Configuration container for instantiating pause persistence layers."""
session_factory: Engine | sessionmaker[Session]
state_owner_user_id: str
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,

View File

@@ -82,10 +82,11 @@ class MessageCycleManager:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
is_first_message = self._application_generate_entity.is_new_conversation
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
thread: Thread | None = None
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
@@ -101,9 +102,10 @@ class MessageCycleManager:
thread.daemon = True
thread.start()
return thread
if is_first_message:
self._application_generate_entity.is_new_conversation = False
return None
return thread
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from models.execution_extra_content import ExecutionContentType
class HumanInputFormDefinition(BaseModel):
model_config = ConfigDict(frozen=True)
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
class HumanInputFormSubmissionData(BaseModel):
model_config = ConfigDict(frozen=True)
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
class HumanInputContent(BaseModel):
model_config = ConfigDict(frozen=True)
workflow_run_id: str
submitted: bool
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
__all__ = [
"ExecutionExtraContentDomainModel",
"HumanInputContent",
"HumanInputFormDefinition",
"HumanInputFormSubmissionData",
]

View File

@@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.provider import (
LoadBalancingModelConfig,
Provider,

View File

@@ -15,10 +15,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -31,8 +28,8 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@@ -469,6 +466,8 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.engine import db
from models.model import Message

View File

@@ -1,3 +1,4 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Union
@@ -11,6 +12,7 @@ from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models import Account
@@ -101,6 +103,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
@@ -112,7 +119,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
@@ -159,6 +168,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return WorkflowAppGenerator().generate(
app_model=app,
workflow=workflow,
@@ -167,6 +181,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
call_depth=1,
pause_state_config=pause_config,
)
@classmethod

View File

@@ -1,19 +1,18 @@
"""
Repository implementations for data access.
"""Repository implementations for data access."""
This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from __future__ import annotations
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowExecutionRepository",
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@@ -0,0 +1,553 @@
import dataclasses
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
HumanInputNodeData,
MemberRecipient,
WebAppDeliveryMethod,
)
from core.workflow.nodes.human_input.enums import (
DeliveryMethodType,
HumanInputFormKind,
HumanInputFormStatus,
)
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormNotFoundError,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
)
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.account import Account, TenantAccountJoin
from models.human_input import (
BackstageRecipientPayload,
ConsoleDeliveryPayload,
ConsoleRecipientPayload,
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
StandaloneWebAppRecipientPayload,
)
@dataclasses.dataclass(frozen=True)
class _DeliveryAndRecipients:
delivery: HumanInputDelivery
recipients: Sequence[HumanInputFormRecipient]
@dataclasses.dataclass(frozen=True)
class _WorkspaceMemberInfo:
user_id: str
email: str
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
def __init__(self, recipient_model: HumanInputFormRecipient):
self._recipient_model = recipient_model
@property
def id(self) -> str:
return self._recipient_model.id
@property
def token(self) -> str:
if self._recipient_model.access_token is None:
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
return self._recipient_model.access_token
class _HumanInputFormEntityImpl(HumanInputFormEntity):
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
self._form_model = form_model
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
self._web_app_recipient = next(
(
recipient
for recipient in recipient_models
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
),
None,
)
self._console_recipient = next(
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
None,
)
self._submitted_data: Mapping[str, Any] | None = (
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
)
@property
def id(self) -> str:
return self._form_model.id
@property
def web_app_token(self):
if self._console_recipient is not None:
return self._console_recipient.access_token
if self._web_app_recipient is None:
return None
return self._web_app_recipient.access_token
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return list(self._recipients)
@property
def rendered_content(self) -> str:
return self._form_model.rendered_content
@property
def selected_action_id(self) -> str | None:
return self._form_model.selected_action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self._submitted_data
@property
def submitted(self) -> bool:
return self._form_model.submitted_at is not None
@property
def status(self) -> HumanInputFormStatus:
return self._form_model.status
@property
def expiration_time(self) -> datetime:
return self._form_model.expiration_time
@dataclasses.dataclass(frozen=True)
class HumanInputFormRecord:
form_id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_kind: HumanInputFormKind
definition: FormDefinition
rendered_content: str
created_at: datetime
expiration_time: datetime
status: HumanInputFormStatus
selected_action_id: str | None
submitted_data: Mapping[str, Any] | None
submitted_at: datetime | None
submission_user_id: str | None
submission_end_user_id: str | None
completed_by_recipient_id: str | None
recipient_id: str | None
recipient_type: RecipientType | None
access_token: str | None
@property
def submitted(self) -> bool:
return self.submitted_at is not None
@classmethod
def from_models(
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
) -> "HumanInputFormRecord":
definition_payload = json.loads(form_model.form_definition)
if "expiration_time" not in definition_payload:
definition_payload["expiration_time"] = form_model.expiration_time
return cls(
form_id=form_model.id,
workflow_run_id=form_model.workflow_run_id,
node_id=form_model.node_id,
tenant_id=form_model.tenant_id,
app_id=form_model.app_id,
form_kind=form_model.form_kind,
definition=FormDefinition.model_validate(definition_payload),
rendered_content=form_model.rendered_content,
created_at=form_model.created_at,
expiration_time=form_model.expiration_time,
status=form_model.status,
selected_action_id=form_model.selected_action_id,
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
submitted_at=form_model.submitted_at,
submission_user_id=form_model.submission_user_id,
submission_end_user_id=form_model.submission_end_user_id,
completed_by_recipient_id=form_model.completed_by_recipient_id,
recipient_id=recipient_model.id if recipient_model else None,
recipient_type=recipient_model.recipient_type if recipient_model else None,
access_token=recipient_model.access_token if recipient_model else None,
)
class _InvalidTimeoutStatusError(ValueError):
pass
class HumanInputFormRepositoryImpl:
def __init__(
self,
session_factory: sessionmaker | Engine,
tenant_id: str,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._tenant_id = tenant_id
def _delivery_method_to_model(
self,
session: Session,
form_id: str,
delivery_method: DeliveryChannelConfig,
) -> _DeliveryAndRecipients:
delivery_id = str(uuidv7())
delivery_model = HumanInputDelivery(
id=delivery_id,
form_id=form_id,
delivery_method_type=delivery_method.type,
delivery_config_id=delivery_method.id,
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
if isinstance(delivery_method, WebAppDeliveryMethod):
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
)
recipients.append(recipient_model)
elif isinstance(delivery_method, EmailDeliveryMethod):
email_recipients_config = delivery_method.config.recipients
recipients.extend(
self._build_email_recipients(
session=session,
form_id=form_id,
delivery_id=delivery_id,
recipients_config=email_recipients_config,
)
)
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
def _build_email_recipients(
self,
session: Session,
form_id: str,
delivery_id: str,
recipients_config: EmailRecipients,
) -> list[HumanInputFormRecipient]:
member_user_ids = [
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
]
external_emails = [
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
]
if recipients_config.whole_workspace:
members = self._query_all_workspace_members(session=session)
else:
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
return self._create_email_recipients_from_resolved(
form_id=form_id,
delivery_id=delivery_id,
members=members,
external_emails=external_emails,
)
@staticmethod
def _create_email_recipients_from_resolved(
*,
form_id: str,
delivery_id: str,
members: Sequence[_WorkspaceMemberInfo],
external_emails: Sequence[str],
) -> list[HumanInputFormRecipient]:
recipient_models: list[HumanInputFormRecipient] = []
seen_emails: set[str] = set()
for member in members:
if not member.email:
continue
if member.email in seen_emails:
continue
seen_emails.add(member.email)
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
recipient_models.append(
HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=payload,
)
)
for email in external_emails:
if not email:
continue
if email in seen_emails:
continue
seen_emails.add(email)
recipient_models.append(
HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=EmailExternalRecipientPayload(email=email),
)
)
return recipient_models
def _query_all_workspace_members(
self,
session: Session,
) -> list[_WorkspaceMemberInfo]:
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def _query_workspace_members_by_ids(
self,
session: Session,
restrict_to_user_ids: Sequence[str],
) -> list[_WorkspaceMemberInfo]:
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
if not unique_ids:
return []
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
stmt = stmt.where(Account.id.in_(unique_ids))
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
with self._session_factory(expire_on_commit=False) as session, session.begin():
# Generate unique form ID
form_id = str(uuidv7())
start_time = naive_utc_now()
node_expiration = form_config.expiration_time(start_time)
form_definition = FormDefinition(
form_content=form_config.form_content,
inputs=form_config.inputs,
user_actions=form_config.user_actions,
rendered_content=params.rendered_content,
expiration_time=node_expiration,
default_values=dict(params.resolved_default_values),
display_in_ui=params.display_in_ui,
node_title=form_config.title,
)
form_model = HumanInputForm(
id=form_id,
tenant_id=self._tenant_id,
app_id=params.app_id,
workflow_run_id=params.workflow_execution_id,
form_kind=params.form_kind,
node_id=params.node_id,
form_definition=form_definition.model_dump_json(),
rendered_content=params.rendered_content,
expiration_time=node_expiration,
created_at=start_time,
)
session.add(form_model)
recipient_models: list[HumanInputFormRecipient] = []
for delivery in params.delivery_methods:
delivery_and_recipients = self._delivery_method_to_model(
session=session,
form_id=form_id,
delivery_method=delivery,
)
session.add(delivery_and_recipients.delivery)
session.add_all(delivery_and_recipients.recipients)
recipient_models.extend(delivery_and_recipients.recipients)
if params.console_recipient_required and not any(
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
):
console_delivery_id = str(uuidv7())
console_delivery = HumanInputDelivery(
id=console_delivery_id,
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
delivery_config_id=None,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
console_recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=console_delivery_id,
recipient_type=RecipientType.CONSOLE,
recipient_payload=ConsoleRecipientPayload(
account_id=params.console_creator_account_id,
).model_dump_json(),
)
session.add(console_delivery)
session.add(console_recipient)
recipient_models.append(console_recipient)
if params.backstage_recipient_required and not any(
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
):
backstage_delivery_id = str(uuidv7())
backstage_delivery = HumanInputDelivery(
id=backstage_delivery_id,
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
delivery_config_id=None,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
backstage_recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=backstage_delivery_id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload(
account_id=params.console_creator_account_id,
).model_dump_json(),
)
session.add(backstage_delivery)
session.add(backstage_recipient)
recipient_models.append(backstage_recipient)
session.flush()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
form_query = select(HumanInputForm).where(
HumanInputForm.workflow_run_id == workflow_execution_id,
HumanInputForm.node_id == node_id,
HumanInputForm.tenant_id == self._tenant_id,
)
with self._session_factory(expire_on_commit=False) as session:
form_model: HumanInputForm | None = session.scalars(form_query).first()
if form_model is None:
return None
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
recipient_models = session.scalars(recipient_query).all()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
class HumanInputFormSubmissionRepository:
"""Repository for fetching and submitting human input forms."""
def __init__(self, session_factory: sessionmaker | Engine):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def get_by_form_id_and_recipient_type(
self,
form_id: str,
recipient_type: RecipientType,
) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.form_id == form_id,
HumanInputFormRecipient.recipient_type == recipient_type,
)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def mark_submitted(
self,
*,
form_id: str,
recipient_id: str | None,
selected_action_id: str,
form_data: Mapping[str, Any],
submission_user_id: str | None,
submission_end_user_id: str | None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.status = HumanInputFormStatus.SUBMITTED
form_model.submission_user_id = submission_user_id
form_model.submission_end_user_id = submission_end_user_id
form_model.completed_by_recipient_id = recipient_id
session.add(form_model)
session.flush()
session.refresh(form_model)
if recipient_model is not None:
session.refresh(recipient_model)
return HumanInputFormRecord.from_models(form_model, recipient_model)
def mark_timeout(
self,
*,
form_id: str,
timeout_status: HumanInputFormStatus,
reason: str | None = None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
# already handled or submitted
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
return HumanInputFormRecord.from_models(form_model, None)
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
raise FormNotFoundError(f"form already submitted, id={form_id}")
form_model.status = timeout_status
form_model.selected_action_id = None
form_model.submitted_data = None
form_model.submission_user_id = None
form_model.submission_end_user_id = None
form_model.completed_by_recipient_id = None
# Reason is recorded in status/error downstream; not stored on form.
session.add(form_model)
session.flush()
session.refresh(form_model)
return HumanInputFormRecord.from_models(form_model, None)

View File

@@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
)
if self._app_id:

View File

@@ -1,4 +1,5 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
from libs.exception import BaseHTTPException
class ToolProviderNotFoundError(ValueError):
@@ -37,6 +38,12 @@ class ToolCredentialPolicyViolationError(ValueError):
pass
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
error_code = "workflow_tool_human_input_not_supported"
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
code = 400
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@@ -3,6 +3,8 @@ from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
@@ -45,6 +47,13 @@ class WorkflowToolConfigurationUtils:
return [outputs_by_variable[variable] for variable in variable_order]
@classmethod
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
raise WorkflowToolHumanInputNotSupportedError()
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]

View File

@@ -98,6 +98,10 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
# NOTE(QuantumGhost): We explicitly set `pause_state_config` to `None`
# because workflow pausing mechanisms (such as HumanInput) are not
# supported within WorkflowTool execution context.
pause_state_config=None,
)
assert isinstance(result, dict)
data = result.get("data", {})

View File

@@ -2,10 +2,12 @@ from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowStartReason",
]

View File

@@ -5,6 +5,16 @@ from pydantic import BaseModel, Field
class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information
that remain constant throughout a single execution of the graph engine.
A single execution is defined as follows: as long as the execution has not reached
its conclusion, it is considered one execution. For instance, if a workflow is suspended
and later resumed, it is still regarded as a single execution, not two.
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
"""
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")

View File

@@ -1,8 +1,11 @@
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
@@ -11,10 +14,31 @@ class PauseReasonType(StrEnum):
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
actions: list[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
node_id: str
node_title: str
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
# `output_variable_name` to their resolved values.
#
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
# `resolved_default_values` is `{"name": "John"}`.
#
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
# The `form_token` is the token used to submit the form via UI surfaces. It corresponds to
# `HumanInputFormRecipient.access_token`.
#
# This field is `None` if webapp delivery is not set and not
# in orchestrating mode.
form_token: str | None = None
class SchedulingPause(BaseModel):

View File

@@ -0,0 +1,8 @@
from enum import StrEnum
class WorkflowStartReason(StrEnum):
"""Reason for workflow start events across graph/queue/SSE layers."""
INITIAL = "initial" # First start of a workflow run.
RESUMPTION = "resumption" # Start triggered after resuming a paused run.

View File

@@ -0,0 +1,15 @@
import time
def get_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

View File

@@ -2,12 +2,14 @@
GraphEngine configuration models.
"""
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class GraphEngineConfig(BaseModel):
"""Configuration for GraphEngine worker pool scaling."""
model_config = ConfigDict(frozen=True)
min_workers: int = 1
max_workers: int = 5
scale_up_threshold: int = 3

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeState
from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol
from .node_execution import NodeExecution
@@ -236,3 +237,6 @@ class GraphExecution:
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1
_: GraphExecutionProtocol = GraphExecution(workflow_id="")

View File

@@ -192,9 +192,13 @@ class EventHandler:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
if self._graph_execution.is_paused:
for node_id in ready_nodes:
self._graph_runtime_state.register_deferred_node(node_id)
else:
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)

View File

@@ -14,6 +14,7 @@ from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from core.workflow.context import capture_current_context
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -55,6 +56,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_DEFAULT_CONFIG = GraphEngineConfig()
@final
class GraphEngine:
"""
@@ -70,7 +74,7 @@ class GraphEngine:
graph: Graph,
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
config: GraphEngineConfig,
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
@@ -234,7 +238,9 @@ class GraphEngine:
self._graph_execution.paused = False
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
start_event = GraphRunStartedEvent(
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
)
self._event_manager.notify_layers(start_event)
yield start_event
@@ -303,15 +309,17 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_start()
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@@ -327,7 +335,11 @@ class GraphEngine:
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
seen_nodes: set[str] = set()
for node_id in paused_nodes + deferred_nodes:
if node_id in seen_nodes:
continue
seen_nodes.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
@@ -345,8 +357,8 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
# Public property accessors for attributes that need external access
@property

View File

@@ -224,6 +224,8 @@ class GraphStateManager:
Returns:
Number of executing nodes
"""
# This count is a best-effort snapshot and can change concurrently.
# Only use it for pause-drain checks where scheduling is already frozen.
with self._lock:
return len(self._executing_nodes)

View File

@@ -83,12 +83,12 @@ class Dispatcher:
"""Main dispatcher loop."""
try:
self._process_commands()
paused = False
while not self._stop_event.is_set():
if (
self._execution_coordinator.aborted
or self._execution_coordinator.paused
or self._execution_coordinator.execution_complete
):
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
break
if self._execution_coordinator.paused:
paused = True
break
self._execution_coordinator.check_scaling()
@@ -101,13 +101,10 @@ class Dispatcher:
time.sleep(0.1)
self._process_commands()
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
if paused:
self._drain_events_until_idle()
else:
self._drain_event_queue()
except Exception as e:
logger.exception("Dispatcher error")
@@ -122,3 +119,24 @@ class Dispatcher:
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()
def _drain_event_queue(self) -> None:
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
def _drain_events_until_idle(self) -> None:
while not self._stop_event.is_set():
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
if not self._execution_coordinator.has_executing_nodes():
break
self._drain_event_queue()

View File

@@ -94,3 +94,11 @@ class ExecutionCoordinator:
self._worker_pool.stop()
self._state_manager.clear_executing()
def has_executing_nodes(self) -> bool:
"""Return True if any nodes are currently marked as executing."""
# This check is only safe once execution has already paused.
# Before pause, executing state can change concurrently, which makes the result unreliable.
if not self._graph_execution.is_paused:
raise AssertionError("has_executing_nodes should only be called after execution is paused")
return self._state_manager.get_executing_count() > 0

View File

@@ -38,6 +38,8 @@ from .loop import (
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
@@ -60,6 +62,8 @@ __all__ = [
"NodeRunAgentLogEvent",
"NodeRunExceptionEvent",
"NodeRunFailedEvent",
"NodeRunHumanInputFormFilledEvent",
"NodeRunHumanInputFormTimeoutEvent",
"NodeRunIterationFailedEvent",
"NodeRunIterationNextEvent",
"NodeRunIterationStartedEvent",

View File

@@ -1,11 +1,16 @@
from pydantic import Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph_events import BaseGraphEvent
class GraphRunStartedEvent(BaseGraphEvent):
pass
# Reason is emitted for workflow start events and is always set.
reason: WorkflowStartReason = Field(
default=WorkflowStartReason.INITIAL,
description="reason for workflow start",
)
class GraphRunSucceededEvent(BaseGraphEvent):

View File

@@ -54,6 +54,22 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form is submitted and before the node finishes."""
node_title: str = Field(..., description="HumanInput node title")
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
action_id: str = Field(..., description="User action identifier chosen in the form.")
action_text: str = Field(..., description="Display text of the chosen action button.")
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form times out."""
node_title: str = Field(..., description="HumanInput node title")
expiration_time: datetime = Field(..., description="Form expiration time")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")

View File

@@ -13,6 +13,8 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
@@ -23,6 +25,8 @@ from .node import (
__all__ = [
"AgentLogEvent",
"HumanInputFormFilledEvent",
"HumanInputFormTimeoutEvent",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",

View File

@@ -47,3 +47,19 @@ class StreamCompletedEvent(NodeEventBase):
class PauseRequestedEvent(NodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
class HumanInputFormFilledEvent(NodeEventBase):
"""Event emitted when a human input form is submitted."""
node_title: str
rendered_content: str
action_id: str
action_text: str
class HumanInputFormTimeoutEvent(NodeEventBase):
"""Event emitted when a human input form times out."""
node_title: str
expiration_time: datetime

View File

@@ -18,6 +18,8 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@@ -34,6 +36,8 @@ from core.workflow.graph_events import (
)
from core.workflow.node_events import (
AgentLogEvent,
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
@@ -61,6 +65,15 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
attribute to track files generated by the LLM). However, these states are not persisted
when the workflow is suspended or resumed. If a node needs its state to be preserved
across workflow suspension and resumption, it should include the relevant state data
in its output.
"""
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]):
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
if self._node_execution_id:
return self._node_execution_id
resumed_execution_id = self._restore_execution_id_from_runtime_state()
if resumed_execution_id:
self._node_execution_id = resumed_execution_id
return self._node_execution_id
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
execution_id = node_execution.execution_id
if not execution_id:
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]):
metadata=event.metadata,
)
@_dispatch.register
def _(self, event: HumanInputFormFilledEvent):
return NodeRunHumanInputFormFilledEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
@_dispatch.register
def _(self, event: HumanInputFormTimeoutEvent):
return NodeRunHumanInputFormTimeoutEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(

View File

@@ -1,3 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]
"""
Human Input node implementation.
"""

View File

@@ -1,10 +1,350 @@
from pydantic import Field
"""
Human Input node entities.
"""
import re
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
class _WebAppDeliveryConfig(BaseModel):
"""Configuration for webapp delivery method."""
pass # Empty for webapp delivery
class MemberRecipient(BaseModel):
"""Member recipient for email delivery."""
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
user_id: str
class ExternalRecipient(BaseModel):
"""External recipient for email delivery."""
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
email: str
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
class EmailRecipients(BaseModel):
"""Email recipients configuration."""
# When true, recipients are the union of all workspace members and external items.
# Member items are ignored because they are already covered by the workspace scope.
# De-duplication is applied by email, with member recipients taking precedence.
whole_workspace: bool = False
items: list[EmailRecipient] = Field(default_factory=list)
class EmailDeliveryConfig(BaseModel):
"""Configuration for email delivery method."""
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
recipients: EmailRecipients
# the subject of email
subject: str
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
# represent the url to submit the form.
#
# It may also reference the output variable of the previous node with the syntax
# `{{#<node_id>.<field_name>#}}`.
body: str
debug_mode: bool = False
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
if not user_id:
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
return self.model_copy(update={"recipients": debug_recipients})
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
return self.model_copy(update={"recipients": debug_recipients})
@classmethod
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
"""Replace the url placeholder with provided value."""
return body.replace(cls.URL_PLACEHOLDER, url or "")
@classmethod
def render_body_template(
cls,
*,
body: str,
url: str | None,
variable_pool: VariablePool | None = None,
) -> str:
"""Render email body by replacing placeholders with runtime values."""
templated_body = cls.replace_url_placeholder(body, url)
if variable_pool is None:
return templated_body
return variable_pool.convert_template(templated_body).text
class _DeliveryMethodBase(BaseModel):
"""Base delivery method configuration."""
enabled: bool = True
id: uuid.UUID = Field(default_factory=uuid.uuid4)
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
return ()
class WebAppDeliveryMethod(_DeliveryMethodBase):
"""Webapp delivery method configuration."""
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
# The config field is not used currently.
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
class EmailDeliveryMethod(_DeliveryMethodBase):
"""Email delivery method configuration."""
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
config: EmailDeliveryConfig
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
variable_template_parser = VariableTemplateParser(template=self.config.body)
selectors: list[Sequence[str]] = []
for variable_selector in variable_template_parser.extract_variable_selectors():
value_selector = list(variable_selector.value_selector)
if len(value_selector) < SELECTORS_LENGTH:
continue
selectors.append(value_selector[:SELECTORS_LENGTH])
return selectors
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
def apply_debug_email_recipient(
method: DeliveryChannelConfig,
*,
enabled: bool,
user_id: str,
) -> DeliveryChannelConfig:
if not enabled:
return method
if not isinstance(method, EmailDeliveryMethod):
return method
if not method.config.debug_mode:
return method
debug_config = method.config.with_debug_recipient(user_id or "")
return method.model_copy(update={"config": debug_config})
class FormInputDefault(BaseModel):
"""Default configuration for form inputs."""
# NOTE: Ideally, a discriminated union would be used to model
# FormInputDefault. However, the UI requires preserving the previous
# value when switching between `VARIABLE` and `CONSTANT` types. This
# necessitates retaining all fields, making a discriminated union unsuitable.
type: PlaceholderType
# The selector of default variable, used when `type` is `VARIABLE`.
selector: Sequence[str] = Field(default_factory=tuple) #
# The value of the default, used when `type` is `CONSTANT`.
# TODO: How should we express JSON values?
value: str = ""
@model_validator(mode="after")
def _validate_selector(self) -> Self:
if self.type == PlaceholderType.CONSTANT:
return self
if len(self.selector) < SELECTORS_LENGTH:
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
return self
class FormInput(BaseModel):
"""Form input definition."""
type: FormInputType
output_variable_name: str
default: FormInputDefault | None = None
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
class UserAction(BaseModel):
"""User action configuration."""
# id is the identifier for this action.
# It also serves as the identifiers of output handle.
#
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
id: str = Field(max_length=20)
title: str = Field(max_length=20)
button_style: ButtonStyle = ButtonStyle.DEFAULT
@field_validator("id")
@classmethod
def _validate_id(cls, value: str) -> str:
if not _IDENTIFIER_PATTERN.match(value):
raise ValueError(
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
f"and contain only letters, numbers, or underscores."
)
return value
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
"""Human Input node data."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
timeout: int = 36
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
@field_validator("inputs")
@classmethod
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
seen_names: set[str] = set()
for form_input in inputs:
name = form_input.output_variable_name
if name in seen_names:
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
seen_names.add(name)
return inputs
@field_validator("user_actions")
@classmethod
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
seen_ids: set[str] = set()
for action in user_actions:
action_id = action.id
if action_id in seen_ids:
raise ValueError(f"duplicated user action id '{action_id}'")
seen_ids.add(action_id)
return user_actions
def is_webapp_enabled(self) -> bool:
for dm in self.delivery_methods:
if not dm.enabled:
continue
if dm.type == DeliveryMethodType.WEBAPP:
return True
return False
def expiration_time(self, start_time: datetime) -> datetime:
if self.timeout_unit == TimeoutUnit.HOUR:
return start_time + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
return start_time + timedelta(days=self.timeout)
else:
raise AssertionError("unknown timeout unit.")
def outputs_field_names(self) -> Sequence[str]:
field_names = []
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
field_names.append(match.group("field_name"))
return field_names
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
variable_mappings: dict[str, Sequence[str]] = {}
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
for selector in selectors:
if len(selector) < SELECTORS_LENGTH:
continue
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
form_template_parser = VariableTemplateParser(template=self.form_content)
_add_variable_selectors(
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
)
for delivery_method in self.delivery_methods:
if not delivery_method.enabled:
continue
_add_variable_selectors(delivery_method.extract_variable_selectors())
for input in self.inputs:
default_value = input.default
if default_value is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
default_value_key = ".".join(default_value.selector)
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
variable_mappings[qualified_variable_mapping_key] = default_value.selector
return variable_mappings
def find_action_text(self, action_id: str) -> str:
"""
Resolve action display text by id.
"""
for action in self.user_actions:
if action.id == action_id:
return action.title
return action_id
class FormDefinition(BaseModel):
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
rendered_content: str
expiration_time: datetime
# this is used to store the resolved default values
default_values: dict[str, Any] = Field(default_factory=dict)
# node_title records the title of the HumanInput node.
node_title: str | None = None
# display_in_ui controls whether the form should be displayed in UI surfaces.
display_in_ui: bool | None = None
class HumanInputSubmissionValidationError(ValueError):
pass
def validate_human_input_submission(
*,
inputs: Sequence[FormInput],
user_actions: Sequence[UserAction],
selected_action_id: str,
form_data: Mapping[str, Any],
) -> None:
available_actions = {action.id for action in user_actions}
if selected_action_id not in available_actions:
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
provided_inputs = set(form_data.keys())
missing_inputs = [
form_input.output_variable_name
for form_input in inputs
if form_input.output_variable_name not in provided_inputs
]
if missing_inputs:
missing_list = ", ".join(missing_inputs)
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")

View File

@@ -0,0 +1,72 @@
import enum
class HumanInputFormStatus(enum.StrEnum):
"""Status of a human input form."""
# Awaiting submission from any recipient. Forms stay in this state until
# submitted or a timeout rule applies.
WAITING = enum.auto()
# Global timeout reached. The workflow run is stopped and will not resume.
# This is distinct from node-level timeout.
EXPIRED = enum.auto()
# Submitted by a recipient; form data is available and execution resumes
# along the selected action edge.
SUBMITTED = enum.auto()
# Node-level timeout reached. The human input node should emit a timeout
# event and the workflow should resume along the timeout edge.
TIMEOUT = enum.auto()
class HumanInputFormKind(enum.StrEnum):
"""Kind of a human input form."""
RUNTIME = enum.auto() # Form created during workflow execution.
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
class DeliveryMethodType(enum.StrEnum):
"""Delivery method types for human input forms."""
# WEBAPP controls whether the form is delivered to the web app. It not only controls
# the standalone web app, but also controls the installed apps in the console.
WEBAPP = enum.auto()
EMAIL = enum.auto()
class ButtonStyle(enum.StrEnum):
"""Button styles for user actions."""
PRIMARY = enum.auto()
DEFAULT = enum.auto()
ACCENT = enum.auto()
GHOST = enum.auto()
class TimeoutUnit(enum.StrEnum):
"""Timeout unit for form expiration."""
HOUR = enum.auto()
DAY = enum.auto()
class FormInputType(enum.StrEnum):
"""Form input types."""
TEXT_INPUT = enum.auto()
PARAGRAPH = enum.auto()
class PlaceholderType(enum.StrEnum):
"""Default value types for form inputs."""
VARIABLE = enum.auto()
CONSTANT = enum.auto()
class EmailRecipientType(enum.StrEnum):
"""Email recipient types."""
MEMBER = enum.auto()
EXTERNAL = enum.auto()

View File

@@ -1,12 +1,42 @@
from collections.abc import Mapping
from typing import Any
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
NodeRunResult,
PauseRequestedEvent,
)
from core.workflow.node_events.base import NodeEventBase
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import HumanInputNodeData
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
if TYPE_CHECKING:
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
_SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
class HumanInputNode(Node[HumanInputNodeData]):
@@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
_SELECTED_BRANCH_KEY,
"selectedBranch",
"branch",
"branch_id",
@@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]):
"handle",
)
_node_data: HumanInputNodeData
_form_repository: HumanInputFormRepository
_OUTPUT_FIELD_ACTION_ID = "__action_id"
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if form_repository is None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self.tenant_id,
)
self._form_repository = form_repository
@classmethod
def version(cls) -> str:
return "1"
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
# TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self.node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self.node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
@@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]):
return candidate
return None
@property
def _workflow_execution_id(self) -> str:
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
assert workflow_exec_id is not None
return workflow_exec_id
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
required_event = self._human_input_required_event(form_entity)
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def resolve_default_values(self) -> Mapping[str, Any]:
variable_pool = self.graph_runtime_state.variable_pool
resolved_defaults: dict[str, Any] = {}
for input in self._node_data.inputs:
if (default_value := input.default) is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
resolved_value = variable_pool.get(default_value.selector)
if resolved_value is None:
# TODO: How should we handle this?
continue
resolved_defaults[input.output_variable_name] = (
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
)
return resolved_defaults
def _should_require_console_recipient(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
if self.invoke_from == InvokeFrom.EXPLORE:
return self._node_data.is_webapp_enabled()
return False
def _display_in_ui(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
user_id=self.user_id or "",
)
for method in enabled_methods
]
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
display_in_ui = self._display_in_ui()
form_token = form_entity.web_app_token
if display_in_ui and form_token is None:
raise AssertionError("Form token should be available for UI execution.")
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
display_in_ui=display_in_ui,
node_id=self.id,
node_title=node_data.title,
form_token=form_token,
resolved_default_values=resolved_default_values,
)
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Execute the human input node.
This method will:
1. Generate a unique form ID
2. Create form content with variable substitution
3. Create form in database
4. Send form via configured delivery methods
5. Suspend workflow execution
6. Wait for form submission to resume
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=self.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
rendered_content=self.render_form_content_before_submission(),
delivery_methods=self._effective_delivery_methods(),
display_in_ui=display_in_ui,
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
),
backstage_recipient_required=True,
)
form_entity = self._form_repository.create_form(params)
# Create human input required event
logger.info(
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
return
if (
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
or form.expiration_time <= naive_utc_now()
):
yield HumanInputFormTimeoutEvent(
node_title=self._node_data.title,
expiration_time=form.expiration_time,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
edge_source_handle=self._TIMEOUT_HANDLE,
)
)
return
if not form.submitted:
yield self._form_to_pause_event(form)
return
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
rendered_content = self.render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
)
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
node_title=self._node_data.title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=selected_action_id,
)
)
def render_form_content_before_submission(self) -> str:
"""
Process form content by substituting variables.
This method should:
1. Parse the form_content markdown
2. Substitute {{#node_name.var_name#}} with actual values
3. Keep {{#$output.field_name#}} placeholders for form inputs
"""
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
self._node_data.form_content,
)
return rendered_form_content.markdown
@staticmethod
def render_form_content_with_outputs(
form_content: str,
outputs: Mapping[str, Any],
field_names: Sequence[str],
) -> str:
"""
Replace {{#$output.xxx#}} placeholders with submitted values.
"""
rendered_content = form_content
for field_name in field_names:
placeholder = "{{#$output." + field_name + "#}}"
value = outputs.get(field_name)
if value is None:
replacement = ""
elif isinstance(value, (dict, list)):
replacement = json.dumps(value, ensure_ascii=False)
else:
replacement = str(value)
rendered_content = rendered_content.replace(placeholder, replacement)
return rendered_content
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
This method should parse:
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@@ -0,0 +1,152 @@
import abc
import dataclasses
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Protocol
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
class HumanInputError(Exception):
pass
class FormNotFoundError(HumanInputError):
pass
@dataclasses.dataclass
class FormCreateParams:
# app_id is the identifier for the app that the form belongs to.
# It is a string with uuid format.
app_id: str
# None when creating a delivery test form; set for runtime forms.
workflow_execution_id: str | None
# node_id is the identifier for a specific
# node in the graph.
#
# TODO: for node inside loop / iteration, this would
# cause problems, as a single node may be executed multiple times.
node_id: str
form_config: HumanInputNodeData
rendered_content: str
# Delivery methods already filtered by runtime context (invoke_from).
delivery_methods: Sequence[DeliveryChannelConfig]
# UI display flag computed by runtime context.
display_in_ui: bool
# resolved_default_values saves the values for defaults with
# type = VARIABLE.
#
# For type = CONSTANT, the value is not stored inside `resolved_default_values`
resolved_default_values: Mapping[str, Any]
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
# Force creating a console-only recipient for submission in Console.
console_recipient_required: bool = False
console_creator_account_id: str | None = None
# Force creating a backstage recipient for submission in Console.
backstage_recipient_required: bool = False
class HumanInputFormEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of the form."""
pass
@property
@abc.abstractmethod
def web_app_token(self) -> str | None:
"""web_app_token returns the token for submission inside webapp.
For console/debug execution, this may point to the console submission token
if the form is configured to require console delivery.
"""
# TODO: what if the users are allowed to add multiple
# webapp delivery?
pass
@property
@abc.abstractmethod
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
@property
@abc.abstractmethod
def rendered_content(self) -> str:
"""Rendered markdown content associated with the form."""
...
@property
@abc.abstractmethod
def selected_action_id(self) -> str | None:
"""Identifier of the selected user action if the form has been submitted."""
...
@property
@abc.abstractmethod
def submitted_data(self) -> Mapping[str, Any] | None:
"""Submitted form data if available."""
...
@property
@abc.abstractmethod
def submitted(self) -> bool:
"""Whether the form has been submitted."""
...
@property
@abc.abstractmethod
def status(self) -> HumanInputFormStatus:
"""Current status of the form."""
...
@property
@abc.abstractmethod
def expiration_time(self) -> datetime:
"""When the form expires."""
...
class HumanInputFormRecipientEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of this recipient."""
...
@property
@abc.abstractmethod
def token(self) -> str:
"""token returns a random string used to submit form"""
...
class HumanInputFormRepository(Protocol):
"""
Repository interface for HumanInputForm.
This interface defines the contract for accessing and manipulating
HumanInputForm data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and other implementation details should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
"""Get the form created for a given human input node in a workflow execution. Returns
`None` if the form has not been created yet."""
...
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
"""
Create a human input form from form definition.
"""
...

View File

@@ -6,15 +6,18 @@ import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, ClassVar, Protocol
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
from pydantic import BaseModel, Field
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.entities.pause_reason import PauseReason
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
@@ -133,6 +136,13 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
nodes: dict[str, NodeState] = Field(default_factory=dict)
edges: dict[str, NodeState] = Field(default_factory=dict)
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
@@ -148,10 +158,20 @@ class _GraphRuntimeStateSnapshot:
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
deferred_nodes: tuple[str, ...]
graph_node_states: dict[str, NodeState]
graph_edge_states: dict[str, NodeState]
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
"""Mutable runtime state shared across graph execution components.
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
def __init__(
self,
@@ -189,6 +209,16 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Node and edges states needed to be restored into
# graph object.
#
# These two fields are non-None only when resuming from a snapshot.
# Once the graph is attached, these two fields will be set to None.
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
self.stop_event: threading.Event = threading.Event()
if graph is not None:
@@ -210,6 +240,7 @@ class GraphRuntimeState:
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
self._apply_pending_graph_state()
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
@@ -331,8 +362,13 @@ class GraphRuntimeState:
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
"deferred_nodes": list(self._deferred_nodes),
}
graph_state = self._snapshot_graph_state()
if graph_state is not None:
snapshot["graph_state"] = graph_state
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
@@ -366,6 +402,11 @@ class GraphRuntimeState:
self._paused_nodes.add(node_id)
def get_paused_nodes(self) -> list[str]:
"""Retrieve the list of paused nodes without mutating internal state."""
return list(self._paused_nodes)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
@@ -373,6 +414,23 @@ class GraphRuntimeState:
self._paused_nodes.clear()
return nodes
def register_deferred_node(self, node_id: str) -> None:
"""Record a node that became ready during pause and should resume later."""
self._deferred_nodes.add(node_id)
def get_deferred_nodes(self) -> list[str]:
"""Retrieve deferred nodes without mutating internal state."""
return list(self._deferred_nodes)
def consume_deferred_nodes(self) -> list[str]:
"""Retrieve and clear deferred nodes awaiting resume."""
nodes = list(self._deferred_nodes)
self._deferred_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
@@ -388,7 +446,7 @@ class GraphRuntimeState:
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id)
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
@@ -434,6 +492,10 @@ class GraphRuntimeState:
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
deferred_nodes_payload = payload.get("deferred_nodes", [])
graph_state_payload = payload.get("graph_state", {}) or {}
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
return _GraphRuntimeStateSnapshot(
start_at=start_at,
@@ -447,6 +509,9 @@ class GraphRuntimeState:
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
graph_node_states=graph_node_states,
graph_edge_states=graph_edge_states,
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
@@ -462,6 +527,10 @@ class GraphRuntimeState:
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
self._deferred_nodes = set(snapshot.deferred_nodes)
self._pending_graph_node_states = snapshot.graph_node_states or None
self._pending_graph_edge_states = snapshot.graph_edge_states or None
self._apply_pending_graph_state()
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
@@ -498,3 +567,68 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump = payload
self._response_coordinator = None
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
graph = self._graph
if graph is None:
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
return _GraphStateSnapshot()
return _GraphStateSnapshot(
nodes=self._pending_graph_node_states or {},
edges=self._pending_graph_edge_states or {},
)
nodes = graph.nodes
edges = graph.edges
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
return _GraphStateSnapshot()
node_states = {}
for node_id, node in nodes.items():
if not isinstance(node_id, str):
continue
node_states[node_id] = node.state
edge_states = {}
for edge_id, edge in edges.items():
if not isinstance(edge_id, str):
continue
edge_states[edge_id] = edge.state
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
def _apply_pending_graph_state(self) -> None:
if self._graph is None:
return
if self._pending_graph_node_states:
for node_id, state in self._pending_graph_node_states.items():
node = self._graph.nodes.get(node_id)
if node is None:
continue
node.state = state
if self._pending_graph_edge_states:
for edge_id, state in self._pending_graph_edge_states.items():
edge = self._graph.edges.get(edge_id)
if edge is None:
continue
edge.state = state
self._pending_graph_node_states = None
self._pending_graph_edge_states = None
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
if not isinstance(payload, Mapping):
return {}
raw_map = payload.get(key, {})
if not isinstance(raw_map, Mapping):
return {}
result: dict[str, NodeState] = {}
for node_id, raw_state in raw_map.items():
if not isinstance(node_id, str):
continue
try:
result[node_id] = NodeState(str(raw_state))
except ValueError:
continue
return result

View File

@@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: None) -> None: ...
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
"""Convert runtime values to JSON-serializable structures."""
result = self.value_to_json_encodable_recursive(value)
if isinstance(result, Mapping) or result is None:
return result
return {}
def _to_json_encodable_recursive(self, value: Any):
def value_to_json_encodable_recursive(self, value: Any):
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
@@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
return self.value_to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
@@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
res[k] = self.value_to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
res_list.append(self.value_to_json_encodable_recursive(item))
return res_list
return value