mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
refactor(llm node): remove unused args
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCallResult
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@@ -20,44 +23,6 @@ class ModelConfig(BaseModel):
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class LLMTraceSegment(BaseModel):
|
||||
"""
|
||||
Streaming trace segment for LLM tool-enabled runs.
|
||||
|
||||
Order is preserved for replay. Tool calls are single entries containing both
|
||||
arguments and results.
|
||||
"""
|
||||
|
||||
type: Literal["thought", "content", "tool_call"]
|
||||
|
||||
# Common optional fields
|
||||
text: str | None = Field(None, description="Text chunk for thought/content")
|
||||
|
||||
# Tool call fields (combined start + result)
|
||||
tool_call: ToolCallResult | None = Field(
|
||||
default=None,
|
||||
description="Combined tool call arguments and result for this segment",
|
||||
)
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
"""Generation data from LLM invocation with tools.
|
||||
|
||||
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
|
||||
- reasoning_contents: [thought1, thought2, ...] - one element per turn
|
||||
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: list[str] | None = None
|
||||
@@ -124,6 +89,211 @@ class ToolMetadata(BaseModel):
|
||||
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
|
||||
|
||||
|
||||
class LLMTraceSegment(BaseModel):
|
||||
"""
|
||||
Streaming trace segment for LLM tool-enabled runs.
|
||||
|
||||
Order is preserved for replay. Tool calls are single entries containing both
|
||||
arguments and results.
|
||||
"""
|
||||
|
||||
type: Literal["thought", "content", "tool_call"]
|
||||
|
||||
# Common optional fields
|
||||
text: str | None = Field(None, description="Text chunk for thought/content")
|
||||
|
||||
# Tool call fields (combined start + result)
|
||||
tool_call: ToolCallResult | None = Field(
|
||||
default=None,
|
||||
description="Combined tool call arguments and result for this segment",
|
||||
)
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
"""Generation data from LLM invocation with tools.
|
||||
|
||||
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
|
||||
- reasoning_contents: [thought1, thought2, ...] - one element per turn
|
||||
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
"""Lightweight state machine to split streaming chunks by <think> tags."""
|
||||
|
||||
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
|
||||
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
||||
_START_PREFIX = "<think"
|
||||
_END_PREFIX = "</think"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_think = False
|
||||
|
||||
@staticmethod
|
||||
def _suffix_prefix_len(text: str, prefix: str) -> int:
|
||||
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
|
||||
max_len = min(len(text), len(prefix) - 1)
|
||||
for i in range(max_len, 0, -1):
|
||||
if text[-i:].lower() == prefix[:i].lower():
|
||||
return i
|
||||
return 0
|
||||
|
||||
def process(self, chunk: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Split incoming chunk into ('thought' | 'text', content) tuples.
|
||||
Content excludes the <think> tags themselves and handles split tags across chunks.
|
||||
"""
|
||||
parts: list[tuple[str, str]] = []
|
||||
self._buffer += chunk
|
||||
|
||||
while self._buffer:
|
||||
if self._in_think:
|
||||
end_match = self._END_PATTERN.search(self._buffer)
|
||||
if end_match:
|
||||
thought_text = self._buffer[: end_match.start()]
|
||||
if thought_text:
|
||||
parts.append(("thought", thought_text))
|
||||
self._buffer = self._buffer[end_match.end() :]
|
||||
self._in_think = False
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("thought", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
start_match = self._START_PATTERN.search(self._buffer)
|
||||
if start_match:
|
||||
prefix = self._buffer[: start_match.start()]
|
||||
if prefix:
|
||||
parts.append(("text", prefix))
|
||||
self._buffer = self._buffer[start_match.end() :]
|
||||
self._in_think = True
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("text", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
cleaned_parts: list[tuple[str, str]] = []
|
||||
for kind, content in parts:
|
||||
# Extra safeguard: strip any stray tags that slipped through.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
if content:
|
||||
cleaned_parts.append((kind, content))
|
||||
|
||||
return cleaned_parts
|
||||
|
||||
def flush(self) -> list[tuple[str, str]]:
|
||||
"""Flush remaining buffer when the stream ends."""
|
||||
if not self._buffer:
|
||||
return []
|
||||
kind = "thought" if self._in_think else "text"
|
||||
content = self._buffer
|
||||
# Drop dangling partial tags instead of emitting them
|
||||
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
|
||||
content = ""
|
||||
self._buffer = ""
|
||||
if not content:
|
||||
return []
|
||||
# Strip any complete tags that might still be present.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
return [(kind, content)] if content else []
|
||||
|
||||
|
||||
class StreamBuffers(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
|
||||
pending_thought: list[str] = Field(default_factory=list)
|
||||
pending_content: list[str] = Field(default_factory=list)
|
||||
current_turn_reasoning: list[str] = Field(default_factory=list)
|
||||
reasoning_per_turn: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TraceState(BaseModel):
|
||||
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
|
||||
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
|
||||
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AggregatedResult(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
text: str = ""
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
|
||||
agent_result: AgentResult | None = None
|
||||
|
||||
|
||||
class ToolOutputState(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
stream: StreamBuffers = Field(default_factory=StreamBuffers)
|
||||
trace: TraceState = Field(default_factory=TraceState)
|
||||
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
|
||||
agent: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class ToolLogPayload(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
tool_name: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_output: Any = None
|
||||
tool_error: Any = None
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
|
||||
data = log.data or {}
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
@@ -155,94 +154,3 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
"""Lightweight state machine to split streaming chunks by <think> tags."""
|
||||
|
||||
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
|
||||
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
||||
_START_PREFIX = "<think"
|
||||
_END_PREFIX = "</think"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_think = False
|
||||
|
||||
@staticmethod
|
||||
def _suffix_prefix_len(text: str, prefix: str) -> int:
|
||||
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
|
||||
max_len = min(len(text), len(prefix) - 1)
|
||||
for i in range(max_len, 0, -1):
|
||||
if text[-i:].lower() == prefix[:i].lower():
|
||||
return i
|
||||
return 0
|
||||
|
||||
def process(self, chunk: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Split incoming chunk into ('thought' | 'text', content) tuples.
|
||||
Content excludes the <think> tags themselves and handles split tags across chunks.
|
||||
"""
|
||||
parts: list[tuple[str, str]] = []
|
||||
self._buffer += chunk
|
||||
|
||||
while self._buffer:
|
||||
if self._in_think:
|
||||
end_match = self._END_PATTERN.search(self._buffer)
|
||||
if end_match:
|
||||
thought_text = self._buffer[: end_match.start()]
|
||||
if thought_text:
|
||||
parts.append(("thought", thought_text))
|
||||
self._buffer = self._buffer[end_match.end() :]
|
||||
self._in_think = False
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("thought", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
start_match = self._START_PATTERN.search(self._buffer)
|
||||
if start_match:
|
||||
prefix = self._buffer[: start_match.start()]
|
||||
if prefix:
|
||||
parts.append(("text", prefix))
|
||||
self._buffer = self._buffer[start_match.end() :]
|
||||
self._in_think = True
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("text", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
cleaned_parts: list[tuple[str, str]] = []
|
||||
for kind, content in parts:
|
||||
# Extra safeguard: strip any stray tags that slipped through.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
if content:
|
||||
cleaned_parts.append((kind, content))
|
||||
|
||||
return cleaned_parts
|
||||
|
||||
def flush(self) -> list[tuple[str, str]]:
|
||||
"""Flush remaining buffer when the stream ends."""
|
||||
if not self._buffer:
|
||||
return []
|
||||
kind = "thought" if self._in_think else "text"
|
||||
content = self._buffer
|
||||
# Drop dangling partial tags instead of emitting them
|
||||
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
|
||||
content = ""
|
||||
self._buffer = ""
|
||||
if not content:
|
||||
return []
|
||||
# Strip any complete tags that might still be present.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
return [(kind, content)] if content else []
|
||||
|
||||
@@ -90,12 +90,19 @@ from models.model import UploadFile
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
AgentContext,
|
||||
AggregatedResult,
|
||||
LLMGenerationData,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
LLMTraceSegment,
|
||||
ModelConfig,
|
||||
StreamBuffers,
|
||||
ThinkTagStreamParser,
|
||||
ToolLogPayload,
|
||||
ToolOutputState,
|
||||
TraceState,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -582,7 +589,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
think_parser = llm_utils.ThinkTagStreamParser()
|
||||
think_parser = ThinkTagStreamParser()
|
||||
reasoning_chunks: list[str] = []
|
||||
|
||||
# Initialize streaming metrics tracking
|
||||
@@ -1495,7 +1502,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
# Process outputs and return generation result
|
||||
result = yield from self._process_tool_outputs(outputs, strategy, node_inputs, process_data)
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
return result
|
||||
|
||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||
@@ -1587,278 +1594,213 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return files
|
||||
|
||||
def _process_tool_outputs(
|
||||
self,
|
||||
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
||||
strategy: Any,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Process strategy outputs and convert to node events.
|
||||
def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
|
||||
if not buffers.pending_thought:
|
||||
return
|
||||
trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought)))
|
||||
buffers.pending_thought.clear()
|
||||
|
||||
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
|
||||
"""
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
usage = LLMUsage.empty_usage()
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
finish_reason = None
|
||||
agent_result: AgentResult | None = None
|
||||
def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
|
||||
if not buffers.pending_content:
|
||||
return
|
||||
trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content)))
|
||||
buffers.pending_content.clear()
|
||||
|
||||
think_parser = llm_utils.ThinkTagStreamParser()
|
||||
# Track reasoning per turn: each tool_call completion marks end of a turn
|
||||
current_turn_reasoning: list[str] = [] # Buffer for current turn's thought chunks
|
||||
reasoning_per_turn: list[str] = [] # Final list: one element per turn
|
||||
tool_call_index_map: dict[str, int] = {} # tool_call_id -> index
|
||||
trace_segments: list[LLMTraceSegment] = [] # Ordered trace for replay
|
||||
tool_trace_map: dict[str, LLMTraceSegment] = {}
|
||||
current_turn = 0
|
||||
pending_thought: list[str] = []
|
||||
pending_content: list[str] = []
|
||||
def _handle_agent_log_output(
|
||||
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
payload = ToolLogPayload.from_log(output)
|
||||
agent_log_event = AgentLogEvent(
|
||||
message_id=output.id,
|
||||
label=output.label,
|
||||
node_execution_id=self.id,
|
||||
parent_id=output.parent_id,
|
||||
error=output.error,
|
||||
status=output.status.value,
|
||||
data=output.data,
|
||||
metadata={k.value: v for k, v in output.metadata.items()},
|
||||
node_id=self._node_id,
|
||||
)
|
||||
for log in agent_context.agent_logs:
|
||||
if log.message_id == agent_log_event.message_id:
|
||||
log.data = agent_log_event.data
|
||||
log.status = agent_log_event.status
|
||||
log.error = agent_log_event.error
|
||||
log.label = agent_log_event.label
|
||||
log.metadata = agent_log_event.metadata
|
||||
break
|
||||
else:
|
||||
agent_context.agent_logs.append(agent_log_event)
|
||||
|
||||
def _flush_thought() -> None:
|
||||
if not pending_thought:
|
||||
return
|
||||
trace_segments.append(LLMTraceSegment(type="thought", text="".join(pending_thought)))
|
||||
pending_thought.clear()
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
tool_call_id = payload.tool_call_id
|
||||
tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else ""
|
||||
|
||||
def _flush_content() -> None:
|
||||
if not pending_content:
|
||||
return
|
||||
trace_segments.append(LLMTraceSegment(type="content", text="".join(pending_content)))
|
||||
pending_content.clear()
|
||||
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
||||
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
||||
|
||||
# Process each output from strategy
|
||||
try:
|
||||
for output in outputs:
|
||||
if isinstance(output, AgentLog):
|
||||
# Store agent log event for metadata (no longer yielded, StreamChunkEvent contains the info)
|
||||
agent_log_event = AgentLogEvent(
|
||||
message_id=output.id,
|
||||
label=output.label,
|
||||
node_execution_id=self.id,
|
||||
parent_id=output.parent_id,
|
||||
error=output.error,
|
||||
status=output.status.value,
|
||||
data=output.data,
|
||||
metadata={k.value: v for k, v in output.metadata.items()},
|
||||
node_id=self._node_id,
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
tool_call_segment = LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
)
|
||||
trace_state.trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
tool_output = payload.tool_output
|
||||
tool_call_id = payload.tool_call_id
|
||||
tool_files = payload.files if isinstance(payload.files, list) else []
|
||||
tool_error = payload.tool_error
|
||||
|
||||
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
||||
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
if output.status == AgentLog.LogStatus.ERROR:
|
||||
tool_error = output.error or payload.tool_error
|
||||
if not tool_error and payload.meta:
|
||||
tool_error = payload.meta.get("error")
|
||||
else:
|
||||
if payload.meta:
|
||||
meta_error = payload.meta.get("error")
|
||||
if meta_error:
|
||||
tool_error = meta_error
|
||||
|
||||
existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id)
|
||||
tool_call_segment = existing_tool_segment or LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
if existing_tool_segment is None:
|
||||
trace_state.trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
if tool_call_segment.tool_call is None:
|
||||
tool_call_segment.tool_call = ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
)
|
||||
tool_call_segment.tool_call.output = (
|
||||
str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None
|
||||
)
|
||||
tool_call_segment.tool_call.files = []
|
||||
tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
|
||||
|
||||
result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
|
||||
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk=result_output or "",
|
||||
tool_result=ToolResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
output=result_output,
|
||||
files=tool_files,
|
||||
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if buffers.current_turn_reasoning:
|
||||
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
||||
buffers.current_turn_reasoning.clear()
|
||||
|
||||
def _handle_llm_chunk_output(
|
||||
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
message = output.delta.message
|
||||
|
||||
if message and message.content:
|
||||
chunk_text = message.content
|
||||
if isinstance(chunk_text, list):
|
||||
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
|
||||
else:
|
||||
chunk_text = str(chunk_text)
|
||||
|
||||
for kind, segment in buffers.think_parser.process(chunk_text):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
buffers.current_turn_reasoning.append(segment)
|
||||
buffers.pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
aggregate.text += segment
|
||||
buffers.pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log_event.message_id:
|
||||
# update the log
|
||||
log.data = agent_log_event.data
|
||||
log.status = agent_log_event.status
|
||||
log.error = agent_log_event.error
|
||||
log.label = agent_log_event.label
|
||||
log.metadata = agent_log_event.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log_event)
|
||||
|
||||
# Emit tool call events when tool call starts
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
|
||||
tool_name = output.data.get("tool_name", "")
|
||||
tool_call_id = output.data.get("tool_call_id", "")
|
||||
tool_args = output.data.get("tool_args", {})
|
||||
tool_arguments = json.dumps(tool_args) if tool_args else ""
|
||||
if output.delta.usage:
|
||||
self._accumulate_usage(aggregate.usage, output.delta.usage)
|
||||
|
||||
if tool_call_id and tool_call_id not in tool_call_index_map:
|
||||
tool_call_index_map[tool_call_id] = len(tool_call_index_map)
|
||||
if output.delta.finish_reason:
|
||||
aggregate.finish_reason = output.delta.finish_reason
|
||||
|
||||
_flush_thought()
|
||||
_flush_content()
|
||||
|
||||
tool_call_segment = LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
)
|
||||
trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Emit tool result events when tool call completes (both success and error)
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||
tool_name = output.data.get("tool_name", "")
|
||||
tool_output = output.data.get("output", "")
|
||||
tool_call_id = output.data.get("tool_call_id", "")
|
||||
tool_files = []
|
||||
tool_error = None
|
||||
|
||||
if tool_call_id and tool_call_id not in tool_call_index_map:
|
||||
tool_call_index_map[tool_call_id] = len(tool_call_index_map)
|
||||
|
||||
_flush_thought()
|
||||
_flush_content()
|
||||
|
||||
# Extract file IDs if present (only for success case)
|
||||
files_data = output.data.get("files")
|
||||
if files_data and isinstance(files_data, list):
|
||||
tool_files = files_data
|
||||
|
||||
# Check for error from multiple sources
|
||||
if output.status == AgentLog.LogStatus.ERROR:
|
||||
# Priority: output.error > data.error > meta.error
|
||||
tool_error = output.error or output.data.get("error")
|
||||
meta = output.data.get("meta")
|
||||
if not tool_error and meta and isinstance(meta, dict):
|
||||
tool_error = meta.get("error")
|
||||
else:
|
||||
# For success case, check meta for potential errors
|
||||
meta = output.data.get("meta")
|
||||
if meta and isinstance(meta, dict) and meta.get("error"):
|
||||
tool_error = meta.get("error")
|
||||
|
||||
existing_tool_segment = tool_trace_map.get(tool_call_id)
|
||||
tool_call_segment = existing_tool_segment or LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
if existing_tool_segment is None:
|
||||
trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
if tool_call_segment.tool_call is None:
|
||||
tool_call_segment.tool_call = ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
)
|
||||
tool_call_segment.tool_call.output = (
|
||||
str(tool_output)
|
||||
if tool_output is not None
|
||||
else str(tool_error)
|
||||
if tool_error is not None
|
||||
else None
|
||||
)
|
||||
tool_call_segment.tool_call.files = []
|
||||
tool_call_segment.tool_call.status = (
|
||||
ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
|
||||
)
|
||||
current_turn += 1
|
||||
|
||||
result_output = (
|
||||
str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
|
||||
)
|
||||
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk=result_output or "",
|
||||
tool_result=ToolResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
output=result_output,
|
||||
files=tool_files,
|
||||
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# End of current turn: save accumulated thought as one element
|
||||
if current_turn_reasoning:
|
||||
reasoning_per_turn.append("".join(current_turn_reasoning))
|
||||
current_turn_reasoning.clear()
|
||||
|
||||
elif isinstance(output, LLMResultChunk):
|
||||
# Handle LLM result chunks - only process text content
|
||||
message = output.delta.message
|
||||
|
||||
# Handle text content
|
||||
if message and message.content:
|
||||
chunk_text = message.content
|
||||
if isinstance(chunk_text, list):
|
||||
# Extract text from content list
|
||||
chunk_text = "".join(getattr(c, "data", str(c)) for c in chunk_text)
|
||||
else:
|
||||
chunk_text = str(chunk_text)
|
||||
for kind, segment in think_parser.process(chunk_text):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
_flush_content()
|
||||
current_turn_reasoning.append(segment)
|
||||
pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
_flush_thought()
|
||||
text += segment
|
||||
pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.delta.usage:
|
||||
self._accumulate_usage(usage, output.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if output.delta.finish_reason:
|
||||
finish_reason = output.delta.finish_reason
|
||||
|
||||
except StopIteration as e:
|
||||
# Get the return value from generator
|
||||
if isinstance(getattr(e, "value", None), AgentResult):
|
||||
agent_result = e.value
|
||||
|
||||
# Use result from generator if available
|
||||
if agent_result:
|
||||
text = agent_result.text or text
|
||||
files = agent_result.files
|
||||
if agent_result.usage:
|
||||
usage = agent_result.usage
|
||||
if agent_result.finish_reason:
|
||||
finish_reason = agent_result.finish_reason
|
||||
|
||||
# Flush any remaining buffered content after streaming ends
|
||||
for kind, segment in think_parser.flush():
|
||||
def _flush_remaining_stream(
|
||||
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
for kind, segment in buffers.think_parser.flush():
|
||||
if not segment:
|
||||
continue
|
||||
if kind == "thought":
|
||||
_flush_content()
|
||||
current_turn_reasoning.append(segment)
|
||||
pending_thought.append(segment)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
buffers.current_turn_reasoning.append(segment)
|
||||
buffers.pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
_flush_thought()
|
||||
text += segment
|
||||
pending_content.append(segment)
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
aggregate.text += segment
|
||||
buffers.pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
@@ -1870,19 +1812,63 @@ class LLMNode(Node[LLMNodeData]):
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Save the last turn's thought if any
|
||||
if current_turn_reasoning:
|
||||
reasoning_per_turn.append("".join(current_turn_reasoning))
|
||||
if buffers.current_turn_reasoning:
|
||||
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
||||
|
||||
_flush_thought()
|
||||
_flush_content()
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
# Build sequence from trace_segments for rendering
|
||||
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(
|
||||
id="",
|
||||
name="",
|
||||
arguments="",
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk="",
|
||||
tool_result=ToolResult(
|
||||
id="",
|
||||
name="",
|
||||
output="",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
def _build_generation_data(
|
||||
self,
|
||||
trace_state: TraceState,
|
||||
agent_context: AgentContext,
|
||||
aggregate: AggregatedResult,
|
||||
buffers: StreamBuffers,
|
||||
) -> LLMGenerationData:
|
||||
sequence: list[dict[str, Any]] = []
|
||||
reasoning_index = 0
|
||||
content_position = 0
|
||||
tool_call_seen_index: dict[str, int] = {}
|
||||
for trace_segment in trace_segments:
|
||||
for trace_segment in trace_state.trace_segments:
|
||||
if trace_segment.type == "thought":
|
||||
sequence.append({"type": "reasoning", "index": reasoning_index})
|
||||
reasoning_index += 1
|
||||
@@ -1898,67 +1884,22 @@ class LLMNode(Node[LLMNodeData]):
|
||||
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
|
||||
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
|
||||
|
||||
# Send final events for all streams
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Close generation sub-field streams
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Close tool_calls stream (already sent via ToolCallChunkEvent)
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(
|
||||
id="",
|
||||
name="",
|
||||
arguments="",
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Close tool_results stream (already sent via ToolResultChunkEvent)
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk="",
|
||||
tool_result=ToolResult(
|
||||
id="",
|
||||
name="",
|
||||
output="",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Build tool_calls from agent_logs (with results)
|
||||
tool_calls_for_generation: list[ToolCallResult] = []
|
||||
for log in agent_logs:
|
||||
tool_call_id = log.data.get("tool_call_id")
|
||||
for log in agent_context.agent_logs:
|
||||
payload = ToolLogPayload.from_mapping(log.data or {})
|
||||
tool_call_id = payload.tool_call_id
|
||||
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
|
||||
continue
|
||||
|
||||
tool_args = log.data.get("tool_args") or {}
|
||||
log_error = log.data.get("error")
|
||||
log_output = log.data.get("output")
|
||||
tool_args = payload.tool_args
|
||||
log_error = payload.tool_error
|
||||
log_output = payload.tool_output
|
||||
result_text = log_output or log_error or ""
|
||||
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
|
||||
tool_calls_for_generation.append(
|
||||
ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=log.data.get("tool_name", ""),
|
||||
name=payload.tool_name,
|
||||
arguments=json.dumps(tool_args) if tool_args else "",
|
||||
output=result_text,
|
||||
status=status,
|
||||
@@ -1966,21 +1907,50 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
tool_calls_for_generation.sort(
|
||||
key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map))
|
||||
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
||||
)
|
||||
|
||||
# Return generation data for caller
|
||||
return LLMGenerationData(
|
||||
text=text,
|
||||
reasoning_contents=reasoning_per_turn,
|
||||
text=aggregate.text,
|
||||
reasoning_contents=buffers.reasoning_per_turn,
|
||||
tool_calls=tool_calls_for_generation,
|
||||
sequence=sequence,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
files=files,
|
||||
trace=trace_segments,
|
||||
usage=aggregate.usage,
|
||||
finish_reason=aggregate.finish_reason,
|
||||
files=aggregate.files,
|
||||
trace=trace_state.trace_segments,
|
||||
)
|
||||
|
||||
def _process_tool_outputs(
|
||||
self,
|
||||
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Process strategy outputs and convert to node events."""
|
||||
state = ToolOutputState()
|
||||
|
||||
try:
|
||||
for output in outputs:
|
||||
if isinstance(output, AgentLog):
|
||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||
else:
|
||||
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||
except StopIteration as exception:
|
||||
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||
state.agent.agent_result = exception.value
|
||||
|
||||
if state.agent.agent_result:
|
||||
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
|
||||
|
||||
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
total_usage.prompt_tokens += delta_usage.prompt_tokens
|
||||
|
||||
Reference in New Issue
Block a user