From d3486cab315b5f13744d49d52ebe98f11f976b73 Mon Sep 17 00:00:00 2001 From: Novice Date: Wed, 17 Dec 2025 10:30:21 +0800 Subject: [PATCH] refactor(llm node): tool call tool result entity --- .../advanced_chat/generate_task_pipeline.py | 25 ++-- .../apps/workflow/generate_task_pipeline.py | 17 ++- api/core/app/apps/workflow_app_runner.py | 7 +- api/core/app/entities/queue_entities.py | 19 +-- api/core/workflow/entities/__init__.py | 5 + api/core/workflow/entities/tool_entities.py | 33 ++++++ .../response_coordinator/coordinator.py | 31 +++-- api/core/workflow/graph_events/__init__.py | 4 + api/core/workflow/graph_events/node.py | 15 ++- api/core/workflow/node_events/node.py | 12 +- api/core/workflow/nodes/base/node.py | 25 ++-- api/core/workflow/nodes/llm/entities.py | 13 +- api/core/workflow/nodes/llm/node.py | 107 +++++++++++------ api/models/workflow.py | 5 +- api/services/workflow_run_service.py | 4 +- .../graph_engine/test_response_coordinator.py | 35 ++++-- .../node_events/test_stream_chunk_events.py | 112 ++++++++++-------- 17 files changed, 300 insertions(+), 169 deletions(-) create mode 100644 api/core/workflow/entities/tool_entities.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f3e9439938..0d5e1e5dfd 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -516,6 +516,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if tts_publisher and queue_message: tts_publisher.publish(queue_message) + tool_call = event.tool_call + tool_result = event.tool_result + tool_payload = tool_call or tool_result + tool_call_id = tool_payload.id if tool_payload and tool_payload.id else "" + tool_name = tool_payload.name if tool_payload and tool_payload.name else "" + tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else "" + tool_files = tool_result.files if tool_result else [] + # Record stream event based on chunk type chunk_type = event.chunk_type or ChunkType.TEXT match chunk_type: @@ -525,13 +533,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._stream_buffer.record_thought_chunk(delta_text) case ChunkType.TOOL_CALL: self._stream_buffer.record_tool_call( - tool_call_id=event.tool_call_id or "", - tool_name=event.tool_name or "", - tool_arguments=event.tool_arguments or "", + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, ) case ChunkType.TOOL_RESULT: self._stream_buffer.record_tool_result( - tool_call_id=event.tool_call_id or "", + tool_call_id=tool_call_id, result=delta_text, ) @@ -541,11 +549,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): message_id=self._message_id, from_variable_selector=event.from_variable_selector, chunk_type=event.chunk_type.value if event.chunk_type else None, - tool_call_id=event.tool_call_id, - tool_name=event.tool_name, - tool_arguments=event.tool_arguments, - tool_files=event.tool_files, - tool_error=event.tool_error, + tool_call_id=tool_call_id or None, + tool_name=tool_name or None, + tool_arguments=tool_arguments or None, + tool_files=tool_files, ) def _handle_iteration_start_event( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index a6c7067ccd..1a9d09f5e7 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -484,6 +484,14 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): if delta_text is None: return + tool_call = event.tool_call + tool_result = event.tool_result + tool_payload = tool_call or tool_result + tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None + tool_name = tool_payload.name if tool_payload and tool_payload.name else None + tool_arguments = tool_call.arguments if tool_call else None + tool_files = tool_result.files if tool_result else [] + # only publish tts message at text chunk streaming if tts_publisher and queue_message: tts_publisher.publish(queue_message) @@ -492,11 +500,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): text=delta_text, from_variable_selector=event.from_variable_selector, chunk_type=event.chunk_type, - tool_call_id=event.tool_call_id, - tool_name=event.tool_name, - tool_arguments=event.tool_arguments, - tool_files=event.tool_files, - tool_error=event.tool_error, + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + tool_files=tool_files, ) def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 23624cb934..6ce33c98ee 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -464,11 +464,8 @@ class WorkflowBasedAppRunner: in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, chunk_type=QueueChunkType(event.chunk_type.value), - tool_call_id=event.tool_call_id, - tool_name=event.tool_name, - tool_arguments=event.tool_arguments, - tool_files=event.tool_files, - tool_error=event.tool_error, + tool_call=event.tool_call, + tool_result=event.tool_result, ) ) elif isinstance(event, NodeRunRetrieverResourceEvent): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index edb2c8a1f3..e07efbc38c 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -9,6 +9,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit from core.workflow.enums import WorkflowNodeExecutionMetadataKey +from core.workflow.graph_events import ToolCall, ToolResult from core.workflow.nodes import NodeType @@ -204,19 +205,11 @@ class QueueTextChunkEvent(AppQueueEvent): chunk_type: ChunkType = ChunkType.TEXT """type of the chunk""" - # Tool call fields (when chunk_type == TOOL_CALL) - tool_call_id: str | None = None - """unique identifier for this tool call""" - tool_name: str | None = None - """name of the tool being called""" - tool_arguments: str | None = None - """accumulated tool arguments JSON""" - - # Tool result fields (when chunk_type == TOOL_RESULT) - tool_files: list[str] = Field(default_factory=list) - """file IDs produced by tool""" - tool_error: str | None = None - """error message if tool failed""" + # Tool streaming payloads + tool_call: ToolCall | None = None + """structured tool call info""" + tool_result: ToolResult | None = None + """structured tool result info""" class QueueAgentMessageEvent(AppQueueEvent): diff --git a/api/core/workflow/entities/__init__.py b/api/core/workflow/entities/__init__.py index be70e467a0..0f3b9a5239 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/core/workflow/entities/__init__.py @@ -1,11 +1,16 @@ from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams +from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution __all__ = [ "AgentNodeStrategyInit", "GraphInitParams", + "ToolCall", + "ToolCallResult", + "ToolResult", + "ToolResultStatus", "WorkflowExecution", "WorkflowNodeExecution", ] diff --git a/api/core/workflow/entities/tool_entities.py b/api/core/workflow/entities/tool_entities.py new file mode 100644 index 0000000000..f4833218c7 --- /dev/null +++ b/api/core/workflow/entities/tool_entities.py @@ -0,0 +1,33 @@ +from enum import StrEnum + +from pydantic import BaseModel, Field + +from core.file import File + + +class ToolResultStatus(StrEnum): + SUCCESS = "success" + ERROR = "error" + + +class ToolCall(BaseModel): + id: str | None = Field(default=None, description="Unique identifier for this tool call") + name: str | None = Field(default=None, description="Name of the tool being called") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + + +class ToolResult(BaseModel): + id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to") + name: str | None = Field(default=None, description="Name of the tool") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[str] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") + + +class ToolCallResult(BaseModel): + id: str | None = Field(default=None, description="Identifier for the tool call") + name: str | None = Field(default=None, description="Name of the tool") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[File] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 1396c3a7ff..631440c6c1 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -16,7 +16,13 @@ from pydantic import BaseModel, Field from core.workflow.enums import NodeExecutionType, NodeState from core.workflow.graph import Graph -from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent +from core.workflow.graph_events import ( + ChunkType, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ToolCall, + ToolResult, +) from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.runtime import VariablePool @@ -321,7 +327,9 @@ class ResponseStreamCoordinator: selector: Sequence[str], chunk: str, is_final: bool = False, - **extra_fields, + chunk_type: ChunkType = ChunkType.TEXT, + tool_call: ToolCall | None = None, + tool_result: ToolResult | None = None, ) -> NodeRunStreamChunkEvent: """Create a stream chunk event with consistent structure. @@ -334,7 +342,9 @@ class ResponseStreamCoordinator: selector: The variable selector chunk: The chunk content is_final: Whether this is the final chunk - **extra_fields: Additional fields for specialized events (chunk_type, tool_call_id, etc.) + chunk_type: The semantic type of the chunk being streamed + tool_call: Structured data for tool_call chunks + tool_result: Structured data for tool_result chunks """ # Check if this is a special selector that doesn't correspond to a node if selector and selector[0] not in self._graph.nodes and self._active_session: @@ -347,7 +357,9 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, - **extra_fields, + chunk_type=chunk_type, + tool_call=tool_call, + tool_result=tool_result, ) # Standard case: selector refers to an actual node @@ -359,7 +371,9 @@ class ResponseStreamCoordinator: selector=selector, chunk=chunk, is_final=is_final, - **extra_fields, + chunk_type=chunk_type, + tool_call=tool_call, + tool_result=tool_result, ) def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: @@ -436,11 +450,8 @@ class ResponseStreamCoordinator: chunk=event.chunk, is_final=event.is_final, chunk_type=event.chunk_type, - tool_call_id=event.tool_call_id, - tool_name=event.tool_name, - tool_arguments=event.tool_arguments, - tool_files=event.tool_files, - tool_error=event.tool_error, + tool_call=event.tool_call, + tool_result=event.tool_result, ) events.append(updated_event) else: diff --git a/api/core/workflow/graph_events/__init__.py b/api/core/workflow/graph_events/__init__.py index 6c37fa1bc6..4ee0ec94d2 100644 --- a/api/core/workflow/graph_events/__init__.py +++ b/api/core/workflow/graph_events/__init__.py @@ -45,6 +45,8 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + ToolCall, + ToolResult, ) __all__ = [ @@ -75,4 +77,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "ToolCall", + "ToolResult", ] diff --git a/api/core/workflow/graph_events/node.py b/api/core/workflow/graph_events/node.py index 3351f028b1..01bc27d3e4 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/core/workflow/graph_events/node.py @@ -5,7 +5,7 @@ from enum import StrEnum from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit +from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult from core.workflow.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -43,13 +43,16 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase): chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") # Tool call fields (when chunk_type == TOOL_CALL) - tool_call_id: str | None = Field(default=None, description="unique identifier for this tool call") - tool_name: str | None = Field(default=None, description="name of the tool being called") - tool_arguments: str | None = Field(default=None, description="accumulated tool arguments JSON") + tool_call: ToolCall | None = Field( + default=None, + description="structured payload for tool_call chunks", + ) # Tool result fields (when chunk_type == TOOL_RESULT) - tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool") - tool_error: str | None = Field(default=None, description="error message if tool failed") + tool_result: ToolResult | None = Field( + default=None, + description="structured payload for tool_result chunks", + ) class NodeRunRetrieverResourceEvent(GraphNodeEventBase): diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index 43e00ae7d1..39f09d02a5 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -7,6 +7,7 @@ from pydantic import Field from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.entities import ToolCall, ToolResult from core.workflow.entities.pause_reason import PauseReason from core.workflow.node_events import NodeRunResult @@ -51,25 +52,22 @@ class StreamChunkEvent(NodeEventBase): chunk: str = Field(..., description="the actual chunk content") is_final: bool = Field(default=False, description="indicates if this is the last chunk") chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk") + tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks") + tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks") class ToolCallChunkEvent(StreamChunkEvent): """Tool call streaming event - tool call arguments streaming output.""" chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True) - tool_call_id: str = Field(..., description="unique identifier for this tool call") - tool_name: str = Field(..., description="name of the tool being called") - tool_arguments: str = Field(default="", description="accumulated tool arguments JSON") + tool_call: ToolCall | None = Field(default=None, description="structured tool call payload") class ToolResultChunkEvent(StreamChunkEvent): """Tool result event - tool execution result.""" chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True) - tool_call_id: str = Field(..., description="identifier of the tool call this result belongs to") - tool_name: str = Field(..., description="name of the tool") - tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool") - tool_error: str | None = Field(default=None, description="error message if tool failed") + tool_result: ToolResult | None = Field(default=None, description="structured tool result payload") class ThoughtChunkEvent(StreamChunkEvent): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index e71335d3b4..40feda8b57 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -556,6 +556,8 @@ class Node(Generic[NodeDataT]): chunk=event.chunk, is_final=event.is_final, chunk_type=ChunkType(event.chunk_type.value), + tool_call=event.tool_call, + tool_result=event.tool_result, ) @_dispatch.register @@ -570,14 +572,18 @@ class Node(Generic[NodeDataT]): chunk=event.chunk, is_final=event.is_final, chunk_type=ChunkType.TOOL_CALL, - tool_call_id=event.tool_call_id, - tool_name=event.tool_name, - tool_arguments=event.tool_arguments, + tool_call=event.tool_call, ) @_dispatch.register def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent: - from core.workflow.graph_events import ChunkType + from core.workflow.entities import ToolResult + from core.workflow.graph_events import ChunkType, ToolResultStatus + + tool_result = event.tool_result + status: ToolResultStatus = ( + tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS + ) return NodeRunStreamChunkEvent( id=self._node_execution_id, @@ -587,10 +593,13 @@ class Node(Generic[NodeDataT]): chunk=event.chunk, is_final=event.is_final, chunk_type=ChunkType.TOOL_RESULT, - tool_call_id=event.tool_call_id, - tool_name=event.tool_name, - tool_files=event.tool_files, - tool_error=event.tool_error, + tool_result=ToolResult( + id=tool_result.id if tool_result else None, + name=tool_result.name if tool_result else None, + output=tool_result.output if tool_result else None, + files=tool_result.files if tool_result else [], + status=status, + ), ) @_dispatch.register diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 82f146414b..f57e251cdf 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -8,6 +8,7 @@ from core.model_runtime.entities import ImagePromptMessageContent, LLMMode from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.entities import ToolCallResult from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector @@ -33,12 +34,10 @@ class LLMTraceSegment(BaseModel): text: str | None = Field(None, description="Text chunk for thought/content") # Tool call fields (combined start + result) - tool_call_id: str | None = None - tool_name: str | None = None - tool_arguments: str | None = None - tool_output: str | None = None - tool_error: str | None = None - files: list[str] = Field(default_factory=list, description="File IDs from tool result if any") + tool_call: ToolCallResult | None = Field( + default=None, + description="Combined tool call arguments and result for this segment", + ) class LLMGenerationData(BaseModel): @@ -51,7 +50,7 @@ class LLMGenerationData(BaseModel): text: str = Field(..., description="Accumulated text content from all turns") reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn") - tool_calls: list[dict[str, Any]] = Field(default_factory=list, description="Tool calls with results") + tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results") sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering") usage: LLMUsage = Field(..., description="LLM usage statistics") finish_reason: str | None = Field(None, description="Finish reason from LLM") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 3b82ec4a05..94b616bd34 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -60,7 +60,8 @@ from core.variables import ( StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams +from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus +from core.workflow.entities.tool_entities import ToolCallResult from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -1671,9 +1672,11 @@ class LLMNode(Node[LLMNodeData]): tool_call_segment = LLMTraceSegment( type="tool_call", text=None, - tool_call_id=tool_call_id, - tool_name=tool_name, - tool_arguments=tool_arguments, + tool_call=ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + ), ) trace_segments.append(tool_call_segment) if tool_call_id: @@ -1682,9 +1685,11 @@ class LLMNode(Node[LLMNodeData]): yield ToolCallChunkEvent( selector=[self._node_id, "generation", "tool_calls"], chunk=tool_arguments, - tool_call_id=tool_call_id, - tool_name=tool_name, - tool_arguments=tool_arguments, + tool_call=ToolCall( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + ), is_final=False, ) @@ -1724,27 +1729,50 @@ class LLMNode(Node[LLMNodeData]): tool_call_segment = existing_tool_segment or LLMTraceSegment( type="tool_call", text=None, - tool_call_id=tool_call_id, - tool_name=tool_name, - tool_arguments=None, + tool_call=ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=None, + ), ) if existing_tool_segment is None: trace_segments.append(tool_call_segment) if tool_call_id: tool_trace_map[tool_call_id] = tool_call_segment - tool_call_segment.tool_output = str(tool_output) if tool_output is not None else None - tool_call_segment.tool_error = str(tool_error) if tool_error is not None else None - tool_call_segment.files = [str(f) for f in tool_files] if tool_files else [] + if tool_call_segment.tool_call is None: + tool_call_segment.tool_call = ToolCallResult( + id=tool_call_id, + name=tool_name, + arguments=None, + ) + tool_call_segment.tool_call.output = ( + str(tool_output) + if tool_output is not None + else str(tool_error) + if tool_error is not None + else None + ) + tool_call_segment.tool_call.files = [] + tool_call_segment.tool_call.status = ( + ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS + ) current_turn += 1 + result_output = ( + str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None + ) + yield ToolResultChunkEvent( selector=[self._node_id, "generation", "tool_results"], - chunk=str(tool_output) if tool_output else "", - tool_call_id=tool_call_id, - tool_name=tool_name, - tool_files=tool_files, - tool_error=tool_error, + chunk=result_output or "", + tool_result=ToolResult( + id=tool_call_id, + name=tool_name, + output=result_output, + files=tool_files, + status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS, + ), is_final=False, ) @@ -1865,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]): sequence.append({"type": "content", "start": start, "end": end}) content_position = end elif trace_segment.type == "tool_call": - tool_id = trace_segment.tool_call_id or "" + tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else "" if tool_id not in tool_call_seen_index: tool_call_seen_index[tool_id] = len(tool_call_seen_index) sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]}) @@ -1893,9 +1921,11 @@ class LLMNode(Node[LLMNodeData]): yield ToolCallChunkEvent( selector=[self._node_id, "generation", "tool_calls"], chunk="", - tool_call_id="", - tool_name="", - tool_arguments="", + tool_call=ToolCall( + id="", + name="", + arguments="", + ), is_final=True, ) @@ -1903,33 +1933,40 @@ class LLMNode(Node[LLMNodeData]): yield ToolResultChunkEvent( selector=[self._node_id, "generation", "tool_results"], chunk="", - tool_call_id="", - tool_name="", - tool_files=[], - tool_error=None, + tool_result=ToolResult( + id="", + name="", + output="", + files=[], + status=ToolResultStatus.SUCCESS, + ), is_final=True, ) # Build tool_calls from agent_logs (with results) - tool_calls_for_generation = [] + tool_calls_for_generation: list[ToolCallResult] = [] for log in agent_logs: tool_call_id = log.data.get("tool_call_id") if not tool_call_id or log.status == AgentLog.LogStatus.START.value: continue tool_args = log.data.get("tool_args") or {} + log_error = log.data.get("error") + log_output = log.data.get("output") + result_text = log_output or log_error or "" + status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS tool_calls_for_generation.append( - { - "id": tool_call_id, - "name": log.data.get("tool_name", ""), - "arguments": json.dumps(tool_args) if tool_args else "", - # Prefer output, fall back to error text if present - "result": log.data.get("output") or log.data.get("error") or "", - } + ToolCallResult( + id=tool_call_id, + name=log.data.get("tool_name", ""), + arguments=json.dumps(tool_args) if tool_args else "", + output=result_text, + status=status, + ) ) tool_calls_for_generation.sort( - key=lambda item: tool_call_index_map.get(item.get("id", ""), len(tool_call_index_map)) + key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map)) ) # Return generation data for caller diff --git a/api/models/workflow.py b/api/models/workflow.py index 89ec0352df..bc229fb4e4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -683,10 +683,6 @@ class WorkflowRun(Base): def outputs_dict(self) -> Mapping[str, Any]: return json.loads(self.outputs) if self.outputs else {} - @property - def outputs_as_generation(self) -> bool: - return is_generation_outputs(self.outputs_dict) - @property def message(self): from .model import Message @@ -712,6 +708,7 @@ class WorkflowRun(Base): "inputs": self.inputs_dict, "status": self.status, "outputs": self.outputs_dict, + "outputs_as_generation": is_generation_outputs(self.outputs_dict), "error": self.error, "elapsed_time": self.elapsed_time, "total_tokens": self.total_tokens, diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 1bc821c43d..b903d8df5f 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -102,14 +102,12 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = self._workflow_run_repo.get_workflow_run_by_id( + return self._workflow_run_repo.get_workflow_run_by_id( tenant_id=app_model.tenant_id, app_id=app_model.id, run_id=run_id, ) - return workflow_run - def get_workflow_runs_count( self, app_model: App, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py index 388496ce1d..8e0eba71cc 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_coordinator.py @@ -2,10 +2,16 @@ from unittest.mock import MagicMock +from core.workflow.entities import ToolResultStatus from core.workflow.enums import NodeType from core.workflow.graph import Graph from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from core.workflow.graph_events import ChunkType, NodeRunStreamChunkEvent +from core.workflow.graph_events import ( + ChunkType, + NodeRunStreamChunkEvent, + ToolCall, + ToolResult, +) from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.runtime import VariablePool @@ -80,9 +86,11 @@ class TestResponseCoordinatorObjectStreaming: chunk='{"query": "test"}', is_final=True, chunk_type=ChunkType.TOOL_CALL, - tool_call_id="call_123", - tool_name="search", - tool_arguments='{"query": "test"}', + tool_call=ToolCall( + id="call_123", + name="search", + arguments='{"query": "test"}', + ), ) # 3. Tool result stream @@ -94,10 +102,13 @@ class TestResponseCoordinatorObjectStreaming: chunk="Found 10 results", is_final=True, chunk_type=ChunkType.TOOL_RESULT, - tool_call_id="call_123", - tool_name="search", - tool_files=[], - tool_error=None, + tool_result=ToolResult( + id="call_123", + name="search", + output="Found 10 results", + files=[], + status=ToolResultStatus.SUCCESS, + ), ) # Intercept these events @@ -111,6 +122,14 @@ class TestResponseCoordinatorObjectStreaming: assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers + # Verify payloads are preserved in buffered events + buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0] + assert buffered_call.tool_call is not None + assert buffered_call.tool_call.id == "call_123" + buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0] + assert buffered_result.tool_result is not None + assert buffered_result.tool_result.status == "success" + # Verify we can find child streams child_streams = coordinator._find_child_streams(["llm_node", "generation"]) assert len(child_streams) == 3 diff --git a/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py index f6e0834b1e..951149e933 100644 --- a/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py +++ b/api/tests/unit_tests/core/workflow/node_events/test_stream_chunk_events.py @@ -1,5 +1,6 @@ """Tests for StreamChunkEvent and its subclasses.""" +from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus from core.workflow.node_events import ( ChunkType, StreamChunkEvent, @@ -87,14 +88,13 @@ class TestToolCallChunkEvent: event = ToolCallChunkEvent( selector=["node1", "tool_calls"], chunk='{"city": "Beijing"}', - tool_call_id="call_123", - tool_name="weather", + tool_call=ToolCall(id="call_123", name="weather", arguments=None), ) assert event.selector == ["node1", "tool_calls"] assert event.chunk == '{"city": "Beijing"}' - assert event.tool_call_id == "call_123" - assert event.tool_name == "weather" + assert event.tool_call.id == "call_123" + assert event.tool_call.name == "weather" assert event.chunk_type == ChunkType.TOOL_CALL def test_chunk_type_is_tool_call(self): @@ -102,8 +102,7 @@ class TestToolCallChunkEvent: event = ToolCallChunkEvent( selector=["node1", "tool_calls"], chunk="", - tool_call_id="call_123", - tool_name="test_tool", + tool_call=ToolCall(id="call_123", name="test_tool", arguments=None), ) assert event.chunk_type == ChunkType.TOOL_CALL @@ -113,30 +112,34 @@ class TestToolCallChunkEvent: event = ToolCallChunkEvent( selector=["node1", "tool_calls"], chunk='{"param": "value"}', - tool_call_id="call_123", - tool_name="test_tool", - tool_arguments='{"param": "value"}', + tool_call=ToolCall( + id="call_123", + name="test_tool", + arguments='{"param": "value"}', + ), ) - assert event.tool_arguments == '{"param": "value"}' + assert event.tool_call.arguments == '{"param": "value"}' def test_serialization(self): """Test that event can be serialized to dict.""" event = ToolCallChunkEvent( selector=["node1", "tool_calls"], chunk='{"city": "Beijing"}', - tool_call_id="call_123", - tool_name="weather", - tool_arguments='{"city": "Beijing"}', + tool_call=ToolCall( + id="call_123", + name="weather", + arguments='{"city": "Beijing"}', + ), is_final=True, ) data = event.model_dump() assert data["chunk_type"] == "tool_call" - assert data["tool_call_id"] == "call_123" - assert data["tool_name"] == "weather" - assert data["tool_arguments"] == '{"city": "Beijing"}' + assert data["tool_call"]["id"] == "call_123" + assert data["tool_call"]["name"] == "weather" + assert data["tool_call"]["arguments"] == '{"city": "Beijing"}' assert data["is_final"] is True @@ -148,14 +151,13 @@ class TestToolResultChunkEvent: event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="Weather: Sunny, 25°C", - tool_call_id="call_123", - tool_name="weather", + tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"), ) assert event.selector == ["node1", "tool_results"] assert event.chunk == "Weather: Sunny, 25°C" - assert event.tool_call_id == "call_123" - assert event.tool_name == "weather" + assert event.tool_result.id == "call_123" + assert event.tool_result.name == "weather" assert event.chunk_type == ChunkType.TOOL_RESULT def test_chunk_type_is_tool_result(self): @@ -163,8 +165,7 @@ class TestToolResultChunkEvent: event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="result", - tool_call_id="call_123", - tool_name="test_tool", + tool_result=ToolResult(id="call_123", name="test_tool"), ) assert event.chunk_type == ChunkType.TOOL_RESULT @@ -174,55 +175,62 @@ class TestToolResultChunkEvent: event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="result", - tool_call_id="call_123", - tool_name="test_tool", + tool_result=ToolResult(id="call_123", name="test_tool"), ) - assert event.tool_files == [] + assert event.tool_result.files == [] def test_tool_files_with_values(self): """Test tool_files with file IDs.""" event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="result", - tool_call_id="call_123", - tool_name="test_tool", - tool_files=["file_1", "file_2"], + tool_result=ToolResult( + id="call_123", + name="test_tool", + files=["file_1", "file_2"], + ), ) - assert event.tool_files == ["file_1", "file_2"] + assert event.tool_result.files == ["file_1", "file_2"] - def test_tool_error_field(self): - """Test tool_error field.""" + def test_tool_error_output(self): + """Test error output captured in tool_result.""" event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="", - tool_call_id="call_123", - tool_name="test_tool", - tool_error="Tool execution failed", + tool_result=ToolResult( + id="call_123", + name="test_tool", + output="Tool execution failed", + status=ToolResultStatus.ERROR, + ), ) - assert event.tool_error == "Tool execution failed" + assert event.tool_result.output == "Tool execution failed" + assert event.tool_result.status == ToolResultStatus.ERROR def test_serialization(self): """Test that event can be serialized to dict.""" event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="Weather: Sunny", - tool_call_id="call_123", - tool_name="weather", - tool_files=["file_1"], - tool_error=None, + tool_result=ToolResult( + id="call_123", + name="weather", + output="Weather: Sunny", + files=["file_1"], + status=ToolResultStatus.SUCCESS, + ), is_final=True, ) data = event.model_dump() assert data["chunk_type"] == "tool_result" - assert data["tool_call_id"] == "call_123" - assert data["tool_name"] == "weather" - assert data["tool_files"] == ["file_1"] - assert data["tool_error"] is None + assert data["tool_result"]["id"] == "call_123" + assert data["tool_result"]["name"] == "weather" + assert data["tool_result"]["files"] == ["file_1"] assert data["is_final"] is True @@ -272,8 +280,7 @@ class TestEventInheritance: event = ToolCallChunkEvent( selector=["node1", "tool_calls"], chunk="", - tool_call_id="call_123", - tool_name="test", + tool_call=ToolCall(id="call_123", name="test", arguments=None), ) assert isinstance(event, StreamChunkEvent) @@ -283,8 +290,7 @@ class TestEventInheritance: event = ToolResultChunkEvent( selector=["node1", "tool_results"], chunk="result", - tool_call_id="call_123", - tool_name="test", + tool_result=ToolResult(id="call_123", name="test"), ) assert isinstance(event, StreamChunkEvent) @@ -302,8 +308,16 @@ class TestEventInheritance: """Test that all events have common StreamChunkEvent fields.""" events = [ StreamChunkEvent(selector=["n", "t"], chunk="a"), - ToolCallChunkEvent(selector=["n", "t"], chunk="b", tool_call_id="1", tool_name="t"), - ToolResultChunkEvent(selector=["n", "t"], chunk="c", tool_call_id="1", tool_name="t"), + ToolCallChunkEvent( + selector=["n", "t"], + chunk="b", + tool_call=ToolCall(id="1", name="t", arguments=None), + ), + ToolResultChunkEvent( + selector=["n", "t"], + chunk="c", + tool_result=ToolResult(id="1", name="t"), + ), ThoughtChunkEvent(selector=["n", "t"], chunk="d"), ]