refactor(llm node): remove unused args

This commit is contained in:
Novice
2025-12-17 15:42:23 +08:00
parent d3486cab31
commit 92fa7271ed
3 changed files with 502 additions and 454 deletions

View File

@@ -1,14 +1,17 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCallResult
from core.workflow.node_events import AgentLogEvent
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@@ -20,44 +23,6 @@ class ModelConfig(BaseModel):
completion_params: dict[str, Any] = Field(default_factory=dict)
class LLMTraceSegment(BaseModel):
"""
Streaming trace segment for LLM tool-enabled runs.
Order is preserved for replay. Tool calls are single entries containing both
arguments and results.
"""
type: Literal["thought", "content", "tool_call"]
# Common optional fields
text: str | None = Field(None, description="Text chunk for thought/content")
# Tool call fields (combined start + result)
tool_call: ToolCallResult | None = Field(
default=None,
description="Combined tool call arguments and result for this segment",
)
class LLMGenerationData(BaseModel):
"""Generation data from LLM invocation with tools.
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
- reasoning_contents: [thought1, thought2, ...] - one element per turn
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
"""
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
class ContextConfig(BaseModel):
enabled: bool
variable_selector: list[str] | None = None
@@ -124,6 +89,211 @@ class ToolMetadata(BaseModel):
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
class LLMTraceSegment(BaseModel):
"""
Streaming trace segment for LLM tool-enabled runs.
Order is preserved for replay. Tool calls are single entries containing both
arguments and results.
"""
type: Literal["thought", "content", "tool_call"]
# Common optional fields
text: str | None = Field(None, description="Text chunk for thought/content")
# Tool call fields (combined start + result)
tool_call: ToolCallResult | None = Field(
default=None,
description="Combined tool call arguments and result for this segment",
)
class LLMGenerationData(BaseModel):
"""Generation data from LLM invocation with tools.
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
- reasoning_contents: [thought1, thought2, ...] - one element per turn
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
"""
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
class ThinkTagStreamParser:
"""Lightweight state machine to split streaming chunks by <think> tags."""
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
_START_PREFIX = "<think"
_END_PREFIX = "</think"
def __init__(self):
self._buffer = ""
self._in_think = False
@staticmethod
def _suffix_prefix_len(text: str, prefix: str) -> int:
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
max_len = min(len(text), len(prefix) - 1)
for i in range(max_len, 0, -1):
if text[-i:].lower() == prefix[:i].lower():
return i
return 0
def process(self, chunk: str) -> list[tuple[str, str]]:
"""
Split incoming chunk into ('thought' | 'text', content) tuples.
Content excludes the <think> tags themselves and handles split tags across chunks.
"""
parts: list[tuple[str, str]] = []
self._buffer += chunk
while self._buffer:
if self._in_think:
end_match = self._END_PATTERN.search(self._buffer)
if end_match:
thought_text = self._buffer[: end_match.start()]
if thought_text:
parts.append(("thought", thought_text))
self._buffer = self._buffer[end_match.end() :]
self._in_think = False
continue
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("thought", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
start_match = self._START_PATTERN.search(self._buffer)
if start_match:
prefix = self._buffer[: start_match.start()]
if prefix:
parts.append(("text", prefix))
self._buffer = self._buffer[start_match.end() :]
self._in_think = True
continue
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("text", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
cleaned_parts: list[tuple[str, str]] = []
for kind, content in parts:
# Extra safeguard: strip any stray tags that slipped through.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
if content:
cleaned_parts.append((kind, content))
return cleaned_parts
def flush(self) -> list[tuple[str, str]]:
"""Flush remaining buffer when the stream ends."""
if not self._buffer:
return []
kind = "thought" if self._in_think else "text"
content = self._buffer
# Drop dangling partial tags instead of emitting them
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
content = ""
self._buffer = ""
if not content:
return []
# Strip any complete tags that might still be present.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
return [(kind, content)] if content else []
class StreamBuffers(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
pending_thought: list[str] = Field(default_factory=list)
pending_content: list[str] = Field(default_factory=list)
current_turn_reasoning: list[str] = Field(default_factory=list)
reasoning_per_turn: list[str] = Field(default_factory=list)
class TraceState(BaseModel):
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
class AggregatedResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
text: str = ""
files: list[File] = Field(default_factory=list)
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
finish_reason: str | None = None
class AgentContext(BaseModel):
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
agent_result: AgentResult | None = None
class ToolOutputState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
stream: StreamBuffers = Field(default_factory=StreamBuffers)
trace: TraceState = Field(default_factory=TraceState)
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
agent: AgentContext = Field(default_factory=AgentContext)
class ToolLogPayload(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
tool_name: str = ""
tool_call_id: str = ""
tool_args: dict[str, Any] = Field(default_factory=dict)
tool_output: Any = None
tool_error: Any = None
files: list[Any] = Field(default_factory=list)
meta: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
data = log.data or {}
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
@classmethod
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate

View File

@@ -1,4 +1,3 @@
import re
from collections.abc import Sequence
from typing import cast
@@ -155,94 +154,3 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
)
session.execute(stmt)
session.commit()
class ThinkTagStreamParser:
"""Lightweight state machine to split streaming chunks by <think> tags."""
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
_START_PREFIX = "<think"
_END_PREFIX = "</think"
def __init__(self):
self._buffer = ""
self._in_think = False
@staticmethod
def _suffix_prefix_len(text: str, prefix: str) -> int:
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
max_len = min(len(text), len(prefix) - 1)
for i in range(max_len, 0, -1):
if text[-i:].lower() == prefix[:i].lower():
return i
return 0
def process(self, chunk: str) -> list[tuple[str, str]]:
"""
Split incoming chunk into ('thought' | 'text', content) tuples.
Content excludes the <think> tags themselves and handles split tags across chunks.
"""
parts: list[tuple[str, str]] = []
self._buffer += chunk
while self._buffer:
if self._in_think:
end_match = self._END_PATTERN.search(self._buffer)
if end_match:
thought_text = self._buffer[: end_match.start()]
if thought_text:
parts.append(("thought", thought_text))
self._buffer = self._buffer[end_match.end() :]
self._in_think = False
continue
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("thought", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
start_match = self._START_PATTERN.search(self._buffer)
if start_match:
prefix = self._buffer[: start_match.start()]
if prefix:
parts.append(("text", prefix))
self._buffer = self._buffer[start_match.end() :]
self._in_think = True
continue
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("text", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
cleaned_parts: list[tuple[str, str]] = []
for kind, content in parts:
# Extra safeguard: strip any stray tags that slipped through.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
if content:
cleaned_parts.append((kind, content))
return cleaned_parts
def flush(self) -> list[tuple[str, str]]:
"""Flush remaining buffer when the stream ends."""
if not self._buffer:
return []
kind = "thought" if self._in_think else "text"
content = self._buffer
# Drop dangling partial tags instead of emitting them
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
content = ""
self._buffer = ""
if not content:
return []
# Strip any complete tags that might still be present.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
return [(kind, content)] if content else []

View File

@@ -90,12 +90,19 @@ from models.model import UploadFile
from . import llm_utils
from .entities import (
AgentContext,
AggregatedResult,
LLMGenerationData,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
LLMTraceSegment,
ModelConfig,
StreamBuffers,
ThinkTagStreamParser,
ToolLogPayload,
ToolOutputState,
TraceState,
)
from .exc import (
InvalidContextStructureError,
@@ -582,7 +589,7 @@ class LLMNode(Node[LLMNodeData]):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
think_parser = llm_utils.ThinkTagStreamParser()
think_parser = ThinkTagStreamParser()
reasoning_chunks: list[str] = []
# Initialize streaming metrics tracking
@@ -1495,7 +1502,7 @@ class LLMNode(Node[LLMNodeData]):
)
# Process outputs and return generation result
result = yield from self._process_tool_outputs(outputs, strategy, node_inputs, process_data)
result = yield from self._process_tool_outputs(outputs)
return result
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
@@ -1587,278 +1594,213 @@ class LLMNode(Node[LLMNodeData]):
return files
def _process_tool_outputs(
self,
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
strategy: Any,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Process strategy outputs and convert to node events.
def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
if not buffers.pending_thought:
return
trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought)))
buffers.pending_thought.clear()
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
"""
text = ""
files: list[File] = []
usage = LLMUsage.empty_usage()
agent_logs: list[AgentLogEvent] = []
finish_reason = None
agent_result: AgentResult | None = None
def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
if not buffers.pending_content:
return
trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content)))
buffers.pending_content.clear()
think_parser = llm_utils.ThinkTagStreamParser()
# Track reasoning per turn: each tool_call completion marks end of a turn
current_turn_reasoning: list[str] = [] # Buffer for current turn's thought chunks
reasoning_per_turn: list[str] = [] # Final list: one element per turn
tool_call_index_map: dict[str, int] = {} # tool_call_id -> index
trace_segments: list[LLMTraceSegment] = [] # Ordered trace for replay
tool_trace_map: dict[str, LLMTraceSegment] = {}
current_turn = 0
pending_thought: list[str] = []
pending_content: list[str] = []
def _handle_agent_log_output(
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
) -> Generator[NodeEventBase, None, None]:
payload = ToolLogPayload.from_log(output)
agent_log_event = AgentLogEvent(
message_id=output.id,
label=output.label,
node_execution_id=self.id,
parent_id=output.parent_id,
error=output.error,
status=output.status.value,
data=output.data,
metadata={k.value: v for k, v in output.metadata.items()},
node_id=self._node_id,
)
for log in agent_context.agent_logs:
if log.message_id == agent_log_event.message_id:
log.data = agent_log_event.data
log.status = agent_log_event.status
log.error = agent_log_event.error
log.label = agent_log_event.label
log.metadata = agent_log_event.metadata
break
else:
agent_context.agent_logs.append(agent_log_event)
def _flush_thought() -> None:
if not pending_thought:
return
trace_segments.append(LLMTraceSegment(type="thought", text="".join(pending_thought)))
pending_thought.clear()
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
tool_name = payload.tool_name
tool_call_id = payload.tool_call_id
tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else ""
def _flush_content() -> None:
if not pending_content:
return
trace_segments.append(LLMTraceSegment(type="content", text="".join(pending_content)))
pending_content.clear()
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
# Process each output from strategy
try:
for output in outputs:
if isinstance(output, AgentLog):
# Store agent log event for metadata (no longer yielded, StreamChunkEvent contains the info)
agent_log_event = AgentLogEvent(
message_id=output.id,
label=output.label,
node_execution_id=self.id,
parent_id=output.parent_id,
error=output.error,
status=output.status.value,
data=output.data,
metadata={k.value: v for k, v in output.metadata.items()},
node_id=self._node_id,
self._flush_thought_segment(buffers, trace_state)
self._flush_content_segment(buffers, trace_state)
tool_call_segment = LLMTraceSegment(
type="tool_call",
text=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
)
trace_state.trace_segments.append(tool_call_segment)
if tool_call_id:
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
is_final=False,
)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = payload.tool_name
tool_output = payload.tool_output
tool_call_id = payload.tool_call_id
tool_files = payload.files if isinstance(payload.files, list) else []
tool_error = payload.tool_error
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
self._flush_thought_segment(buffers, trace_state)
self._flush_content_segment(buffers, trace_state)
if output.status == AgentLog.LogStatus.ERROR:
tool_error = output.error or payload.tool_error
if not tool_error and payload.meta:
tool_error = payload.meta.get("error")
else:
if payload.meta:
meta_error = payload.meta.get("error")
if meta_error:
tool_error = meta_error
existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id)
tool_call_segment = existing_tool_segment or LLMTraceSegment(
type="tool_call",
text=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
),
)
if existing_tool_segment is None:
trace_state.trace_segments.append(tool_call_segment)
if tool_call_id:
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
if tool_call_segment.tool_call is None:
tool_call_segment.tool_call = ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
)
tool_call_segment.tool_call.output = (
str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None
)
tool_call_segment.tool_call.files = []
tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=result_output or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_output,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
),
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
buffers.current_turn_reasoning.clear()
def _handle_llm_chunk_output(
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
message = output.delta.message
if message and message.content:
chunk_text = message.content
if isinstance(chunk_text, list):
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
else:
chunk_text = str(chunk_text)
for kind, segment in buffers.think_parser.process(chunk_text):
if not segment:
continue
if kind == "thought":
self._flush_content_segment(buffers, trace_state)
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
self._flush_thought_segment(buffers, trace_state)
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
for log in agent_logs:
if log.message_id == agent_log_event.message_id:
# update the log
log.data = agent_log_event.data
log.status = agent_log_event.status
log.error = agent_log_event.error
log.label = agent_log_event.label
log.metadata = agent_log_event.metadata
break
else:
agent_logs.append(agent_log_event)
# Emit tool call events when tool call starts
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
tool_name = output.data.get("tool_name", "")
tool_call_id = output.data.get("tool_call_id", "")
tool_args = output.data.get("tool_args", {})
tool_arguments = json.dumps(tool_args) if tool_args else ""
if output.delta.usage:
self._accumulate_usage(aggregate.usage, output.delta.usage)
if tool_call_id and tool_call_id not in tool_call_index_map:
tool_call_index_map[tool_call_id] = len(tool_call_index_map)
if output.delta.finish_reason:
aggregate.finish_reason = output.delta.finish_reason
_flush_thought()
_flush_content()
tool_call_segment = LLMTraceSegment(
type="tool_call",
text=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
)
trace_segments.append(tool_call_segment)
if tool_call_id:
tool_trace_map[tool_call_id] = tool_call_segment
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
is_final=False,
)
# Emit tool result events when tool call completes (both success and error)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = output.data.get("tool_name", "")
tool_output = output.data.get("output", "")
tool_call_id = output.data.get("tool_call_id", "")
tool_files = []
tool_error = None
if tool_call_id and tool_call_id not in tool_call_index_map:
tool_call_index_map[tool_call_id] = len(tool_call_index_map)
_flush_thought()
_flush_content()
# Extract file IDs if present (only for success case)
files_data = output.data.get("files")
if files_data and isinstance(files_data, list):
tool_files = files_data
# Check for error from multiple sources
if output.status == AgentLog.LogStatus.ERROR:
# Priority: output.error > data.error > meta.error
tool_error = output.error or output.data.get("error")
meta = output.data.get("meta")
if not tool_error and meta and isinstance(meta, dict):
tool_error = meta.get("error")
else:
# For success case, check meta for potential errors
meta = output.data.get("meta")
if meta and isinstance(meta, dict) and meta.get("error"):
tool_error = meta.get("error")
existing_tool_segment = tool_trace_map.get(tool_call_id)
tool_call_segment = existing_tool_segment or LLMTraceSegment(
type="tool_call",
text=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
),
)
if existing_tool_segment is None:
trace_segments.append(tool_call_segment)
if tool_call_id:
tool_trace_map[tool_call_id] = tool_call_segment
if tool_call_segment.tool_call is None:
tool_call_segment.tool_call = ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
)
tool_call_segment.tool_call.output = (
str(tool_output)
if tool_output is not None
else str(tool_error)
if tool_error is not None
else None
)
tool_call_segment.tool_call.files = []
tool_call_segment.tool_call.status = (
ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
)
current_turn += 1
result_output = (
str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=result_output or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_output,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
),
is_final=False,
)
# End of current turn: save accumulated thought as one element
if current_turn_reasoning:
reasoning_per_turn.append("".join(current_turn_reasoning))
current_turn_reasoning.clear()
elif isinstance(output, LLMResultChunk):
# Handle LLM result chunks - only process text content
message = output.delta.message
# Handle text content
if message and message.content:
chunk_text = message.content
if isinstance(chunk_text, list):
# Extract text from content list
chunk_text = "".join(getattr(c, "data", str(c)) for c in chunk_text)
else:
chunk_text = str(chunk_text)
for kind, segment in think_parser.process(chunk_text):
if not segment:
continue
if kind == "thought":
_flush_content()
current_turn_reasoning.append(segment)
pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
_flush_thought()
text += segment
pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if output.delta.usage:
self._accumulate_usage(usage, output.delta.usage)
# Capture finish reason
if output.delta.finish_reason:
finish_reason = output.delta.finish_reason
except StopIteration as e:
# Get the return value from generator
if isinstance(getattr(e, "value", None), AgentResult):
agent_result = e.value
# Use result from generator if available
if agent_result:
text = agent_result.text or text
files = agent_result.files
if agent_result.usage:
usage = agent_result.usage
if agent_result.finish_reason:
finish_reason = agent_result.finish_reason
# Flush any remaining buffered content after streaming ends
for kind, segment in think_parser.flush():
def _flush_remaining_stream(
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
for kind, segment in buffers.think_parser.flush():
if not segment:
continue
if kind == "thought":
_flush_content()
current_turn_reasoning.append(segment)
pending_thought.append(segment)
self._flush_content_segment(buffers, trace_state)
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
_flush_thought()
text += segment
pending_content.append(segment)
self._flush_thought_segment(buffers, trace_state)
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
@@ -1870,19 +1812,63 @@ class LLMNode(Node[LLMNodeData]):
is_final=False,
)
# Save the last turn's thought if any
if current_turn_reasoning:
reasoning_per_turn.append("".join(current_turn_reasoning))
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
_flush_thought()
_flush_content()
self._flush_thought_segment(buffers, trace_state)
self._flush_content_segment(buffers, trace_state)
# Build sequence from trace_segments for rendering
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
def _build_generation_data(
self,
trace_state: TraceState,
agent_context: AgentContext,
aggregate: AggregatedResult,
buffers: StreamBuffers,
) -> LLMGenerationData:
sequence: list[dict[str, Any]] = []
reasoning_index = 0
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_segments:
for trace_segment in trace_state.trace_segments:
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
@@ -1898,67 +1884,22 @@ class LLMNode(Node[LLMNodeData]):
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
# Send final events for all streams
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
# Close generation sub-field streams
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
# Close tool_calls stream (already sent via ToolCallChunkEvent)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
# Close tool_results stream (already sent via ToolResultChunkEvent)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
# Build tool_calls from agent_logs (with results)
tool_calls_for_generation: list[ToolCallResult] = []
for log in agent_logs:
tool_call_id = log.data.get("tool_call_id")
for log in agent_context.agent_logs:
payload = ToolLogPayload.from_mapping(log.data or {})
tool_call_id = payload.tool_call_id
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
continue
tool_args = log.data.get("tool_args") or {}
log_error = log.data.get("error")
log_output = log.data.get("output")
tool_args = payload.tool_args
log_error = payload.tool_error
log_output = payload.tool_output
result_text = log_output or log_error or ""
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
tool_calls_for_generation.append(
ToolCallResult(
id=tool_call_id,
name=log.data.get("tool_name", ""),
name=payload.tool_name,
arguments=json.dumps(tool_args) if tool_args else "",
output=result_text,
status=status,
@@ -1966,21 +1907,50 @@ class LLMNode(Node[LLMNodeData]):
)
tool_calls_for_generation.sort(
key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map))
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
)
# Return generation data for caller
return LLMGenerationData(
text=text,
reasoning_contents=reasoning_per_turn,
text=aggregate.text,
reasoning_contents=buffers.reasoning_per_turn,
tool_calls=tool_calls_for_generation,
sequence=sequence,
usage=usage,
finish_reason=finish_reason,
files=files,
trace=trace_segments,
usage=aggregate.usage,
finish_reason=aggregate.finish_reason,
files=aggregate.files,
trace=trace_state.trace_segments,
)
def _process_tool_outputs(
self,
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Process strategy outputs and convert to node events."""
state = ToolOutputState()
try:
for output in outputs:
if isinstance(output, AgentLog):
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
else:
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
except StopIteration as exception:
if isinstance(getattr(exception, "value", None), AgentResult):
state.agent.agent_result = exception.value
if state.agent.agent_result:
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
state.aggregate.usage = state.agent.agent_result.usage
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
total_usage.prompt_tokens += delta_usage.prompt_tokens