mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
feat(api): automatically NODE_TYPE_CLASSES_MAPPING generation from node class definitions (#28525)
This commit is contained in:
@@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
@@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
|
||||
@@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
# tracing instance; use forward ref to avoid circular import at import time
|
||||
trace_manager: Optional["TraceQueueManager"] = None
|
||||
|
||||
|
||||
@@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports.
|
||||
# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to
|
||||
# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup.
|
||||
|
||||
# Rebuild models that use forward references
|
||||
AppGenerateEntity.model_rebuild()
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild()
|
||||
ConversationAppGenerateEntity.model_rebuild()
|
||||
ChatAppGenerateEntity.model_rebuild()
|
||||
CompletionAppGenerateEntity.model_rebuild()
|
||||
AgentChatAppGenerateEntity.model_rebuild()
|
||||
AdvancedChatAppGenerateEntity.model_rebuild()
|
||||
WorkflowAppGenerateEntity.model_rebuild()
|
||||
RagPipelineGenerateEntity.model_rebuild()
|
||||
|
||||
# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet.
|
||||
class _TraceQueueManagerStub:
|
||||
pass
|
||||
|
||||
|
||||
_ns = {"TraceQueueManager": _TraceQueueManagerStub}
|
||||
AppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
import pkgutil
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under core.workflow.nodes.*
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
# Only register concrete subclasses that define node_type and version()
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith("core.workflow.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
# External/test subclasses may register but must not override production
|
||||
bucket.setdefault(version, cls) # type: ignore[index]
|
||||
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
|
||||
version_keys = [v for v in bucket if v != "latest"]
|
||||
numeric_pairs: list[tuple[str, int]] = []
|
||||
for v in version_keys:
|
||||
numeric_pairs.append((v, int(v)))
|
||||
if numeric_pairs:
|
||||
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
|
||||
else:
|
||||
latest_key = max(version_keys) if version_keys else version
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
@@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]):
|
||||
# in `api/core/workflow/nodes/__init__.py`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
# Return a readonly view so callers can't mutate the registry by accident
|
||||
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -1,165 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.if_else import IfElseNode
|
||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.trigger_plugin import TriggerEventNode
|
||||
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
||||
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
||||
# Specifically, if you have introduced new node types, you should add them here.
|
||||
#
|
||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
||||
# hook. Try to avoid duplication of node information.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.LOOP: {
|
||||
LATEST_VERSION: LoopNode,
|
||||
"1": LoopNode,
|
||||
},
|
||||
NodeType.LOOP_START: {
|
||||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.LOOP_END: {
|
||||
LATEST_VERSION: LoopEndNode,
|
||||
"1": LoopEndNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.HUMAN_INPUT: {
|
||||
LATEST_VERSION: HumanInputNode,
|
||||
"1": HumanInputNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
LATEST_VERSION: DatasourceNode,
|
||||
"1": DatasourceNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_INDEX: {
|
||||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.TRIGGER_WEBHOOK: {
|
||||
LATEST_VERSION: TriggerWebhookNode,
|
||||
"1": TriggerWebhookNode,
|
||||
},
|
||||
NodeType.TRIGGER_PLUGIN: {
|
||||
LATEST_VERSION: TriggerEventNode,
|
||||
"1": TriggerEventNode,
|
||||
},
|
||||
NodeType.TRIGGER_SCHEDULE: {
|
||||
LATEST_VERSION: TriggerScheduleNode,
|
||||
"1": TriggerScheduleNode,
|
||||
},
|
||||
}
|
||||
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
@@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
@@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
}
|
||||
if usage.total_tokens > 0:
|
||||
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
@@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
if isinstance(tool_runtime, WorkflowTool):
|
||||
return tool_runtime.latest_usage
|
||||
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
||||
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
||||
latest = getattr(tool_runtime, "latest_usage", None)
|
||||
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
||||
# for any name, so we must type-check here.
|
||||
if isinstance(latest, LLMUsage):
|
||||
return latest
|
||||
if isinstance(latest, dict):
|
||||
# Allow dict payloads from external runtimes
|
||||
return LLMUsage.model_validate(latest)
|
||||
# Fallback to empty usage when attribute is missing or not a valid payload
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
|
||||
49
api/extensions/ext_forward_refs.py
Normal file
49
api/extensions/ext_forward_refs.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import logging
|
||||
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
|
||||
|
||||
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
|
||||
Safe to run multiple times.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
RagPipelineGenerateEntity,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
|
||||
|
||||
ns = {"TraceQueueManager": TraceQueueManager}
|
||||
for Model in (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
WorkflowAppGenerateEntity,
|
||||
RagPipelineGenerateEntity,
|
||||
):
|
||||
try:
|
||||
Model.model_rebuild(_types_namespace=ns)
|
||||
except Exception as e:
|
||||
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
|
||||
except Exception as e:
|
||||
# Don't block app startup; just log at debug level.
|
||||
logger.debug("ext_forward_refs init skipped: %s", e)
|
||||
@@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "test"
|
||||
return "1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock LLM node."""
|
||||
@@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock agent node."""
|
||||
@@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock tool node."""
|
||||
@@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock knowledge retrieval node."""
|
||||
@@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock HTTP request node."""
|
||||
@@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock question classifier node."""
|
||||
@@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock parameter extractor node."""
|
||||
@@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""Execute mock document extractor node."""
|
||||
@@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _create_graph_engine(self, index: int, item: Any):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
@@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||
@@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock template transform node."""
|
||||
@@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
"""Return the version of this mock node."""
|
||||
return "mock-1"
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""Execute mock code node."""
|
||||
|
||||
@@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
type_version_set: set[tuple[NodeType, str]] = set()
|
||||
|
||||
for cls in classes:
|
||||
# Only validate production node classes; skip test-defined subclasses and external helpers
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
if not module_name.startswith("core."):
|
||||
continue
|
||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||
node_type = cls.node_type
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
import types
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Import concrete nodes we will assert on (numeric version path)
|
||||
from core.workflow.nodes.variable_assigner.v1.node import (
|
||||
VariableAssignerNode as VariableAssignerV1,
|
||||
)
|
||||
from core.workflow.nodes.variable_assigner.v2.node import (
|
||||
VariableAssignerNode as VariableAssignerV2,
|
||||
)
|
||||
|
||||
|
||||
def test_variable_assigner_latest_prefers_highest_numeric_version():
|
||||
# Act
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert basic presence
|
||||
assert NodeType.VARIABLE_ASSIGNER in mapping
|
||||
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
|
||||
|
||||
# Both concrete versions must be present
|
||||
assert va_versions.get("1") is VariableAssignerV1
|
||||
assert va_versions.get("2") is VariableAssignerV2
|
||||
|
||||
# And latest should point to numerically-highest version ("2")
|
||||
assert va_versions.get("latest") is VariableAssignerV2
|
||||
|
||||
|
||||
def test_latest_prefers_highest_numeric_version():
|
||||
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
|
||||
# that has no concrete implementations in production to avoid interference.
|
||||
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
|
||||
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
|
||||
|
||||
def init_node_data(self, data):
|
||||
pass
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return None
|
||||
|
||||
def _get_retry_config(self):
|
||||
return types.SimpleNamespace() # not used
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version1"
|
||||
|
||||
def _get_description(self):
|
||||
return None
|
||||
|
||||
def _get_default_value_dict(self):
|
||||
return {}
|
||||
|
||||
def get_base_node_data(self):
|
||||
return types.SimpleNamespace(title="version1")
|
||||
|
||||
class _Version2(_Version1): # type: ignore[misc]
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return "version2"
|
||||
|
||||
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
|
||||
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
|
||||
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
|
||||
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
|
||||
|
||||
assert legacy_versions.get("1") is _Version1
|
||||
assert legacy_versions.get("2") is _Version2
|
||||
assert legacy_versions.get("latest") is _Version2
|
||||
@@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "sample-test"
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user