refactor(llm node): tool call tool result entity

This commit is contained in:
Novice
2025-12-17 10:30:21 +08:00
parent dd0a870969
commit d3486cab31
17 changed files with 300 additions and 169 deletions

View File

@@ -516,6 +516,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
tool_files = tool_result.files if tool_result else []
# Record stream event based on chunk type
chunk_type = event.chunk_type or ChunkType.TEXT
match chunk_type:
@@ -525,13 +533,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._stream_buffer.record_thought_chunk(delta_text)
case ChunkType.TOOL_CALL:
self._stream_buffer.record_tool_call(
tool_call_id=event.tool_call_id or "",
tool_name=event.tool_name or "",
tool_arguments=event.tool_arguments or "",
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
)
case ChunkType.TOOL_RESULT:
self._stream_buffer.record_tool_result(
tool_call_id=event.tool_call_id or "",
tool_call_id=tool_call_id,
result=delta_text,
)
@@ -541,11 +549,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
message_id=self._message_id,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type.value if event.chunk_type else None,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_call_id=tool_call_id or None,
tool_name=tool_name or None,
tool_arguments=tool_arguments or None,
tool_files=tool_files,
)
def _handle_iteration_start_event(

View File

@@ -484,6 +484,14 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if delta_text is None:
return
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
tool_arguments = tool_call.arguments if tool_call else None
tool_files = tool_result.files if tool_result else []
# only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
@@ -492,11 +500,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
text=delta_text,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_files=tool_files,
)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:

View File

@@ -464,11 +464,8 @@ class WorkflowBasedAppRunner:
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
chunk_type=QueueChunkType(event.chunk_type.value),
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):

View File

@@ -9,6 +9,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_events import ToolCall, ToolResult
from core.workflow.nodes import NodeType
@@ -204,19 +205,11 @@ class QueueTextChunkEvent(AppQueueEvent):
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_files: list[str] = Field(default_factory=list)
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
# Tool streaming payloads
tool_call: ToolCall | None = None
"""structured tool call info"""
tool_result: ToolResult | None = None
"""structured tool result info"""
class QueueAgentMessageEvent(AppQueueEvent):

View File

@@ -1,11 +1,16 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",
"ToolResult",
"ToolResultStatus",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@@ -0,0 +1,33 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")

View File

@@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
@@ -321,7 +327,9 @@ class ResponseStreamCoordinator:
selector: Sequence[str],
chunk: str,
is_final: bool = False,
**extra_fields,
chunk_type: ChunkType = ChunkType.TEXT,
tool_call: ToolCall | None = None,
tool_result: ToolResult | None = None,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
@@ -334,7 +342,9 @@ class ResponseStreamCoordinator:
selector: The variable selector
chunk: The chunk content
is_final: Whether this is the final chunk
**extra_fields: Additional fields for specialized events (chunk_type, tool_call_id, etc.)
chunk_type: The semantic type of the chunk being streamed
tool_call: Structured data for tool_call chunks
tool_result: Structured data for tool_result chunks
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
@@ -347,7 +357,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
**extra_fields,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
# Standard case: selector refers to an actual node
@@ -359,7 +371,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
**extra_fields,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
@@ -436,11 +450,8 @@ class ResponseStreamCoordinator:
chunk=event.chunk,
is_final=event.is_final,
chunk_type=event.chunk_type,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
events.append(updated_event)
else:

View File

@@ -45,6 +45,8 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
__all__ = [
@@ -75,4 +77,6 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"ToolCall",
"ToolResult",
]

View File

@@ -5,7 +5,7 @@ from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -43,13 +43,16 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call_id: str | None = Field(default=None, description="unique identifier for this tool call")
tool_name: str | None = Field(default=None, description="name of the tool being called")
tool_arguments: str | None = Field(default=None, description="accumulated tool arguments JSON")
tool_call: ToolCall | None = Field(
default=None,
description="structured payload for tool_call chunks",
)
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
tool_error: str | None = Field(default=None, description="error message if tool failed")
tool_result: ToolResult | None = Field(
default=None,
description="structured payload for tool_result chunks",
)
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):

View File

