mirror of
https://github.com/langgenius/dify.git
synced 2026-05-30 16:00:32 -04:00
feat: extract model runtime
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -1,12 +1,6 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import (
|
||||
BuiltinNodeTypes,
|
||||
@@ -20,10 +14,14 @@ from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEven
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.runtime import ToolNodeRuntimeProtocol
|
||||
from dify_graph.nodes.tool_runtime_entities import (
|
||||
ToolRuntimeHandle,
|
||||
ToolRuntimeMessage,
|
||||
ToolRuntimeParameter,
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from dify_graph.variables.variables import ArrayAnyVariable
|
||||
from factories import file_factory
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
from .exc import (
|
||||
@@ -52,6 +50,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
tool_file_manager_factory: ToolFileManagerProtocol,
|
||||
runtime: ToolNodeRuntimeProtocol | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -60,6 +59,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
if runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
self._runtime = runtime
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@@ -73,10 +75,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
@@ -86,8 +84,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
@@ -95,13 +91,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
variable_pool: VariablePool | None = None
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
self._node_id,
|
||||
self.node_data,
|
||||
dify_ctx.invoke_from,
|
||||
variable_pool,
|
||||
tool_runtime = self._runtime.get_runtime(
|
||||
node_id=self._node_id,
|
||||
node_data=self.node_data,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -116,7 +109,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
return
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime)
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
@@ -132,14 +125,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
message_stream = self._runtime.invoke(
|
||||
tool_runtime=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
provider_name=self.node_data.provider_name,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -159,38 +150,16 @@ class ToolNode(Node[ToolNodeData]):
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool, error: {e.description}",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
@@ -198,7 +167,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
tool_parameters: Sequence[ToolRuntimeParameter],
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
@@ -207,7 +176,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
@@ -247,40 +216,29 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
messages: Generator[ToolRuntimeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
tool_runtime: Tool,
|
||||
tool_runtime: ToolRuntimeHandle,
|
||||
**_: Any,
|
||||
) -> Generator[NodeEventBase, None, LLMUsage]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
Convert graph-owned tool runtime messages into node outputs.
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
for message in messages:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
ToolRuntimeMessage.MessageType.IMAGE_LINK,
|
||||
ToolRuntimeMessage.MessageType.BINARY_LINK,
|
||||
ToolRuntimeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
@@ -300,14 +258,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
file = self._runtime.build_file_reference(mapping=mapping)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
elif message.type == ToolRuntimeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
@@ -320,27 +275,22 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
files.append(self._runtime.build_file_reference(mapping=mapping))
|
||||
elif message.type == ToolRuntimeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.JsonMessage)
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.TextMessage)
|
||||
|
||||
# Check if this LINK message is a file link
|
||||
file_obj = (message.meta or {}).get("file")
|
||||
@@ -356,8 +306,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
@@ -374,7 +324,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
elif message.type == ToolRuntimeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
@@ -385,38 +335,16 @@ class ToolNode(Node[ToolNodeData]):
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
elif message.type == ToolRuntimeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolRuntimeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
icon, icon_dark = self._runtime.resolve_provider_icons(
|
||||
provider_name=dict_metadata["provider"],
|
||||
default_icon=icon,
|
||||
)
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
@@ -446,7 +374,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
usage = self._extract_tool_usage(tool_runtime)
|
||||
usage = self._runtime.get_usage(tool_runtime=tool_runtime)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
@@ -468,21 +396,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
return usage
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
||||
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
||||
latest = getattr(tool_runtime, "latest_usage", None)
|
||||
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
||||
# for any name, so we must type-check here.
|
||||
if isinstance(latest, LLMUsage):
|
||||
return latest
|
||||
if isinstance(latest, dict):
|
||||
# Allow dict payloads from external runtimes
|
||||
return LLMUsage.model_validate(latest)
|
||||
# Fallback to empty usage when attribute is missing or not a valid payload
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user