feat(api): automatically NODE_TYPE_CLASSES_MAPPING generation from node class definitions (#28525)

This commit is contained in:
wangxiaolei
2025-12-01 14:14:19 +08:00
committed by GitHub
parent 2f8cb2a1af
commit d162f7e5ef
11 changed files with 245 additions and 189 deletions

View File

@@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
@@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,

View File

@@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel):
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
# tracing instance; use forward ref to avoid circular import at import time
trace_manager: Optional["TraceQueueManager"] = None
@@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager
# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports.
# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to
# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup.
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild()
ChatAppGenerateEntity.model_rebuild()
CompletionAppGenerateEntity.model_rebuild()
AgentChatAppGenerateEntity.model_rebuild()
AdvancedChatAppGenerateEntity.model_rebuild()
WorkflowAppGenerateEntity.model_rebuild()
RagPipelineGenerateEntity.model_rebuild()
# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet.
class _TraceQueueManagerStub:
pass
_ns = {"TraceQueueManager": _TraceQueueManagerStub}
AppGenerateEntity.model_rebuild(_types_namespace=_ns)
EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns)
ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns)
ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns)
AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns)
RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns)

View File

@@ -1,7 +1,11 @@
import importlib
import logging
import operator
import pkgutil
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
@@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]):
cls._node_data_type = node_data_type
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under core.workflow.nodes.*
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
# Only register concrete subclasses that define node_type and version()
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith("core.workflow.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
# External/test subclasses may register but must not override production
bucket.setdefault(version, cls) # type: ignore[index]
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
version_keys = [v for v in bucket if v != "latest"]
numeric_pairs: list[tuple[str, int]] = []
for v in version_keys:
numeric_pairs.append((v, int(v)))
if numeric_pairs:
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
else:
latest_key = max(version_keys) if version_keys else version
bucket["latest"] = bucket[latest_key]
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
"""
@@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]):
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
def __init__(
self,
id: str,
@@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]):
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@classmethod
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
Import all modules under core.workflow.nodes so subclasses register themselves on import.
Then we return a readonly view of the registry to avoid accidental mutation.
"""
# Import all node modules to ensure they are loaded (thus registered)
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports
# e.g. node_factory imports node_mapping which builds the mapping here.
if _modname in {
"core.workflow.nodes.node_factory",
"core.workflow.nodes.node_mapping",
}:
continue
importlib.import_module(_modname)
# Return a readonly view so callers can't mutate the registry by accident
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
@property
def retry(self) -> bool:
return False

View File

@@ -1,165 +1,9 @@
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.start import StartNode
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.trigger_plugin import TriggerEventNode
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
LATEST_VERSION = "latest"
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
# Specifically, if you have introduced new node types, you should add them here.
#
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
"1": StartNode,
},
NodeType.END: {
LATEST_VERSION: EndNode,
"1": EndNode,
},
NodeType.ANSWER: {
LATEST_VERSION: AnswerNode,
"1": AnswerNode,
},
NodeType.LLM: {
LATEST_VERSION: LLMNode,
"1": LLMNode,
},
NodeType.KNOWLEDGE_RETRIEVAL: {
LATEST_VERSION: KnowledgeRetrievalNode,
"1": KnowledgeRetrievalNode,
},
NodeType.IF_ELSE: {
LATEST_VERSION: IfElseNode,
"1": IfElseNode,
},
NodeType.CODE: {
LATEST_VERSION: CodeNode,
"1": CodeNode,
},
NodeType.TEMPLATE_TRANSFORM: {
LATEST_VERSION: TemplateTransformNode,
"1": TemplateTransformNode,
},
NodeType.QUESTION_CLASSIFIER: {
LATEST_VERSION: QuestionClassifierNode,
"1": QuestionClassifierNode,
},
NodeType.HTTP_REQUEST: {
LATEST_VERSION: HttpRequestNode,
"1": HttpRequestNode,
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
},
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
LATEST_VERSION: VariableAggregatorNode,
"1": VariableAggregatorNode,
}, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: {
LATEST_VERSION: IterationNode,
"1": IterationNode,
},
NodeType.ITERATION_START: {
LATEST_VERSION: IterationStartNode,
"1": IterationStartNode,
},
NodeType.LOOP: {
LATEST_VERSION: LoopNode,
"1": LoopNode,
},
NodeType.LOOP_START: {
LATEST_VERSION: LoopStartNode,
"1": LoopStartNode,
},
NodeType.LOOP_END: {
LATEST_VERSION: LoopEndNode,
"1": LoopEndNode,
},
NodeType.PARAMETER_EXTRACTOR: {
LATEST_VERSION: ParameterExtractorNode,
"1": ParameterExtractorNode,
},
NodeType.VARIABLE_ASSIGNER: {
LATEST_VERSION: VariableAssignerNodeV2,
"1": VariableAssignerNodeV1,
"2": VariableAssignerNodeV2,
},
NodeType.DOCUMENT_EXTRACTOR: {
LATEST_VERSION: DocumentExtractorNode,
"1": DocumentExtractorNode,
},
NodeType.LIST_OPERATOR: {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
# This is an issue that caused problems before.
# Logically, we shouldn't use two different versions to point to the same class here,
# but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
NodeType.TRIGGER_WEBHOOK: {
LATEST_VERSION: TriggerWebhookNode,
"1": TriggerWebhookNode,
},
NodeType.TRIGGER_PLUGIN: {
LATEST_VERSION: TriggerEventNode,
"1": TriggerEventNode,
},
NodeType.TRIGGER_SCHEDULE: {
LATEST_VERSION: TriggerScheduleNode,
"1": TriggerScheduleNode,
},
}
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()