@@ -7,6 +7,7 @@ from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
@@ -51,25 +52,22 @@ class StreamChunkEvent(NodeEventBase):
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
class ToolCallChunkEvent(StreamChunkEvent):
"""Tool call streaming event - tool call arguments streaming output."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
tool_call_id: str = Field(..., description="unique identifier for this tool call")
tool_name: str = Field(..., description="name of the tool being called")
tool_arguments: str = Field(default="", description="accumulated tool arguments JSON")
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
class ToolResultChunkEvent(StreamChunkEvent):
"""Tool result event - tool execution result."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
tool_call_id: str = Field(..., description="identifier of the tool call this result belongs to")
tool_name: str = Field(..., description="name of the tool")
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
tool_error: str | None = Field(default=None, description="error message if tool failed")
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
class ThoughtChunkEvent(StreamChunkEvent):

View File

@@ -556,6 +556,8 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
@_dispatch.register
@@ -570,14 +572,18 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_CALL,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_call=event.tool_call,
)
@_dispatch.register
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
from core.workflow.entities import ToolResult
from core.workflow.graph_events import ChunkType, ToolResultStatus
tool_result = event.tool_result
status: ToolResultStatus = (
tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS
)
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
@@ -587,10 +593,13 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_RESULT,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_result=ToolResult(
id=tool_result.id if tool_result else None,
name=tool_result.name if tool_result else None,
output=tool_result.output if tool_result else None,
files=tool_result.files if tool_result else [],
status=status,
),
)
@_dispatch.register

View File

@@ -8,6 +8,7 @@ from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCallResult
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@@ -33,12 +34,10 @@ class LLMTraceSegment(BaseModel):
text: str | None = Field(None, description="Text chunk for thought/content")
# Tool call fields (combined start + result)
tool_call_id: str | None = None
tool_name: str | None = None
tool_arguments: str | None = None
tool_output: str | None = None
tool_error: str | None = None
files: list[str] = Field(default_factory=list, description="File IDs from tool result if any")
tool_call: ToolCallResult | None = Field(
default=None,
description="Combined tool call arguments and result for this segment",
)
class LLMGenerationData(BaseModel):
@@ -51,7 +50,7 @@ class LLMGenerationData(BaseModel):
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[dict[str, Any]] = Field(default_factory=list, description="Tool calls with results")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")

View File

@@ -60,7 +60,8 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from core.workflow.entities.tool_entities import ToolCallResult
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@@ -1671,9 +1672,11 @@ class LLMNode(Node[LLMNodeData]):
tool_call_segment = LLMTraceSegment(
type="tool_call",
text=None,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
)
trace_segments.append(tool_call_segment)
if tool_call_id:
@@ -1682,9 +1685,11 @@ class LLMNode(Node[LLMNodeData]):
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
is_final=False,
)
@@ -1724,27 +1729,50 @@ class LLMNode(Node[LLMNodeData]):
tool_call_segment = existing_tool_segment or LLMTraceSegment(
type="tool_call",
text=None,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
),
)
if existing_tool_segment is None:
trace_segments.append(tool_call_segment)
if tool_call_id:
tool_trace_map[tool_call_id] = tool_call_segment
tool_call_segment.tool_output = str(tool_output) if tool_output is not None else None
tool_call_segment.tool_error = str(tool_error) if tool_error is not None else None
tool_call_segment.files = [str(f) for f in tool_files] if tool_files else []
if tool_call_segment.tool_call is None:
tool_call_segment.tool_call = ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
)
tool_call_segment.tool_call.output = (
str(tool_output)
if tool_output is not None
else str(tool_error)
if tool_error is not None
else None
)
tool_call_segment.tool_call.files = []
tool_call_segment.tool_call.status = (
ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
)
current_turn += 1
result_output = (
str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=str(tool_output) if tool_output else "",
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_files=tool_files,
tool_error=tool_error,
chunk=result_output or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_output,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
),
is_final=False,
)
@@ -1865,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]):
sequence.append({"type": "content", "start": start, "end": end})
content_position = end
elif trace_segment.type == "tool_call":
tool_id = trace_segment.tool_call_id or ""
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
if tool_id not in tool_call_seen_index:
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
@@ -1893,9 +1921,11 @@ class LLMNode(Node[LLMNodeData]):
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call_id="",
tool_name="",
tool_arguments="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
@@ -1903,33 +1933,40 @@ class LLMNode(Node[LLMNodeData]):
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_call_id="",
tool_name="",
tool_files=[],
tool_error=None,
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
# Build tool_calls from agent_logs (with results)
tool_calls_for_generation = []
tool_calls_for_generation: list[ToolCallResult] = []
for log in agent_logs:
tool_call_id = log.data.get("tool_call_id")
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
continue
tool_args = log.data.get("tool_args") or {}
log_error = log.data.get("error")
log_output = log.data.get("output")
result_text = log_output or log_error or ""
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
tool_calls_for_generation.append(
{
"id": tool_call_id,
"name": log.data.get("tool_name", ""),
"arguments": json.dumps(tool_args) if tool_args else "",
# Prefer output, fall back to error text if present
"result": log.data.get("output") or log.data.get("error") or "",
}
ToolCallResult(
id=tool_call_id,
name=log.data.get("tool_name", ""),
arguments=json.dumps(tool_args) if tool_args else "",
output=result_text,
status=status,
)
)
tool_calls_for_generation.sort(
key=lambda item: tool_call_index_map.get(item.get("id", ""), len(tool_call_index_map))
key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map))
)
# Return generation data for caller

