feat: extract model runtime

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2026-03-15 15:34:47 +08:00
parent 3d5a29462e
commit fbb74a4af9
178 changed files with 4343 additions and 2134 deletions

View File

@@ -1,12 +1,6 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import (
BuiltinNodeTypes,
@@ -20,10 +14,14 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEven
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.nodes.runtime import ToolNodeRuntimeProtocol
from dify_graph.nodes.tool_runtime_entities import (
ToolRuntimeHandle,
ToolRuntimeMessage,
ToolRuntimeParameter,
)
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
from dify_graph.variables.variables import ArrayAnyVariable
from factories import file_factory
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
from .exc import (
@@ -52,6 +50,7 @@ class ToolNode(Node[ToolNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
tool_file_manager_factory: ToolFileManagerProtocol,
runtime: ToolNodeRuntimeProtocol | None = None,
):
super().__init__(
id=id,
@@ -60,6 +59,9 @@ class ToolNode(Node[ToolNodeData]):
graph_runtime_state=graph_runtime_state,
)
self._tool_file_manager_factory = tool_file_manager_factory
if runtime is None:
raise ValueError("runtime is required")
self._runtime = runtime
@classmethod
def version(cls) -> str:
@@ -73,10 +75,6 @@ class ToolNode(Node[ToolNodeData]):
"""
Run the tool node
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
dify_ctx = self.require_dify_context()
# fetch tool icon
tool_info = {
"provider_type": self.node_data.provider_type.value,
@@ -86,8 +84,6 @@ class ToolNode(Node[ToolNodeData]):
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
@@ -95,13 +91,10 @@ class ToolNode(Node[ToolNodeData]):
variable_pool: VariablePool | None = None
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
self._node_id,
self.node_data,
dify_ctx.invoke_from,
variable_pool,
tool_runtime = self._runtime.get_runtime(
node_id=self._node_id,
node_data=self.node_data,
variable_pool=variable_pool,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@@ -116,7 +109,7 @@ class ToolNode(Node[ToolNodeData]):
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime)
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
@@ -132,14 +125,12 @@ class ToolNode(Node[ToolNodeData]):
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
message_stream = self._runtime.invoke(
tool_runtime=tool_runtime,
tool_parameters=parameters,
user_id=dify_ctx.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
provider_name=self.node_data.provider_name,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@@ -159,38 +150,16 @@ class ToolNode(Node[ToolNodeData]):
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_id=self._node_id,
tool_runtime=tool_runtime,
)
except ToolInvokeError as e:
except ToolNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
error_type=type(e).__name__,
)
)
except PluginInvokeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
error_type=type(e).__name__,
)
)
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool, error: {e.description}",
error=str(e),
error_type=type(e).__name__,
)
)
@@ -198,7 +167,7 @@ class ToolNode(Node[ToolNodeData]):
def _generate_parameters(
self,
*,
tool_parameters: Sequence[ToolParameter],
tool_parameters: Sequence[ToolRuntimeParameter],
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
@@ -207,7 +176,7 @@ class ToolNode(Node[ToolNodeData]):
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
@@ -247,40 +216,29 @@ class ToolNode(Node[ToolNodeData]):
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
messages: Generator[ToolRuntimeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_id: str,
tool_runtime: Tool,
tool_runtime: ToolRuntimeHandle,
**_: Any,
) -> Generator[NodeEventBase, None, LLMUsage]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
Convert graph-owned tool runtime messages into node outputs.
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict | list] = []
variables: dict[str, Any] = {}
for message in message_stream:
for message in messages:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
ToolRuntimeMessage.MessageType.IMAGE_LINK,
ToolRuntimeMessage.MessageType.BINARY_LINK,
ToolRuntimeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
url = message.message.text
if message.meta:
@@ -300,14 +258,11 @@ class ToolNode(Node[ToolNodeData]):
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
file = self._runtime.build_file_reference(mapping=mapping)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
elif message.type == ToolRuntimeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
@@ -320,27 +275,22 @@ class ToolNode(Node[ToolNodeData]):
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
files.append(self._runtime.build_file_reference(mapping=mapping))
elif message.type == ToolRuntimeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
elif message.type == ToolRuntimeMessage.MessageType.JSON:
assert isinstance(message.message, ToolRuntimeMessage.JsonMessage)
# JSON message handling for tool node
if message.message.json_object:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
elif message.type == ToolRuntimeMessage.MessageType.LINK:
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
# Check if this LINK message is a file link
file_obj = (message.meta or {}).get("file")
@@ -356,8 +306,8 @@ class ToolNode(Node[ToolNodeData]):
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
elif message.type == ToolRuntimeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolRuntimeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
@@ -374,7 +324,7 @@ class ToolNode(Node[ToolNodeData]):
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
elif message.type == ToolRuntimeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
@@ -385,38 +335,16 @@ class ToolNode(Node[ToolNodeData]):
if not isinstance(message.meta["file"], File):
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
elif message.type == ToolRuntimeMessage.MessageType.LOG:
assert isinstance(message.message, ToolRuntimeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
icon, icon_dark = self._runtime.resolve_provider_icons(
provider_name=dict_metadata["provider"],
default_icon=icon,
)
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
@@ -446,7 +374,7 @@ class ToolNode(Node[ToolNodeData]):
is_final=True,
)
usage = self._extract_tool_usage(tool_runtime)
usage = self._runtime.get_usage(tool_runtime=tool_runtime)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
@@ -468,21 +396,6 @@ class ToolNode(Node[ToolNodeData]):
return usage
@staticmethod
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
# Avoid importing WorkflowTool at module import time; rely on duck typing
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
latest = getattr(tool_runtime, "latest_usage", None)
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
# for any name, so we must type-check here.
if isinstance(latest, LLMUsage):
return latest
if isinstance(latest, dict):
# Allow dict payloads from external runtimes
return LLMUsage.model_validate(latest)
# Fallback to empty usage when attribute is missing or not a valid payload
return LLMUsage.empty_usage()
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,