mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 07:00:37 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -6,13 +6,12 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@@ -60,7 +59,7 @@ from dify_graph.node_events import (
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from dify_graph.utils.datetime_utils import naive_utc_now
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
@@ -68,23 +67,6 @@ _MISSING_RUN_CONTEXT_VALUE = object()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyRunContextProtocol(Protocol):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: Any
|
||||
invoke_from: Any
|
||||
|
||||
|
||||
class _MappingDifyRunContext:
|
||||
def __init__(self, mapping: Mapping[str, Any]) -> None:
|
||||
self.tenant_id = str(mapping["tenant_id"])
|
||||
self.app_id = str(mapping["app_id"])
|
||||
self.user_id = str(mapping["user_id"])
|
||||
self.user_from = mapping["user_from"]
|
||||
self.invoke_from = mapping["invoke_from"]
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
@@ -177,8 +159,9 @@ class Node(Generic[NodeDataT]):
|
||||
# Skip base class itself
|
||||
if cls is Node:
|
||||
return
|
||||
# Only register production node implementations defined under the
|
||||
# canonical workflow namespaces.
|
||||
# Only treat nodes from the base dify_graph package as production
|
||||
# registrations. Higher-layer packages may still register subclasses,
|
||||
# but dify_graph itself should not know their module identities.
|
||||
# This prevents test helper subclasses from polluting the global registry and
|
||||
# accidentally overriding real node types (e.g., a test Answer node).
|
||||
module_name = getattr(cls, "__module__", "")
|
||||
@@ -186,7 +169,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_type = cls.node_type
|
||||
version = cls.version()
|
||||
bucket = Node._registry.setdefault(node_type, {})
|
||||
if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")):
|
||||
if module_name.startswith("dify_graph.nodes."):
|
||||
# Production node definitions take precedence and may override
|
||||
bucket[version] = cls # type: ignore[index]
|
||||
else:
|
||||
@@ -299,25 +282,6 @@ class Node(Generic[NodeDataT]):
|
||||
raise ValueError(f"run_context missing required key: {key}")
|
||||
return value
|
||||
|
||||
def require_dify_context(self) -> DifyRunContextProtocol:
|
||||
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
|
||||
if isinstance(raw_ctx, Mapping):
|
||||
missing_keys = [
|
||||
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
|
||||
]
|
||||
if missing_keys:
|
||||
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
|
||||
return _MappingDifyRunContext(raw_ctx)
|
||||
|
||||
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
|
||||
if not hasattr(raw_ctx, attr):
|
||||
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
|
||||
|
||||
return cast(DifyRunContextProtocol, raw_ctx)
|
||||
|
||||
@property
|
||||
def execution_id(self) -> str:
|
||||
return self._node_execution_id
|
||||
@@ -793,16 +757,11 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
|
||||
retriever_resources = [
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
]
|
||||
return NodeRunRetrieverResourceEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
retriever_resources=retriever_resources,
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
node_version=self.version(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user