View File

@@ -683,10 +683,6 @@ class WorkflowRun(Base):
def outputs_dict(self) -> Mapping[str, Any]:
return json.loads(self.outputs) if self.outputs else {}
@property
def outputs_as_generation(self) -> bool:
return is_generation_outputs(self.outputs_dict)
@property
def message(self):
from .model import Message
@@ -712,6 +708,7 @@ class WorkflowRun(Base):
"inputs": self.inputs_dict,
"status": self.status,
"outputs": self.outputs_dict,
"outputs_as_generation": is_generation_outputs(self.outputs_dict),
"error": self.error,
"elapsed_time": self.elapsed_time,
"total_tokens": self.total_tokens,

View File

@@ -102,14 +102,12 @@ class WorkflowRunService:
:param app_model: app model
:param run_id: workflow run id
"""
workflow_run = self._workflow_run_repo.get_workflow_run_by_id(
return self._workflow_run_repo.get_workflow_run_by_id(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
run_id=run_id,
)
return workflow_run
def get_workflow_runs_count(
self,
app_model: App,

View File

@@ -2,10 +2,16 @@
from unittest.mock import MagicMock
from core.workflow.entities import ToolResultStatus
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import ChunkType, NodeRunStreamChunkEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.runtime import VariablePool
@@ -80,9 +86,11 @@ class TestResponseCoordinatorObjectStreaming:
chunk='{"query": "test"}',
is_final=True,
chunk_type=ChunkType.TOOL_CALL,
tool_call_id="call_123",
tool_name="search",
tool_arguments='{"query": "test"}',
tool_call=ToolCall(
id="call_123",
name="search",
arguments='{"query": "test"}',
),
)
# 3. Tool result stream
@@ -94,10 +102,13 @@ class TestResponseCoordinatorObjectStreaming:
chunk="Found 10 results",
is_final=True,
chunk_type=ChunkType.TOOL_RESULT,
tool_call_id="call_123",
tool_name="search",
tool_files=[],
tool_error=None,
tool_result=ToolResult(
id="call_123",
name="search",
output="Found 10 results",
files=[],
status=ToolResultStatus.SUCCESS,
),
)
# Intercept these events
@@ -111,6 +122,14 @@ class TestResponseCoordinatorObjectStreaming:
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
# Verify payloads are preserved in buffered events
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
assert buffered_call.tool_call is not None
assert buffered_call.tool_call.id == "call_123"
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
assert buffered_result.tool_result is not None
assert buffered_result.tool_result.status == "success"
# Verify we can find child streams
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
assert len(child_streams) == 3

View File

@@ -1,5 +1,6 @@
"""Tests for StreamChunkEvent and its subclasses."""
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
from core.workflow.node_events import (
ChunkType,
StreamChunkEvent,
@@ -87,14 +88,13 @@ class TestToolCallChunkEvent:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call_id="call_123",
tool_name="weather",
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
)
assert event.selector == ["node1", "tool_calls"]
assert event.chunk == '{"city": "Beijing"}'
assert event.tool_call_id == "call_123"
assert event.tool_name == "weather"
assert event.tool_call.id == "call_123"
assert event.tool_call.name == "weather"
assert event.chunk_type == ChunkType.TOOL_CALL
def test_chunk_type_is_tool_call(self):
@@ -102,8 +102,7 @@ class TestToolCallChunkEvent:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call_id="call_123",
tool_name="test_tool",
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
)
assert event.chunk_type == ChunkType.TOOL_CALL
@@ -113,30 +112,34 @@ class TestToolCallChunkEvent:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"param": "value"}',
tool_call_id="call_123",
tool_name="test_tool",
tool_arguments='{"param": "value"}',
tool_call=ToolCall(
id="call_123",
name="test_tool",
arguments='{"param": "value"}',
),
)
assert event.tool_arguments == '{"param": "value"}'
assert event.tool_call.arguments == '{"param": "value"}'
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call_id="call_123",
tool_name="weather",
tool_arguments='{"city": "Beijing"}',
tool_call=ToolCall(
id="call_123",
name="weather",
arguments='{"city": "Beijing"}',
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_call"
assert data["tool_call_id"] == "call_123"
assert data["tool_name"] == "weather"
assert data["tool_arguments"] == '{"city": "Beijing"}'
assert data["tool_call"]["id"] == "call_123"
assert data["tool_call"]["name"] == "weather"
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
assert data["is_final"] is True
@@ -148,14 +151,13 @@ class TestToolResultChunkEvent:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny, 25°C",
tool_call_id="call_123",
tool_name="weather",
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
)
assert event.selector == ["node1", "tool_results"]
assert event.chunk == "Weather: Sunny, 25°C"
assert event.tool_call_id == "call_123"
assert event.tool_name == "weather"
assert event.tool_result.id == "call_123"
assert event.tool_result.name == "weather"
assert event.chunk_type == ChunkType.TOOL_RESULT
def test_chunk_type_is_tool_result(self):
@@ -163,8 +165,7 @@ class TestToolResultChunkEvent:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test_tool",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.chunk_type == ChunkType.TOOL_RESULT
@@ -174,55 +175,62 @@ class TestToolResultChunkEvent:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test_tool",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.tool_files == []
assert event.tool_result.files == []
def test_tool_files_with_values(self):
"""Test tool_files with file IDs."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test_tool",
tool_files=["file_1", "file_2"],
tool_result=ToolResult(
id="call_123",
name="test_tool",
files=["file_1", "file_2"],
),
)
assert event.tool_files == ["file_1", "file_2"]
assert event.tool_result.files == ["file_1", "file_2"]
def test_tool_error_field(self):
"""Test tool_error field."""
def test_tool_error_output(self):
"""Test error output captured in tool_result."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="",
tool_call_id="call_123",
tool_name="test_tool",
tool_error="Tool execution failed",
tool_result=ToolResult(
id="call_123",
name="test_tool",
output="Tool execution failed",
status=ToolResultStatus.ERROR,
),
)
assert event.tool_error == "Tool execution failed"
assert event.tool_result.output == "Tool execution failed"
assert event.tool_result.status == ToolResultStatus.ERROR
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny",
tool_call_id="call_123",
tool_name="weather",
tool_files=["file_1"],
tool_error=None,
tool_result=ToolResult(
id="call_123",
name="weather",
output="Weather: Sunny",
files=["file_1"],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_result"
assert data["tool_call_id"] == "call_123"
assert data["tool_name"] == "weather"
assert data["tool_files"] == ["file_1"]
assert data["tool_error"] is None
assert data["tool_result"]["id"] == "call_123"
assert data["tool_result"]["name"] == "weather"
assert data["tool_result"]["files"] == ["file_1"]
assert data["is_final"] is True
@@ -272,8 +280,7 @@ class TestEventInheritance:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call_id="call_123",
tool_name="test",
tool_call=ToolCall(id="call_123", name="test", arguments=None),
)
assert isinstance(event, StreamChunkEvent)
@@ -283,8 +290,7 @@ class TestEventInheritance:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test",
tool_result=ToolResult(id="call_123", name="test"),
)
assert isinstance(event, StreamChunkEvent)
@@ -302,8 +308,16 @@ class TestEventInheritance:
"""Test that all events have common StreamChunkEvent fields."""
events = [
StreamChunkEvent(selector=["n", "t"], chunk="a"),
ToolCallChunkEvent(selector=["n", "t"], chunk="b", tool_call_id="1", tool_name="t"),
ToolResultChunkEvent(selector=["n", "t"], chunk="c", tool_call_id="1", tool_name="t"),
ToolCallChunkEvent(
selector=["n", "t"],
chunk="b",
tool_call=ToolCall(id="1", name="t", arguments=None),
),
ToolResultChunkEvent(
selector=["n", "t"],
chunk="c",
tool_result=ToolResult(id="1", name="t"),
),
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
]