mirror of
https://github.com/langgenius/dify.git
synced 2026-03-28 05:00:59 -04:00
97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
"""
|
|
Skip state propagation through the graph.
|
|
"""
|
|
|
|
from collections.abc import Sequence
|
|
from typing import final
|
|
|
|
from graphon.graph import Edge, Graph
|
|
|
|
from ..graph_state_manager import GraphStateManager
|
|
|
|
|
|
@final
|
|
class SkipPropagator:
|
|
"""
|
|
Propagates skip states through the graph.
|
|
|
|
When a node is skipped, this ensures all downstream nodes
|
|
that depend solely on it are also skipped.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
graph: Graph,
|
|
state_manager: GraphStateManager,
|
|
) -> None:
|
|
"""
|
|
Initialize the skip propagator.
|
|
|
|
Args:
|
|
graph: The workflow graph
|
|
state_manager: Unified state manager
|
|
"""
|
|
self._graph = graph
|
|
self._state_manager = state_manager
|
|
|
|
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
|
"""
|
|
Recursively propagate skip state from a skipped edge.
|
|
|
|
Rules:
|
|
- If a node has any UNKNOWN incoming edges, stop processing
|
|
- If all incoming edges are SKIPPED, skip the node and its edges
|
|
- If any incoming edge is TAKEN, the node may still execute
|
|
|
|
Args:
|
|
edge_id: The ID of the skipped edge to start from
|
|
"""
|
|
downstream_node_id = self._graph.edges[edge_id].head
|
|
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
|
|
|
# Analyze edge states
|
|
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
|
|
|
# Stop if there are unknown edges (not yet processed)
|
|
if edge_states["has_unknown"]:
|
|
return
|
|
|
|
# If any edge is taken, node may still execute
|
|
if edge_states["has_taken"]:
|
|
# Enqueue node
|
|
self._state_manager.enqueue_node(downstream_node_id)
|
|
self._state_manager.start_execution(downstream_node_id)
|
|
return
|
|
|
|
# All edges are skipped, propagate skip to this node
|
|
if edge_states["all_skipped"]:
|
|
self._propagate_skip_to_node(downstream_node_id)
|
|
|
|
def _propagate_skip_to_node(self, node_id: str) -> None:
|
|
"""
|
|
Mark a node and all its outgoing edges as skipped.
|
|
|
|
Args:
|
|
node_id: The ID of the node to skip
|
|
"""
|
|
# Mark node as skipped
|
|
self._state_manager.mark_node_skipped(node_id)
|
|
|
|
# Mark all outgoing edges as skipped and propagate
|
|
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
|
for edge in outgoing_edges:
|
|
self._state_manager.mark_edge_skipped(edge.id)
|
|
# Recursively propagate skip
|
|
self.propagate_skip_from_edge(edge.id)
|
|
|
|
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
|
"""
|
|
Skip all paths from unselected branch edges.
|
|
|
|
Args:
|
|
unselected_edges: List of edges not taken by the branch
|
|
"""
|
|
for edge in unselected_edges:
|
|
self._state_manager.mark_edge_skipped(edge.id)
|
|
self.propagate_skip_from_edge(edge.id)
|