Files
dify/api/graphon/graph_engine/graph_traversal/skip_propagator.py

97 lines
2.9 KiB
Python

"""
Skip state propagation through the graph.
"""
from collections.abc import Sequence
from typing import final
from graphon.graph import Edge, Graph
from ..graph_state_manager import GraphStateManager
@final
class SkipPropagator:
"""
Propagates skip states through the graph.
When a node is skipped, this ensures all downstream nodes
that depend solely on it are also skipped.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
state_manager: Unified state manager
"""
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
Recursively propagate skip state from a skipped edge.
Rules:
- If a node has any UNKNOWN incoming edges, stop processing
- If all incoming edges are SKIPPED, skip the node and its edges
- If any incoming edge is TAKEN, the node may still execute
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
return
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
if edge_states["all_skipped"]:
self._propagate_skip_to_node(downstream_node_id)
def _propagate_skip_to_node(self, node_id: str) -> None:
"""
Mark a node and all its outgoing edges as skipped.
Args:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.
Args:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)