View File

@@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
@@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]):
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
}
if usage.total_tokens > 0:
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
@@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]):
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
if isinstance(tool_runtime, WorkflowTool):
return tool_runtime.latest_usage
# Avoid importing WorkflowTool at module import time; rely on duck typing
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
latest = getattr(tool_runtime, "latest_usage", None)
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
# for any name, so we must type-check here.
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
# Allow dict payloads from external runtimes
return LLMUsage.model_validate(latest)
# Fallback to empty usage when attribute is missing or not a valid payload
return LLMUsage.empty_usage()
@classmethod

View File

@@ -0,0 +1,49 @@
import logging
from dify_app import DifyApp
def is_enabled() -> bool:
return True
def init_app(app: DifyApp):
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
Safe to run multiple times.
"""
logger = logging.getLogger(__name__)
try:
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
EasyUIBasedAppGenerateEntity,
RagPipelineGenerateEntity,
WorkflowAppGenerateEntity,
)
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
ns = {"TraceQueueManager": TraceQueueManager}
for Model in (
AppGenerateEntity,
EasyUIBasedAppGenerateEntity,
ConversationAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
WorkflowAppGenerateEntity,
RagPipelineGenerateEntity,
):
try:
Model.model_rebuild(_types_namespace=ns)
except Exception as e:
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
except Exception as e:
# Don't block app startup; just log at debug level.
logger.debug("ext_forward_refs init skipped: %s", e)

View File

@@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
@classmethod
def version(cls) -> str:
return "test"
return "1"
def __init__(
self,

View File

@@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock LLM node."""
@@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock agent node."""
@@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock tool node."""
@@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock knowledge retrieval node."""
@@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock HTTP request node."""
@@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock question classifier node."""
@@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock parameter extractor node."""
@@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> Generator:
"""Execute mock document extractor node."""
@@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _create_graph_engine(self, index: int, item: Any):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _create_graph_engine(self, start_at, root_node_id: str):
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
@@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> NodeRunResult:
"""Execute mock template transform node."""
@@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
@classmethod
def version(cls) -> str:
"""Return the version of this mock node."""
return "mock-1"
return "1"
def _run(self) -> NodeRunResult:
"""Execute mock code node."""

View File

@@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
type_version_set: set[tuple[NodeType, str]] = set()
for cls in classes:
# Only validate production node classes; skip test-defined subclasses and external helpers
module_name = getattr(cls, "__module__", "")
if not module_name.startswith("core."):
continue
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
node_type = cls.node_type

View File

@@ -0,0 +1,84 @@
import types
from collections.abc import Mapping
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
# Import concrete nodes we will assert on (numeric version path)
from core.workflow.nodes.variable_assigner.v1.node import (
VariableAssignerNode as VariableAssignerV1,
)
from core.workflow.nodes.variable_assigner.v2.node import (
VariableAssignerNode as VariableAssignerV2,
)
def test_variable_assigner_latest_prefers_highest_numeric_version():
# Act
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert basic presence
assert NodeType.VARIABLE_ASSIGNER in mapping
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
# Both concrete versions must be present
assert va_versions.get("1") is VariableAssignerV1
assert va_versions.get("2") is VariableAssignerV2
# And latest should point to numerically-highest version ("2")
assert va_versions.get("latest") is VariableAssignerV2
def test_latest_prefers_highest_numeric_version():
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
# that has no concrete implementations in production to avoid interference.
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
def init_node_data(self, data):
pass
def _run(self):
raise NotImplementedError
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self):
return None
def _get_retry_config(self):
return types.SimpleNamespace() # not used
def _get_title(self) -> str:
return "version1"
def _get_description(self):
return None
def _get_default_value_dict(self):
return {}
def get_base_node_data(self):
return types.SimpleNamespace(title="version1")
class _Version2(_Version1): # type: ignore[misc]
@classmethod
def version(cls) -> str:
return "2"
def _get_title(self) -> str:
return "version2"
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
assert legacy_versions.get("1") is _Version1
assert legacy_versions.get("2") is _Version2
assert legacy_versions.get("latest") is _Version2

View File

@@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
@classmethod
def version(cls) -> str:
return "sample-test"
return "1"
def _run(self):
raise NotImplementedError