feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -6,13 +6,12 @@ from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from types import MappingProxyType
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from dify_graph.entities import GraphInitParams
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import (
ErrorStrategy,
NodeExecutionType,
@@ -60,7 +59,7 @@ from dify_graph.node_events import (
StreamCompletedEvent,
)
from dify_graph.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from dify_graph.utils.datetime_utils import naive_utc_now
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
_MISSING_RUN_CONTEXT_VALUE = object()
@@ -68,23 +67,6 @@ _MISSING_RUN_CONTEXT_VALUE = object()
logger = logging.getLogger(__name__)
class DifyRunContextProtocol(Protocol):
tenant_id: str
app_id: str
user_id: str
user_from: Any
invoke_from: Any
class _MappingDifyRunContext:
def __init__(self, mapping: Mapping[str, Any]) -> None:
self.tenant_id = str(mapping["tenant_id"])
self.app_id = str(mapping["app_id"])
self.user_id = str(mapping["user_id"])
self.user_from = mapping["user_from"]
self.invoke_from = mapping["invoke_from"]
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
@@ -177,8 +159,9 @@ class Node(Generic[NodeDataT]):
# Skip base class itself
if cls is Node:
return
# Only register production node implementations defined under the
# canonical workflow namespaces.
# Only treat nodes from the base dify_graph package as production
# registrations. Higher-layer packages may still register subclasses,
# but dify_graph itself should not know their module identities.
# This prevents test helper subclasses from polluting the global registry and
# accidentally overriding real node types (e.g., a test Answer node).
module_name = getattr(cls, "__module__", "")
@@ -186,7 +169,7 @@ class Node(Generic[NodeDataT]):
node_type = cls.node_type
version = cls.version()
bucket = Node._registry.setdefault(node_type, {})
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
if module_name.startswith("dify_graph.nodes."):
# Production node definitions take precedence and may override
bucket[version] = cls # type: ignore[index]
else:
@@ -299,25 +282,6 @@ class Node(Generic[NodeDataT]):
raise ValueError(f"run_context missing required key: {key}")
return value
def require_dify_context(self) -> DifyRunContextProtocol:
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
if raw_ctx is None:
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
if isinstance(raw_ctx, Mapping):
missing_keys = [
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
]
if missing_keys:
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
return _MappingDifyRunContext(raw_ctx)
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
if not hasattr(raw_ctx, attr):
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
return cast(DifyRunContextProtocol, raw_ctx)
@property
def execution_id(self) -> str:
return self._node_execution_id
@@ -793,16 +757,11 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
retriever_resources = [
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
]
return NodeRunRetrieverResourceEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=retriever_resources,
retriever_resources=event.retriever_resources,
context=event.context,
node_version=self.version(),
)