mirror of
https://github.com/langgenius/dify.git
synced 2026-02-15 04:01:30 -05:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@@ -6,15 +6,18 @@ import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Protocol
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
@@ -133,6 +136,13 @@ class GraphProtocol(Protocol):
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
nodes: dict[str, NodeState] = Field(default_factory=dict)
|
||||
edges: dict[str, NodeState] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _GraphRuntimeStateSnapshot:
|
||||
"""Immutable view of a serialized runtime state snapshot."""
|
||||
@@ -148,10 +158,20 @@ class _GraphRuntimeStateSnapshot:
|
||||
graph_execution_dump: str | None
|
||||
response_coordinator_dump: str | None
|
||||
paused_nodes: tuple[str, ...]
|
||||
deferred_nodes: tuple[str, ...]
|
||||
graph_node_states: dict[str, NodeState]
|
||||
graph_edge_states: dict[str, NodeState]
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
"""Mutable runtime state shared across graph execution components."""
|
||||
"""Mutable runtime state shared across graph execution components.
|
||||
|
||||
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
|
||||
including scheduling details, variable values, and timing information.
|
||||
|
||||
Values that are initialized prior to workflow execution and remain constant
|
||||
throughout the execution should be part of `GraphInitParams` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -189,6 +209,16 @@ class GraphRuntimeState:
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
#
|
||||
# These two fields are non-None only when resuming from a snapshot.
|
||||
# Once the graph is attached, these two fields will be set to None.
|
||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
@@ -210,6 +240,7 @@ class GraphRuntimeState:
|
||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
@@ -331,8 +362,13 @@ class GraphRuntimeState:
|
||||
"ready_queue": self.ready_queue.dumps(),
|
||||
"graph_execution": self.graph_execution.dumps(),
|
||||
"paused_nodes": list(self._paused_nodes),
|
||||
"deferred_nodes": list(self._deferred_nodes),
|
||||
}
|
||||
|
||||
graph_state = self._snapshot_graph_state()
|
||||
if graph_state is not None:
|
||||
snapshot["graph_state"] = graph_state
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
@@ -366,6 +402,11 @@ class GraphRuntimeState:
|
||||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def get_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve the list of paused nodes without mutating internal state."""
|
||||
|
||||
return list(self._paused_nodes)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
@@ -373,6 +414,23 @@ class GraphRuntimeState:
|
||||
self._paused_nodes.clear()
|
||||
return nodes
|
||||
|
||||
def register_deferred_node(self, node_id: str) -> None:
|
||||
"""Record a node that became ready during pause and should resume later."""
|
||||
|
||||
self._deferred_nodes.add(node_id)
|
||||
|
||||
def get_deferred_nodes(self) -> list[str]:
|
||||
"""Retrieve deferred nodes without mutating internal state."""
|
||||
|
||||
return list(self._deferred_nodes)
|
||||
|
||||
def consume_deferred_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear deferred nodes awaiting resume."""
|
||||
|
||||
nodes = list(self._deferred_nodes)
|
||||
self._deferred_nodes.clear()
|
||||
return nodes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
@@ -388,7 +446,7 @@ class GraphRuntimeState:
|
||||
graph_execution_cls = module.GraphExecution
|
||||
workflow_id = self._pending_graph_execution_workflow_id or ""
|
||||
self._pending_graph_execution_workflow_id = None
|
||||
return graph_execution_cls(workflow_id=workflow_id)
|
||||
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
|
||||
|
||||
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
|
||||
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
|
||||
@@ -434,6 +492,10 @@ class GraphRuntimeState:
|
||||
graph_execution_payload = payload.get("graph_execution")
|
||||
response_payload = payload.get("response_coordinator")
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
deferred_nodes_payload = payload.get("deferred_nodes", [])
|
||||
graph_state_payload = payload.get("graph_state", {}) or {}
|
||||
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
|
||||
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
|
||||
|
||||
return _GraphRuntimeStateSnapshot(
|
||||
start_at=start_at,
|
||||
@@ -447,6 +509,9 @@ class GraphRuntimeState:
|
||||
graph_execution_dump=graph_execution_payload,
|
||||
response_coordinator_dump=response_payload,
|
||||
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
|
||||
graph_node_states=graph_node_states,
|
||||
graph_edge_states=graph_edge_states,
|
||||
)
|
||||
|
||||
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||
@@ -462,6 +527,10 @@ class GraphRuntimeState:
|
||||
self._restore_graph_execution(snapshot.graph_execution_dump)
|
||||
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||
self._paused_nodes = set(snapshot.paused_nodes)
|
||||
self._deferred_nodes = set(snapshot.deferred_nodes)
|
||||
self._pending_graph_node_states = snapshot.graph_node_states or None
|
||||
self._pending_graph_edge_states = snapshot.graph_edge_states or None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||
if payload is not None:
|
||||
@@ -498,3 +567,68 @@ class GraphRuntimeState:
|
||||
|
||||
self._pending_response_coordinator_dump = payload
|
||||
self._response_coordinator = None
|
||||
|
||||
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
|
||||
graph = self._graph
|
||||
if graph is None:
|
||||
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
|
||||
return _GraphStateSnapshot()
|
||||
return _GraphStateSnapshot(
|
||||
nodes=self._pending_graph_node_states or {},
|
||||
edges=self._pending_graph_edge_states or {},
|
||||
)
|
||||
|
||||
nodes = graph.nodes
|
||||
edges = graph.edges
|
||||
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
|
||||
return _GraphStateSnapshot()
|
||||
|
||||
node_states = {}
|
||||
for node_id, node in nodes.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
node_states[node_id] = node.state
|
||||
|
||||
edge_states = {}
|
||||
for edge_id, edge in edges.items():
|
||||
if not isinstance(edge_id, str):
|
||||
continue
|
||||
edge_states[edge_id] = edge.state
|
||||
|
||||
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
|
||||
|
||||
def _apply_pending_graph_state(self) -> None:
|
||||
if self._graph is None:
|
||||
return
|
||||
if self._pending_graph_node_states:
|
||||
for node_id, state in self._pending_graph_node_states.items():
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
node.state = state
|
||||
if self._pending_graph_edge_states:
|
||||
for edge_id, state in self._pending_graph_edge_states.items():
|
||||
edge = self._graph.edges.get(edge_id)
|
||||
if edge is None:
|
||||
continue
|
||||
edge.state = state
|
||||
|
||||
self._pending_graph_node_states = None
|
||||
self._pending_graph_edge_states = None
|
||||
|
||||
|
||||
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return {}
|
||||
raw_map = payload.get(key, {})
|
||||
if not isinstance(raw_map, Mapping):
|
||||
return {}
|
||||
result: dict[str, NodeState] = {}
|
||||
for node_id, raw_state in raw_map.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
try:
|
||||
result[node_id] = NodeState(str(raw_state))
|
||||
except ValueError:
|
||||
continue
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user