mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
Add workflow graph validation checks (#27106)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from ..runtime.graph_runtime_state import GraphRuntimeState
|
||||
from ..runtime.variable_pool import VariablePool
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
@@ -6,6 +8,8 @@ from .workflow_node_execution import WorkflowNodeExecution
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"GraphRuntimeState",
|
||||
"VariablePool",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
@@ -3,11 +3,12 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .edge import Edge
|
||||
from .validation import get_graph_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -201,6 +202,17 @@ class Graph:
|
||||
|
||||
return GraphBuilder(graph_cls=cls)
|
||||
|
||||
@classmethod
|
||||
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
|
||||
"""
|
||||
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
"""
|
||||
for node in nodes.values():
|
||||
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
@@ -307,6 +319,9 @@ class Graph:
|
||||
# Create node instances
|
||||
nodes = cls._create_node_instances(node_configs_map, node_factory)
|
||||
|
||||
# Promote fail-branch nodes to branch execution type at graph level
|
||||
cls._promote_fail_branch_nodes(nodes)
|
||||
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
@@ -314,7 +329,7 @@ class Graph:
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
graph = cls(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
in_edges=in_edges,
|
||||
@@ -322,6 +337,11 @@ class Graph:
|
||||
root_node=root_node,
|
||||
)
|
||||
|
||||
# Validate the graph structure using built-in validators
|
||||
get_graph_validator().validate(graph)
|
||||
|
||||
return graph
|
||||
|
||||
@property
|
||||
def node_ids(self) -> list[str]:
|
||||
"""
|
||||
|
||||
125
api/core/workflow/graph/validation.py
Normal file
125
api/core/workflow/graph/validation.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidationIssue:
|
||||
"""Immutable value object describing a single validation issue."""
|
||||
|
||||
code: str
|
||||
message: str
|
||||
node_id: str | None = None
|
||||
|
||||
|
||||
class GraphValidationError(ValueError):
|
||||
"""Raised when graph validation fails."""
|
||||
|
||||
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
|
||||
if not issues:
|
||||
raise ValueError("GraphValidationError requires at least one issue.")
|
||||
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
|
||||
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GraphValidationRule(Protocol):
|
||||
"""Protocol that individual validation rules must satisfy."""
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
"""Validate the provided graph and return any discovered issues."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _EdgeEndpointValidator:
|
||||
"""Ensures all edges reference existing nodes."""
|
||||
|
||||
missing_node_code: str = "MISSING_NODE"
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
|
||||
node_id=edge.tail,
|
||||
)
|
||||
)
|
||||
if edge.head not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.missing_node_code,
|
||||
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
|
||||
node_id=edge.head,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _RootNodeValidator:
|
||||
"""Validates root node invariants."""
|
||||
|
||||
invalid_root_code: str = "INVALID_ROOT"
|
||||
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
|
||||
|
||||
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
|
||||
root_node = graph.root_node
|
||||
issues: list[GraphValidationIssue] = []
|
||||
if root_node.id not in graph.nodes:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' is missing from the node registry.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
node_type = getattr(root_node, "node_type", None)
|
||||
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
|
||||
issues.append(
|
||||
GraphValidationIssue(
|
||||
code=self.invalid_root_code,
|
||||
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
|
||||
node_id=root_node.id,
|
||||
)
|
||||
)
|
||||
return issues
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class GraphValidator:
|
||||
"""Coordinates execution of graph validation rules."""
|
||||
|
||||
rules: tuple[GraphValidationRule, ...]
|
||||
|
||||
def validate(self, graph: Graph) -> None:
|
||||
"""Validate the graph against all configured rules."""
|
||||
issues: list[GraphValidationIssue] = []
|
||||
for rule in self.rules:
|
||||
issues.extend(rule.validate(graph))
|
||||
|
||||
if issues:
|
||||
raise GraphValidationError(issues)
|
||||
|
||||
|
||||
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
|
||||
_EdgeEndpointValidator(),
|
||||
_RootNodeValidator(),
|
||||
)
|
||||
|
||||
|
||||
def get_graph_validator() -> GraphValidator:
|
||||
"""Construct the validator composed of default rules."""
|
||||
return GraphValidator(_DEFAULT_RULES)
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str, is_str_dict
|
||||
@@ -82,8 +82,4 @@ class DifyNodeFactory(NodeFactory):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
# If node has fail branch, change execution type to branch
|
||||
if node_instance.error_strategy == ErrorStrategy.FAIL_BRANCH:
|
||||
node_instance.execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
return node_instance
|
||||
|
||||
Reference in New Issue
Block a user