mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
refactor(llm node): tool call tool result entity
This commit is contained in:
@@ -516,6 +516,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
|
||||
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
|
||||
# Record stream event based on chunk type
|
||||
chunk_type = event.chunk_type or ChunkType.TEXT
|
||||
match chunk_type:
|
||||
@@ -525,13 +533,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._stream_buffer.record_thought_chunk(delta_text)
|
||||
case ChunkType.TOOL_CALL:
|
||||
self._stream_buffer.record_tool_call(
|
||||
tool_call_id=event.tool_call_id or "",
|
||||
tool_name=event.tool_name or "",
|
||||
tool_arguments=event.tool_arguments or "",
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
)
|
||||
case ChunkType.TOOL_RESULT:
|
||||
self._stream_buffer.record_tool_result(
|
||||
tool_call_id=event.tool_call_id or "",
|
||||
tool_call_id=tool_call_id,
|
||||
result=delta_text,
|
||||
)
|
||||
|
||||
@@ -541,11 +549,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
message_id=self._message_id,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type.value if event.chunk_type else None,
|
||||
tool_call_id=event.tool_call_id,
|
||||
tool_name=event.tool_name,
|
||||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
tool_call_id=tool_call_id or None,
|
||||
tool_name=tool_name or None,
|
||||
tool_arguments=tool_arguments or None,
|
||||
tool_files=tool_files,
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
|
||||
@@ -484,6 +484,14 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if delta_text is None:
|
||||
return
|
||||
|
||||
tool_call = event.tool_call
|
||||
tool_result = event.tool_result
|
||||
tool_payload = tool_call or tool_result
|
||||
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
|
||||
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
|
||||
tool_arguments = tool_call.arguments if tool_call else None
|
||||
tool_files = tool_result.files if tool_result else []
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
@@ -492,11 +500,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
text=delta_text,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=event.tool_call_id,
|
||||
tool_name=event.tool_name,
|
||||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_files=tool_files,
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
|
||||
@@ -464,11 +464,8 @@ class WorkflowBasedAppRunner:
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
chunk_type=QueueChunkType(event.chunk_type.value),
|
||||
tool_call_id=event.tool_call_id,
|
||||
tool_name=event.tool_name,
|
||||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_events import ToolCall, ToolResult
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
|
||||
@@ -204,19 +205,11 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
chunk_type: ChunkType = ChunkType.TEXT
|
||||
"""type of the chunk"""
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = None
|
||||
"""unique identifier for this tool call"""
|
||||
tool_name: str | None = None
|
||||
"""name of the tool being called"""
|
||||
tool_arguments: str | None = None
|
||||
"""accumulated tool arguments JSON"""
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] = Field(default_factory=list)
|
||||
"""file IDs produced by tool"""
|
||||
tool_error: str | None = None
|
||||
"""error message if tool failed"""
|
||||
# Tool streaming payloads
|
||||
tool_call: ToolCall | None = None
|
||||
"""structured tool call info"""
|
||||
tool_result: ToolResult | None = None
|
||||
"""structured tool result info"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"ToolResult",
|
||||
"ToolResultStatus",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
33
api/core/workflow/entities/tool_entities.py
Normal file
33
api/core/workflow/entities/tool_entities.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
@@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@@ -321,7 +327,9 @@ class ResponseStreamCoordinator:
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
**extra_fields,
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_result: ToolResult | None = None,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
@@ -334,7 +342,9 @@ class ResponseStreamCoordinator:
|
||||
selector: The variable selector
|
||||
chunk: The chunk content
|
||||
is_final: Whether this is the final chunk
|
||||
**extra_fields: Additional fields for specialized events (chunk_type, tool_call_id, etc.)
|
||||
chunk_type: The semantic type of the chunk being streamed
|
||||
tool_call: Structured data for tool_call chunks
|
||||
tool_result: Structured data for tool_result chunks
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
@@ -347,7 +357,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
**extra_fields,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
@@ -359,7 +371,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
**extra_fields,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
@@ -436,11 +450,8 @@ class ResponseStreamCoordinator:
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call_id=event.tool_call_id,
|
||||
tool_name=event.tool_name,
|
||||
tool_arguments=event.tool_arguments,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
|
||||
@@ -45,6 +45,8 @@ from .node import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -75,4 +77,6 @@ __all__ = [
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"ToolCall",
|
||||
"ToolResult",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ from enum import StrEnum
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@@ -43,13 +43,16 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call_id: str | None = Field(default=None, description="unique identifier for this tool call")
|
||||
tool_name: str | None = Field(default=None, description="name of the tool being called")
|
||||
tool_arguments: str | None = Field(default=None, description="accumulated tool arguments JSON")
|
||||
tool_call: ToolCall | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_call chunks",
|
||||
)
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
|
||||
tool_error: str | None = Field(default=None, description="error message if tool failed")
|
||||
tool_result: ToolResult | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_result chunks",
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import Field
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
@@ -51,25 +52,22 @@ class StreamChunkEvent(NodeEventBase):
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
|
||||
|
||||
|
||||
class ToolCallChunkEvent(StreamChunkEvent):
|
||||
"""Tool call streaming event - tool call arguments streaming output."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
|
||||
tool_call_id: str = Field(..., description="unique identifier for this tool call")
|
||||
tool_name: str = Field(..., description="name of the tool being called")
|
||||
tool_arguments: str = Field(default="", description="accumulated tool arguments JSON")
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
|
||||
|
||||
|
||||
class ToolResultChunkEvent(StreamChunkEvent):
|
||||
"""Tool result event - tool execution result."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
|
||||
tool_call_id: str = Field(..., description="identifier of the tool call this result belongs to")
|
||||
tool_name: str = Field(..., description="name of the tool")
|
||||
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
|
||||
tool_error: str | None = Field(default=None, description="error message if tool failed")
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
|
||||
|
||||
|
||||
class ThoughtChunkEvent(StreamChunkEvent):
|
||||
|
||||
@@ -556,6 +556,8 @@ class Node(Generic[NodeDataT]):
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
@@ -570,14 +572,18 @@ class Node(Generic[NodeDataT]):
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call_id=event.tool_call_id,
|
||||
tool_name=event.tool_name,
|
||||
tool_arguments=event.tool_arguments,
|
||||
tool_call=event.tool_call,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
from core.workflow.entities import ToolResult
|
||||
from core.workflow.graph_events import ChunkType, ToolResultStatus
|
||||
|
||||
tool_result = event.tool_result
|
||||
status: ToolResultStatus = (
|
||||
tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS
|
||||
)
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
@@ -587,10 +593,13 @@ class Node(Generic[NodeDataT]):
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_call_id=event.tool_call_id,
|
||||
tool_name=event.tool_name,
|
||||
tool_files=event.tool_files,
|
||||
tool_error=event.tool_error,
|
||||
tool_result=ToolResult(
|
||||
id=tool_result.id if tool_result else None,
|
||||
name=tool_result.name if tool_result else None,
|
||||
output=tool_result.output if tool_result else None,
|
||||
files=tool_result.files if tool_result else [],
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCallResult
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@@ -33,12 +34,10 @@ class LLMTraceSegment(BaseModel):
|
||||
text: str | None = Field(None, description="Text chunk for thought/content")
|
||||
|
||||
# Tool call fields (combined start + result)
|
||||
tool_call_id: str | None = None
|
||||
tool_name: str | None = None
|
||||
tool_arguments: str | None = None
|
||||
tool_output: str | None = None
|
||||
tool_error: str | None = None
|
||||
files: list[str] = Field(default_factory=list, description="File IDs from tool result if any")
|
||||
tool_call: ToolCallResult | None = Field(
|
||||
default=None,
|
||||
description="Combined tool call arguments and result for this segment",
|
||||
)
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
@@ -51,7 +50,7 @@ class LLMGenerationData(BaseModel):
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[dict[str, Any]] = Field(default_factory=list, description="Tool calls with results")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
|
||||
@@ -60,7 +60,8 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -1671,9 +1672,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
tool_call_segment = LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
)
|
||||
trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
@@ -1682,9 +1685,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
@@ -1724,27 +1729,50 @@ class LLMNode(Node[LLMNodeData]):
|
||||
tool_call_segment = existing_tool_segment or LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_arguments=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
if existing_tool_segment is None:
|
||||
trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
tool_call_segment.tool_output = str(tool_output) if tool_output is not None else None
|
||||
tool_call_segment.tool_error = str(tool_error) if tool_error is not None else None
|
||||
tool_call_segment.files = [str(f) for f in tool_files] if tool_files else []
|
||||
if tool_call_segment.tool_call is None:
|
||||
tool_call_segment.tool_call = ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
)
|
||||
tool_call_segment.tool_call.output = (
|
||||
str(tool_output)
|
||||
if tool_output is not None
|
||||
else str(tool_error)
|
||||
if tool_error is not None
|
||||
else None
|
||||
)
|
||||
tool_call_segment.tool_call.files = []
|
||||
tool_call_segment.tool_call.status = (
|
||||
ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
|
||||
)
|
||||
current_turn += 1
|
||||
|
||||
result_output = (
|
||||
str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
|
||||
)
|
||||
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk=str(tool_output) if tool_output else "",
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
tool_files=tool_files,
|
||||
tool_error=tool_error,
|
||||
chunk=result_output or "",
|
||||
tool_result=ToolResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
output=result_output,
|
||||
files=tool_files,
|
||||
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
@@ -1865,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_position = end
|
||||
elif trace_segment.type == "tool_call":
|
||||
tool_id = trace_segment.tool_call_id or ""
|
||||
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
|
||||
if tool_id not in tool_call_seen_index:
|
||||
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
|
||||
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
|
||||
@@ -1893,9 +1921,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call_id="",
|
||||
tool_name="",
|
||||
tool_arguments="",
|
||||
tool_call=ToolCall(
|
||||
id="",
|
||||
name="",
|
||||
arguments="",
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
@@ -1903,33 +1933,40 @@ class LLMNode(Node[LLMNodeData]):
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk="",
|
||||
tool_call_id="",
|
||||
tool_name="",
|
||||
tool_files=[],
|
||||
tool_error=None,
|
||||
tool_result=ToolResult(
|
||||
id="",
|
||||
name="",
|
||||
output="",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Build tool_calls from agent_logs (with results)
|
||||
tool_calls_for_generation = []
|
||||
tool_calls_for_generation: list[ToolCallResult] = []
|
||||
for log in agent_logs:
|
||||
tool_call_id = log.data.get("tool_call_id")
|
||||
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
|
||||
continue
|
||||
|
||||
tool_args = log.data.get("tool_args") or {}
|
||||
log_error = log.data.get("error")
|
||||
log_output = log.data.get("output")
|
||||
result_text = log_output or log_error or ""
|
||||
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
|
||||
tool_calls_for_generation.append(
|
||||
{
|
||||
"id": tool_call_id,
|
||||
"name": log.data.get("tool_name", ""),
|
||||
"arguments": json.dumps(tool_args) if tool_args else "",
|
||||
# Prefer output, fall back to error text if present
|
||||
"result": log.data.get("output") or log.data.get("error") or "",
|
||||
}
|
||||
ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=log.data.get("tool_name", ""),
|
||||
arguments=json.dumps(tool_args) if tool_args else "",
|
||||
output=result_text,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls_for_generation.sort(
|
||||
key=lambda item: tool_call_index_map.get(item.get("id", ""), len(tool_call_index_map))
|
||||
key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map))
|
||||
)
|
||||
|
||||
# Return generation data for caller
|
||||
|
||||
@@ -683,10 +683,6 @@ class WorkflowRun(Base):
|
||||
def outputs_dict(self) -> Mapping[str, Any]:
|
||||
return json.loads(self.outputs) if self.outputs else {}
|
||||
|
||||
@property
|
||||
def outputs_as_generation(self) -> bool:
|
||||
return is_generation_outputs(self.outputs_dict)
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
from .model import Message
|
||||
@@ -712,6 +708,7 @@ class WorkflowRun(Base):
|
||||
"inputs": self.inputs_dict,
|
||||
"status": self.status,
|
||||
"outputs": self.outputs_dict,
|
||||
"outputs_as_generation": is_generation_outputs(self.outputs_dict),
|
||||
"error": self.error,
|
||||
"elapsed_time": self.elapsed_time,
|
||||
"total_tokens": self.total_tokens,
|
||||
|
||||
@@ -102,14 +102,12 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param run_id: workflow run id
|
||||
"""
|
||||
workflow_run = self._workflow_run_repo.get_workflow_run_by_id(
|
||||
return self._workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def get_workflow_runs_count(
|
||||
self,
|
||||
app_model: App,
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import ToolResultStatus
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_events import ChunkType, NodeRunStreamChunkEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@@ -80,9 +86,11 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
chunk='{"query": "test"}',
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call_id="call_123",
|
||||
tool_name="search",
|
||||
tool_arguments='{"query": "test"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Tool result stream
|
||||
@@ -94,10 +102,13 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
chunk="Found 10 results",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_call_id="call_123",
|
||||
tool_name="search",
|
||||
tool_files=[],
|
||||
tool_error=None,
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="search",
|
||||
output="Found 10 results",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
)
|
||||
|
||||
# Intercept these events
|
||||
@@ -111,6 +122,14 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
|
||||
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
|
||||
|
||||
# Verify payloads are preserved in buffered events
|
||||
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
|
||||
assert buffered_call.tool_call is not None
|
||||
assert buffered_call.tool_call.id == "call_123"
|
||||
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
|
||||
assert buffered_result.tool_result is not None
|
||||
assert buffered_result.tool_result.status == "success"
|
||||
|
||||
# Verify we can find child streams
|
||||
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
|
||||
assert len(child_streams) == 3
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for StreamChunkEvent and its subclasses."""
|
||||
|
||||
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.node_events import (
|
||||
ChunkType,
|
||||
StreamChunkEvent,
|
||||
@@ -87,14 +88,13 @@ class TestToolCallChunkEvent:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_calls"]
|
||||
assert event.chunk == '{"city": "Beijing"}'
|
||||
assert event.tool_call_id == "call_123"
|
||||
assert event.tool_name == "weather"
|
||||
assert event.tool_call.id == "call_123"
|
||||
assert event.tool_call.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
|
||||
def test_chunk_type_is_tool_call(self):
|
||||
@@ -102,8 +102,7 @@ class TestToolCallChunkEvent:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
@@ -113,30 +112,34 @@ class TestToolCallChunkEvent:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"param": "value"}',
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_arguments='{"param": "value"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_arguments == '{"param": "value"}'
|
||||
assert event.tool_call.arguments == '{"param": "value"}'
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_arguments='{"city": "Beijing"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_call"
|
||||
assert data["tool_call_id"] == "call_123"
|
||||
assert data["tool_name"] == "weather"
|
||||
assert data["tool_arguments"] == '{"city": "Beijing"}'
|
||||
assert data["tool_call"]["id"] == "call_123"
|
||||
assert data["tool_call"]["name"] == "weather"
|
||||
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
@@ -148,14 +151,13 @@ class TestToolResultChunkEvent:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny, 25°C",
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_results"]
|
||||
assert event.chunk == "Weather: Sunny, 25°C"
|
||||
assert event.tool_call_id == "call_123"
|
||||
assert event.tool_name == "weather"
|
||||
assert event.tool_result.id == "call_123"
|
||||
assert event.tool_result.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
|
||||
def test_chunk_type_is_tool_result(self):
|
||||
@@ -163,8 +165,7 @@ class TestToolResultChunkEvent:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
@@ -174,55 +175,62 @@ class TestToolResultChunkEvent:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.tool_files == []
|
||||
assert event.tool_result.files == []
|
||||
|
||||
def test_tool_files_with_values(self):
|
||||
"""Test tool_files with file IDs."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_files=["file_1", "file_2"],
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
files=["file_1", "file_2"],
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_files == ["file_1", "file_2"]
|
||||
assert event.tool_result.files == ["file_1", "file_2"]
|
||||
|
||||
def test_tool_error_field(self):
|
||||
"""Test tool_error field."""
|
||||
def test_tool_error_output(self):
|
||||
"""Test error output captured in tool_result."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_error="Tool execution failed",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
output="Tool execution failed",
|
||||
status=ToolResultStatus.ERROR,
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_error == "Tool execution failed"
|
||||
assert event.tool_result.output == "Tool execution failed"
|
||||
assert event.tool_result.status == ToolResultStatus.ERROR
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny",
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_files=["file_1"],
|
||||
tool_error=None,
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
output="Weather: Sunny",
|
||||
files=["file_1"],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_result"
|
||||
assert data["tool_call_id"] == "call_123"
|
||||
assert data["tool_name"] == "weather"
|
||||
assert data["tool_files"] == ["file_1"]
|
||||
assert data["tool_error"] is None
|
||||
assert data["tool_result"]["id"] == "call_123"
|
||||
assert data["tool_result"]["name"] == "weather"
|
||||
assert data["tool_result"]["files"] == ["file_1"]
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
@@ -272,8 +280,7 @@ class TestEventInheritance:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test",
|
||||
tool_call=ToolCall(id="call_123", name="test", arguments=None),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
@@ -283,8 +290,7 @@ class TestEventInheritance:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test",
|
||||
tool_result=ToolResult(id="call_123", name="test"),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
@@ -302,8 +308,16 @@ class TestEventInheritance:
|
||||
"""Test that all events have common StreamChunkEvent fields."""
|
||||
events = [
|
||||
StreamChunkEvent(selector=["n", "t"], chunk="a"),
|
||||
ToolCallChunkEvent(selector=["n", "t"], chunk="b", tool_call_id="1", tool_name="t"),
|
||||
ToolResultChunkEvent(selector=["n", "t"], chunk="c", tool_call_id="1", tool_name="t"),
|
||||
ToolCallChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="b",
|
||||
tool_call=ToolCall(id="1", name="t", arguments=None),
|
||||
),
|
||||
ToolResultChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="c",
|
||||
tool_result=ToolResult(id="1", name="t"),
|
||||
),
|
||||
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user