Files
dify/api/dify_graph/nodes/llm/node.py
Novice a28f22e59d fix: resolve import errors and test failures after segment 4 merge
- Update BaseNodeData import path to dify_graph.entities.base_node_data
- Change NodeType.COMMAND/FILE_UPLOAD to BuiltinNodeTypes constants
- Fix system_oauth_encryption -> system_encryption rename in commands
- Remove tests for deleted agent runner modules
- Fix Avatar: named import + string size API in collaboration files
- Add missing skill feature deps: @monaco-editor/react, react-arborist,
  @tanstack/react-virtual
- Fix frontend test mocks: add useUserProfile, useLeaderRestoreListener,
  next/navigation mock, and nodeOutputVars to expected payload

Made-with: Cursor
2026-03-23 13:59:09 +08:00

2901 lines
121 KiB
Python

from __future__ import annotations
import base64
import io
import json
import logging
import mimetypes
import os
import re
import time
from collections.abc import Generator, Mapping, Sequence
from functools import reduce
from pathlib import PurePosixPath
from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.patterns import StrategyFactory
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import (
adapt_schema_for_sandbox_file_paths,
convert_sandbox_file_paths_in_output,
detect_file_path_fields,
)
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output,
)
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.sandbox import Sandbox
from core.sandbox.bash.session import MAX_OUTPUT_FILE_SIZE, MAX_OUTPUT_FILES, SandboxBashSession
from core.sandbox.entities.config import AppAssets
from core.skill.assembler import SkillDocumentAssembler
from core.skill.constants import SkillAttrs
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.tools.__base.tool import Tool
from core.tools.signature import sign_tool_file, sign_upload_file
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.entities.tool_entities import ToolCallResult
from dify_graph.enums import (
BuiltinNodeTypes,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
ModelInvokeCompletedEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from dify_graph.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from . import llm_utils
from .entities import (
AgentContext,
AggregatedResult,
LLMGenerationData,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
LLMTraceSegment,
ModelTraceSegment,
PromptMessageContext,
StreamBuffers,
ThinkTagStreamParser,
ToolLogPayload,
ToolOutputState,
ToolTraceSegment,
TraceState,
)
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from dify_graph.file.models import File
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(Node[LLMNodeData]):
node_type = BuiltinNodeTypes.LLM
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
variable_pool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
prompt_template = self.node_data.prompt_template
static_messages, context_refs, template_order = self._parse_prompt_template()
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
# fetch files
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
context = None
context_files: list[File] = []
for event in generator:
context = event.context
context_files = event.context_files or []
yield event
if context:
node_inputs["#context#"] = context
if context_files:
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance = self._model_instance
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
query: str | None = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
prompt_messages: Sequence[PromptMessage]
stop: Sequence[str] | None
if isinstance(prompt_template, list) and context_refs:
prompt_messages, stop = self._build_prompt_messages_with_context(
context_refs=context_refs,
template_order=template_order,
static_messages=static_messages,
query=query,
files=files,
context=context,
memory=memory,
model_instance=model_instance,
context_files=context_files,
)
else:
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_instance=model_instance,
stop=model_stop,
prompt_template=cast(
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
self.node_data.prompt_template,
),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
context_files=context_files,
sandbox=self.graph_runtime_state.sandbox,
)
# handle invoke result
generator = LLMNode.invoke_llm(
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.require_dify_context().user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self.node_data.reasoning_format,
)
# Variables for outputs
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
structured_output_schema: Mapping[str, Any] | None
structured_output_file_paths: list[str] = []
if self.node_data.structured_output_enabled:
if not self.node_data.structured_output:
raise LLMNodeError("structured_output_enabled is True but structured_output is not set")
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
if self.node_data.computer_use:
raise LLMNodeError("Structured output is not supported in computer use mode.")
else:
if detect_file_path_fields(raw_schema):
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
raw_schema
)
else:
structured_output_schema = raw_schema
else:
structured_output_schema = None
if self.node_data.computer_use:
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("computer use is enabled but no sandbox found")
tool_dependencies: ToolDependencies | None = self._extract_tool_dependencies()
generator = self._invoke_llm_with_sandbox(
sandbox=sandbox,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
)
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
files=files,
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
)
else:
# Use traditional LLM invocation
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_schema=structured_output_schema,
allow_file_path=bool(structured_output_file_paths),
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
)
(
clean_text,
reasoning_content,
generation_reasoning_content,
generation_clean_content,
usage,
finish_reason,
structured_output,
generation_data,
) = yield from self._stream_llm_events(generator, model_instance=model_instance)
if structured_output and structured_output_file_paths:
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
structured_output_value = structured_output.structured_output
if structured_output_value is None:
raise LLMNodeError("Structured output is empty")
resolved_count = 0
def resolve_file(path: str) -> File:
nonlocal resolved_count
if resolved_count >= MAX_OUTPUT_FILES:
raise LLMNodeError("Structured output files exceed the sandbox output limit")
resolved_count += 1
return self._resolve_sandbox_file_path(sandbox=sandbox, path=path)
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
output=structured_output_value,
file_path_fields=structured_output_file_paths,
file_resolver=resolve_file,
)
structured_output = LLMStructuredOutput(structured_output=converted_output)
if structured_output_files:
self._file_outputs.extend(structured_output_files)
# Extract variables from generation_data if available
if generation_data:
clean_text = generation_data.text
reasoning_content = ""
usage = generation_data.usage
finish_reason = generation_data.finish_reason
# Unified process_data building
process_data = {
"model_mode": self.node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_provider,
"model_name": model_name,
}
if self.tool_call_enabled and self._node_data.tools:
process_data["tools"] = [
{
"type": tool.type.value if hasattr(tool.type, "value") else tool.type,
"provider_name": tool.provider_name,
"tool_name": tool.tool_name,
}
for tool in self._node_data.tools
if tool.enabled
]
is_sandbox = self.graph_runtime_state.sandbox is not None
outputs = self._build_outputs(
is_sandbox=is_sandbox,
clean_text=clean_text,
reasoning_content=reasoning_content,
generation_reasoning_content=generation_reasoning_content,
generation_clean_content=generation_clean_content,
usage=usage,
finish_reason=finish_reason,
prompt_messages=prompt_messages,
generation_data=generation_data,
structured_output=structured_output,
)
# Send final chunk event to indicate streaming is complete
# For tool calls and sandbox, final events are already sent in _process_tool_outputs
if not self.tool_call_enabled and not self.node_data.computer_use:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
}
if generation_data and generation_data.trace:
metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [
segment.model_dump() for segment in generation_data.trace
]
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata=metadata,
llm_usage=usage,
)
)
except ValueError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
except Exception as e:
logger.exception("error while executing llm node")
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
llm_usage=usage,
)
)
def _build_outputs(
self,
*,
is_sandbox: bool,
clean_text: str,
reasoning_content: str,
generation_reasoning_content: str,
generation_clean_content: str,
usage: LLMUsage,
finish_reason: str | None,
prompt_messages: Sequence[PromptMessage],
generation_data: LLMGenerationData | None,
structured_output: LLMStructuredOutput | None,
) -> dict[str, Any]:
"""Build the outputs dictionary for the LLM node.
Two runtime modes produce different output shapes:
- **Classical** (is_sandbox=False): top-level ``text`` and ``reasoning_content``
are preserved for backward compatibility with existing users.
- **Sandbox** (is_sandbox=True): ``text`` and ``reasoning_content`` are omitted
from the top level because they duplicate fields inside ``generation``.
The ``generation`` field always carries the full structured representation
(content, reasoning, tool_calls, sequence) regardless of runtime mode.
Args:
is_sandbox: Whether the current runtime is sandbox mode.
clean_text: Processed text for outputs["text"]; may keep <think> tags for "tagged" format.
reasoning_content: Native model reasoning from the API response.
generation_reasoning_content: Reasoning for the generation field, extracted from <think>
tags via _split_reasoning (always tag-free). Falls back to reasoning_content
if empty (no <think> tags found).
generation_clean_content: Clean text for the generation field (always tag-free).
Differs from clean_text only when reasoning_format is "tagged".
usage: LLM usage statistics.
finish_reason: Finish reason from LLM.
prompt_messages: Prompt messages sent to the LLM.
generation_data: Multi-turn generation data from tool/sandbox invocation, or None.
structured_output: Structured output if enabled.
"""
# Common outputs shared by both runtimes
outputs: dict[str, Any] = {
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
}
# Classical runtime keeps top-level text/reasoning_content for backward compatibility
if not is_sandbox:
outputs["text"] = clean_text
outputs["reasoning_content"] = reasoning_content
# Build generation field
if generation_data:
# Agent/sandbox runtime: generation_data captures multi-turn interactions
generation = {
"content": generation_data.text,
"reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...]
"tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls],
"sequence": generation_data.sequence,
}
files_to_output = list(generation_data.files)
# Merge auto-collected/structured-output files from self._file_outputs
if self._file_outputs:
existing_ids = {f.id for f in files_to_output}
files_to_output.extend(f for f in self._file_outputs if f.id not in existing_ids)
else:
# Classical runtime: use pre-computed generation-specific text pair,
# falling back to native model reasoning if no <think> tags were found.
generation_reasoning = generation_reasoning_content or reasoning_content
generation_content = generation_clean_content or clean_text
sequence: list[dict[str, Any]] = []
if generation_reasoning:
sequence = [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len(generation_content)},
]
generation = {
"content": generation_content,
"reasoning_content": [generation_reasoning] if generation_reasoning else [],
"tool_calls": [],
"sequence": sequence,
}
files_to_output = self._file_outputs
outputs["generation"] = generation
if files_to_output:
outputs["files"] = ArrayFileSegment(value=files_to_output)
if structured_output:
outputs["structured_output"] = structured_output.structured_output
return outputs
@staticmethod
def invoke_llm(
*,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
structured_output_schema: Mapping[str, Any] | None,
allow_file_path: bool = False,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_parameters = model_instance.parameters
invoke_model_parameters = dict(model_parameters)
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
if structured_output_schema:
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=structured_output_schema,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
user=user_id,
allow_file_path=allow_file_path,
)
else:
request_start_time = time.perf_counter()
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=invoke_model_parameters,
stop=list(stop or []),
stream=True,
user=user_id,
)
return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
node_type=node_type,
reasoning_format=reasoning_format,
request_start_time=request_start_time,
)
@staticmethod
def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_start_time: float | None = None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
duration = None
if request_start_time is not None:
duration = time.perf_counter() - request_start_time
invoke_result.usage.latency = round(duration, 3)
event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
reasoning_format=reasoning_format,
request_latency=duration,
)
yield event
return
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
think_parser = ThinkTagStreamParser()
reasoning_chunks: list[str] = []
# Initialize streaming metrics tracking
start_time = request_start_time if request_start_time is not None else time.perf_counter()
first_token_time = None
has_content = False
collected_structured_output = None # Collect structured_output from streaming chunks
# Consume the invoke result and handle generator exception
try:
for result in invoke_result:
if isinstance(result, LLMResultChunkWithStructuredOutput):
# Collect structured_output from the chunk
if result.structured_output is not None:
collected_structured_output = dict(result.structured_output)
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
# Detect first token for TTFT calculation
if text_part and not has_content:
first_token_time = time.perf_counter()
has_content = True
full_text_buffer.write(text_part)
# Text output: always forward raw chunk (keep <think> tags intact)
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=text_part,
is_final=False,
)
# Generation output: split out thoughts, forward only non-thought content chunks
for kind, segment in think_parser.process(text_part):
if not segment:
if kind not in {"thought_start", "thought_end"}:
continue
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
# Update the whole metadata
if not model and result.model:
model = result.model
if len(prompt_messages) == 0:
# TODO(QuantumGhost): it seems that this update has no visable effect.
# What's the purpose of the line below?
prompt_messages = list(result.prompt_messages)
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
for kind, segment in think_parser.flush():
if not segment and kind not in {"thought_start", "thought_end"}:
continue
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = "".join(reasoning_chunks)
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
if reasoning_chunks and not reasoning_content:
reasoning_content = "".join(reasoning_chunks)
# Calculate streaming metrics
end_time = time.perf_counter()
total_duration = end_time - start_time
usage.latency = round(total_duration, 3)
if has_content and first_token_time:
gen_ai_server_time_to_first_token = first_token_time - start_time
llm_streaming_time_to_generate = end_time - first_token_time
usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3)
usage.time_to_generate = round(llm_streaming_time_to_generate, 3)
yield ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=usage,
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if collected from streaming chunks
structured_output=collected_structured_output,
)
@staticmethod
def _image_file_to_markdown(file: File, /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@classmethod
def _split_reasoning(
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
) -> tuple[str, str]:
"""
Split reasoning content from text based on reasoning_format strategy.
Args:
text: Full text that may contain <think> blocks
reasoning_format: Strategy for handling reasoning content
- "separated": Remove <think> tags and return clean text + reasoning_content field
- "tagged": Keep <think> tags in text, return empty reasoning_content
Returns:
tuple of (clean_text, reasoning_content)
"""
if reasoning_format == "tagged":
return text, ""
# Find all <think>...</think> blocks (case-insensitive)
matches = cls._THINK_PATTERN.findall(text)
# Extract reasoning content from all <think> blocks
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
# Remove all <think>...</think> blocks from original text
clean_text = cls._THINK_PATTERN.sub("", text)
# Clean up extra whitespace
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == "jinja2" and message.jinja2_text:
message.text = message.jinja2_text
return messages
def _parse_prompt_template(
self,
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
"""
Parse prompt_template to separate static messages and context references.
Returns:
Tuple of (static_messages, context_refs, template_order)
- static_messages: list of LLMNodeChatModelMessage
- context_refs: list of PromptMessageContext
- template_order: list of (index, type) tuples preserving original order
"""
prompt_template = self.node_data.prompt_template
static_messages: list[LLMNodeChatModelMessage] = []
context_refs: list[PromptMessageContext] = []
template_order: list[tuple[int, str]] = []
if isinstance(prompt_template, list):
for idx, item in enumerate(prompt_template):
if isinstance(item, PromptMessageContext):
context_refs.append(item)
template_order.append((idx, "context"))
else:
static_messages.append(item)
template_order.append((idx, "static"))
# Transform static messages for jinja2
if static_messages:
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
return static_messages, context_refs, template_order
def _build_prompt_messages_with_context(
self,
*,
context_refs: list[PromptMessageContext],
template_order: list[tuple[int, str]],
static_messages: list[LLMNodeChatModelMessage],
query: str | None,
files: Sequence[File],
context: str | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
context_files: list[File],
) -> tuple[list[PromptMessage], Sequence[str] | None]:
"""
Build prompt messages by combining static messages and context references in DSL order.
Returns:
Tuple of (prompt_messages, stop_sequences)
"""
variable_pool = self.graph_runtime_state.variable_pool
# Process messages in DSL order: iterate once and handle each type directly
combined_messages: list[PromptMessage] = []
context_idx = 0
static_idx = 0
for _, type_ in template_order:
if type_ == "context":
# Handle context reference
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
# Restore multimodal content (base64/url) that was truncated when saving context
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
combined_messages.extend(restored_messages)
context_idx += 1
else:
# Handle static message
static_msg = static_messages[static_idx]
processed_msgs = LLMNode.handle_list_messages(
messages=[static_msg],
context=context,
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
sandbox=self.graph_runtime_state.sandbox,
)
combined_messages.extend(processed_msgs)
static_idx += 1
# Append memory messages
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=self.node_data.memory,
model_config=model_config,
)
combined_messages.extend(memory_messages)
# Append current query if provided
if query:
query_message = LLMNodeChatModelMessage(
text=query,
role=PromptMessageRole.USER,
edition_type="basic",
)
query_msgs = LLMNode.handle_list_messages(
messages=[query_message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(query_msgs)
# Handle files (sys_files and context_files)
combined_messages = self._append_files_to_messages(
messages=combined_messages,
sys_files=files,
context_files=context_files,
model_config=model_config,
)
# Filter empty messages and get stop sequences
combined_messages = self._filter_messages(combined_messages, model_config)
stop = self._get_stop_sequences(model_config)
return combined_messages, stop
def _append_files_to_messages(
self,
*,
messages: list[PromptMessage],
sys_files: Sequence[File],
context_files: list[File],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
"""Append sys_files and context_files to messages."""
vision_enabled = self.node_data.vision.enabled
vision_detail = self.node_data.vision.configs.detail
# Handle sys_files (will be deprecated later)
if vision_enabled and sys_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
# Handle context_files
if vision_enabled and context_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
for file in context_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
return messages
def _filter_messages(
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> list[PromptMessage]:
"""Filter empty messages and unsupported content types."""
filtered_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message.content, list):
filtered_content: list[PromptMessageContentUnionTypes] = []
for content_item in message.content:
# Skip non-text content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
filtered_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
feature_map = {
PromptMessageContentType.IMAGE: ModelFeature.VISION,
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
}
required_feature = feature_map.get(content_item.type)
if required_feature and required_feature not in model_config.model_schema.features:
continue
filtered_content.append(content_item)
# Simplify single text content
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
message.content = filtered_content[0].data
else:
message.content = filtered_content
if not message.is_empty():
filtered_messages.append(message)
if not filtered_messages:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_messages
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
"""Get stop sequences from model config."""
return model_config.stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}
if not node_data.prompt_config:
return variables
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return str(input_dict["content"])
# else, parse the dict
try:
return json.dumps(input_dict, ensure_ascii=False)
except Exception:
return str(input_dict)
if isinstance(variable, ArraySegment):
result = ""
for item in variable.value:
if isinstance(item, dict):
result += parse_dict(item)
else:
result += str(item)
result += "\n"
value = result.strip()
elif isinstance(variable, ObjectSegment):
value = parse_dict(variable.value)
else:
value = variable.text
variables[variable_name] = value
return variables
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
inputs = {}
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, CompletionModelPromptTemplate):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
inputs[variable_selector.variable] = ""
inputs[variable_selector.variable] = variable.to_object()
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
continue
inputs[variable_selector.variable] = variable.to_object()
return inputs
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return
if not node_data.context.variable_selector:
return
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(
retriever_resources=[], context=context_value_variable.value, context_files=[]
)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource: list[dict[str, Any]] = []
context_files: list[File] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
else:
if "content" not in item:
raise InvalidContextStructureError(f"Invalid context structure: {item}")
if item.get("summary"):
context_str += item["summary"] + "\n"
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
if retriever_resource:
original_retriever_resource.append(retriever_resource)
segment_id = retriever_resource.get("segment_id")
if not segment_id:
continue
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment_id,
)
).all()
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.require_dify_context().tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=sign_upload_file(upload_file.id, upload_file.extension),
)
context_files.append(attachment_info)
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip(),
context_files=context_files,
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None:
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
and context_dict["metadata"]["_source"] == "knowledge"
):
metadata = context_dict.get("metadata", {})
return {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
"files": context_dict.get("files"),
"summary": context_dict.get("summary"),
}
return None
@staticmethod
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: BaseMemory | None = None,
model_instance: ModelInstance,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
sandbox: Sandbox | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
sandbox=sandbox,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if sys_query:
message = LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if sys_query:
if prompt_content_type == str:
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
prompt_template = node_data.prompt_template
variable_selectors = []
prompt_context_selectors: list[Sequence[str]] = []
if isinstance(prompt_template, list):
for item in prompt_template:
# Check PromptMessageContext first (same order as _parse_prompt_template)
# This extracts value_selector which is used by variable_pool.get(ctx_ref.value_selector)
if isinstance(item, PromptMessageContext):
if len(item.value_selector) >= 2:
prompt_context_selectors.append(item.value_selector)
elif isinstance(item, LLMNodeChatModelMessage):
variable_template_parser = VariableTemplateParser(template=item.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping: dict[str, Any] = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
for context_selector in prompt_context_selectors:
variable_key = f"#{'.'.join(context_selector)}#"
variable_mapping[variable_key] = list(context_selector)
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
for item in prompt_template:
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
enable_jinja = True
break
else:
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
return variable_mapping
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "llm",
"config": {
"prompt_templates": {
"chat_model": {
"prompts": [
{"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"}
]
},
"completion_model": {
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"prompt": {
"text": "Here are the chat histories between human and assistant, inside "
"<histories></histories> XML tags.\n\n<histories>\n{{"
"#histories#}}\n</histories>\n\n\nHuman: {{#sys.query#}}\n\nAssistant:",
"edition_type": "basic",
},
"stop": ["Human:"],
},
}
},
}
@staticmethod
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
sandbox: Sandbox | None = None,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
bundle: SkillBundle | None = None
if sandbox:
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
if bundle is not None:
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
document=SkillDocument(
skill_id="anonymous", content=result_text, metadata=message.metadata or {}
),
base_path=AppAssets.PATH,
)
result_text = skill_entry.content
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
plain_text = segment_group.text
if plain_text and bundle is not None:
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
document=SkillDocument(
skill_id="anonymous", content=plain_text, metadata=message.metadata or {}
),
base_path=AppAssets.PATH,
)
plain_text = skill_entry.content
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
@staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult | LLMResultWithStructuredOutput,
saver: LLMFileSaver,
file_outputs: list[File],
reasoning_format: Literal["separated", "tagged"] = "tagged",
request_latency: float | None = None,
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part)
# Extract reasoning content from <think> tags in the main text
full_text = buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
event = ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=invoke_result.usage,
finish_reason=None,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if enabled
structured_output=getattr(invoke_result, "structured_output", None),
)
if request_latency is not None:
event.usage.latency = round(request_latency, 3)
return event
@staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> File:
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
- Inlined data encoded in base64, which would be saved to storage directly.
- Remote files referenced by an url, which would be downloaded and then saved to storage.
Currently, only image files are supported.
"""
if content.url != "":
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
return saved_file
@staticmethod
def _normalize_sandbox_file_path(path: str) -> str:
raw = path.strip()
if not raw:
raise LLMNodeError("Sandbox file path must not be empty")
sandbox_path = PurePosixPath(raw)
if any(part == ".." for part in sandbox_path.parts):
raise LLMNodeError("Sandbox file path must not contain '..'")
normalized = str(sandbox_path)
if normalized in {".", ""}:
raise LLMNodeError("Sandbox file path is invalid")
return normalized
def _resolve_sandbox_file_path(self, *, sandbox: Sandbox, path: str) -> File:
normalized_path = self._normalize_sandbox_file_path(path)
filename = os.path.basename(normalized_path)
if not filename:
raise LLMNodeError("Sandbox file path must point to a file")
try:
file_content = sandbox.vm.download_file(normalized_path)
except Exception as exc:
raise LLMNodeError(f"Sandbox file not found: {normalized_path}") from exc
file_binary = file_content.getvalue()
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
raise LLMNodeError(f"Sandbox file exceeds size limit: {normalized_path}")
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
tool_file_manager = ToolFileManager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
file_type = self._get_file_type_from_mime(mime_type)
return File(
id=tool_file.id,
tenant_id=self.tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
@staticmethod
def _get_file_type_from_mime(mime_type: str) -> FileType:
if mime_type.startswith("image/"):
return FileType.IMAGE
if mime_type.startswith("video/"):
return FileType.VIDEO
if mime_type.startswith("audio/"):
return FileType.AUDIO
if "text" in mime_type or "pdf" in mime_type:
return FileType.DOCUMENT
return FileType.CUSTOM
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list[File],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
If the messages contain non-textual content (e.g., multimedia like images or videos),
it will be saved separately, and the corresponding Markdown representation will
be yielded to the caller.
"""
# NOTE(QuantumGhost): This function should yield results to the caller immediately
# whenever new content or partial content is available. Avoid any intermediate buffering
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
# if necessary.
if contents is None:
yield from []
return
if isinstance(contents, str):
yield contents
else:
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = LLMNode.save_multimodal_image_output(
content=item,
file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
@property
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def tool_call_enabled(self) -> bool:
return (
self.node_data.tools is not None
and len(self.node_data.tools) > 0
and all(tool.enabled for tool in self.node_data.tools)
)
def _stream_llm_events(
self,
generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None],
*,
model_instance: ModelInstance,
) -> Generator[
NodeEventBase,
None,
tuple[
str, # clean_text: processed text for outputs["text"]
str, # reasoning_content: native model reasoning
str, # generation_reasoning_content: reasoning for generation field (from <think> tags)
str, # generation_clean_content: clean text for generation field (always tag-free)
LLMUsage,
str | None,
LLMStructuredOutput | None,
LLMGenerationData | None,
],
]:
"""Stream events and capture generator return value in one place.
Uses generator delegation so _run stays concise while still emitting events.
Returns two pairs of text fields because outputs["text"] and generation["content"]
may differ when reasoning_format is "tagged":
- clean_text / reasoning_content: for top-level outputs (may keep <think> tags)
- generation_clean_content / generation_reasoning_content: for the generation field
(always tag-free, extracted via _split_reasoning with "separated" mode)
"""
clean_text = ""
reasoning_content = ""
generation_reasoning_content = ""
generation_clean_content = ""
usage = LLMUsage.empty_usage()
finish_reason: str | None = None
structured_output: LLMStructuredOutput | None = None
generation_data: LLMGenerationData | None = None
completed = False
while True:
try:
event = next(generator)
except StopIteration as exc:
if isinstance(exc.value, LLMGenerationData):
generation_data = exc.value
break
if completed:
# After completion we still drain to reach StopIteration.value
continue
match event:
case StreamChunkEvent() | ThoughtChunkEvent():
yield event
case ModelInvokeCompletedEvent(
text=text,
usage=usage_event,
finish_reason=finish_reason_event,
reasoning_content=reasoning_event,
structured_output=structured_raw,
):
clean_text = text
usage = usage_event
finish_reason = finish_reason_event
reasoning_content = reasoning_event or ""
generation_reasoning_content = reasoning_content
generation_clean_content = clean_text
if self.node_data.reasoning_format == "tagged":
# Keep tagged text for output; also extract reasoning for generation field
generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning(
clean_text, reasoning_format="separated"
)
else:
clean_text, generation_reasoning_content = LLMNode._split_reasoning(
clean_text, self.node_data.reasoning_format
)
generation_clean_content = clean_text
structured_output = (
LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None
)
from core.app.llm.quota import deduct_llm_quota
deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
completed = True
case LLMStructuredOutput():
structured_output = event
case _:
continue
return (
clean_text,
reasoning_content,
generation_reasoning_content,
generation_clean_content,
usage,
finish_reason,
structured_output,
generation_data,
)
def _extract_disabled_tools(self) -> dict[str, ToolDependency]:
tools = [
ToolDependency(type=tool.type, provider=tool.provider, tool_name=tool.tool_name)
for tool in self.node_data.tool_settings
if not tool.enabled
]
return {tool.tool_id(): tool for tool in tools}
def _extract_tool_dependencies(self) -> ToolDependencies | None:
"""Extract tool artifact from prompt template."""
sandbox = self.graph_runtime_state.sandbox
if not sandbox:
raise LLMNodeError("Sandbox not found")
bundle = sandbox.attrs.get(SkillAttrs.BUNDLE)
tool_deps_list: list[ToolDependencies] = []
for prompt in self.node_data.prompt_template:
if isinstance(prompt, LLMNodeChatModelMessage):
skill_entry = SkillDocumentAssembler(bundle).assemble_document(
document=SkillDocument(skill_id="anonymous", content=prompt.text, metadata=prompt.metadata or {}),
base_path=AppAssets.PATH,
)
tool_deps_list.append(skill_entry.tools)
if len(tool_deps_list) == 0:
return None
disabled_tools = self._extract_disabled_tools()
tool_dependencies = reduce(lambda x, y: x.merge(y), tool_deps_list)
for tool in tool_dependencies.dependencies:
if tool.tool_id() in disabled_tools:
tool.enabled = False
return tool_dependencies
def _invoke_llm_with_tools(
self,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
files: Sequence[File],
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
"""
# Get model features to determine strategy
model_features = self._get_model_features(model_instance)
# Prepare tool instances
tool_instances = self._prepare_tool_instances(variable_pool)
# Prepare prompt files (files that come from prompt variables, not vision files)
prompt_files = self._extract_prompt_files(variable_pool)
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
)
# Run strategy
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
result = yield from self._process_tool_outputs(outputs)
return result
def _invoke_llm_with_sandbox(
self,
sandbox: Sandbox,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
) -> Generator[NodeEventBase, None, LLMGenerationData]:
result: LLMGenerationData | None = None
# FIXME(Mairuis): Async processing for bash session.
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
prompt_files = self._extract_prompt_files(variable_pool)
model_features = self._get_model_features(model_instance)
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=[session.bash_tool],
files=prompt_files,
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
result = yield from self._process_tool_outputs(outputs)
# Auto-collect sandbox output/ files, deduplicate by id
collected_files = session.collect_output_files()
if collected_files:
existing_ids = {f.id for f in self._file_outputs}
self._file_outputs.extend(f for f in collected_files if f.id not in existing_ids)
if result is None:
raise LLMNodeError("SandboxSession exited unexpectedly")
return result
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
"""Get model schema to determine features."""
try:
model_type_instance = model_instance.model_type_instance
model_schema = model_type_instance.get_model_schema(
model_instance.model,
model_instance.credentials,
)
return model_schema.features if model_schema and model_schema.features else []
except Exception:
logger.warning("Failed to get model schema, assuming no special features")
return []
def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]:
"""Prepare tool instances from configuration."""
tool_instances = []
if self._node_data.tools:
for tool in self._node_data.tools:
try:
# Process settings to extract the correct structure
processed_settings = {}
for key, value in tool.settings.items():
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
# Extract the nested value if it has the ToolInput structure
if "type" in value["value"] and "value" in value["value"]:
processed_settings[key] = value["value"]
else:
processed_settings[key] = value
else:
processed_settings[key] = value
# Merge parameters with processed settings (similar to Agent Node logic)
merged_parameters = {**tool.parameters, **processed_settings}
# Create AgentToolEntity from ToolMetadata
agent_tool = AgentToolEntity(
provider_id=tool.provider_name,
provider_type=tool.type,
tool_name=tool.tool_name,
tool_parameters=merged_parameters,
plugin_unique_identifier=tool.plugin_unique_identifier,
credential_id=tool.credential_id,
)
# Get tool runtime from ToolManager
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_id,
agent_tool=agent_tool,
invoke_from=self.invoke_from,
variable_pool=variable_pool,
)
# Apply custom description from extra field if available
if tool.extra.get("description") and tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
tool.extra.get("description") or tool_runtime.entity.description.llm
)
tool_instances.append(tool_runtime)
except Exception as e:
logger.warning("Failed to load tool %s: %s", tool, str(e))
continue
return tool_instances
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
"""Extract files from prompt template variables."""
from dify_graph.variables import ArrayFileVariable, FileVariable
files: list[File] = []
# Extract variables from prompt template
if isinstance(self._node_data.prompt_template, list):
for message in self._node_data.prompt_template:
if message.text:
parser = VariableTemplateParser(message.text)
variable_selectors = parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable = variable_pool.get(variable_selector.value_selector)
if isinstance(variable, FileVariable) and variable.value:
files.append(variable.value)
elif isinstance(variable, ArrayFileVariable) and variable.value:
files.extend(variable.value)
return files
@staticmethod
def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]:
"""Convert ToolCallResult into JSON-friendly dict."""
def _file_to_ref(file: File) -> str | None:
# Align with streamed tool result events which carry file IDs
return file.id or file.related_id
files = []
for file in tool_call.files or []:
ref = _file_to_ref(file)
if ref:
files.append(ref)
return {
"id": tool_call.id,
"name": tool_call.name,
"arguments": tool_call.arguments,
"output": tool_call.output,
"files": files,
"status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status,
"elapsed_time": tool_call.elapsed_time,
}
def _generate_model_provider_icon_url(self, provider: str, dark: bool = False) -> str | None:
"""Generate icon URL for model provider."""
from yarl import URL
from configs import dify_config
icon_type = "icon_small_dark" if dark else "icon_small"
try:
return str(
URL(dify_config.CONSOLE_API_URL or "/")
/ "console"
/ "api"
/ "workspaces"
/ self.tenant_id
/ "model-providers"
/ provider
/ icon_type
/ "en_US"
)
except Exception:
return None
def _emit_model_start(self, trace_state: TraceState) -> Generator[NodeEventBase, None, None]:
"""Yield a MODEL_START event with model identity info at the beginning of a model turn.
Idempotent: only emits once per turn (guarded by trace_state.model_start_emitted)."""
if trace_state.model_start_emitted:
return
trace_state.model_start_emitted = True
if trace_state.model_segment_start_time is None:
trace_state.model_segment_start_time = time.perf_counter()
provider = self._node_data.model.provider
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_start"],
chunk="",
chunk_type=ChunkType.MODEL_START,
is_final=False,
model_provider=provider,
model_name=self._node_data.model.name,
model_icon=self._generate_model_provider_icon_url(provider),
model_icon_dark=self._generate_model_provider_icon_url(provider, dark=True),
)
def _flush_model_segment(
self,
buffers: StreamBuffers,
trace_state: TraceState,
error: str | None = None,
) -> Generator[NodeEventBase, None, None]:
"""Flush pending thought/content buffers into a single model trace segment
and yield a MODEL_END chunk event with usage/duration metrics."""
if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls:
return
now = time.perf_counter()
duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0
usage = trace_state.pending_usage
provider = self._node_data.model.provider
model_name = self._node_data.model.name
model_icon = self._generate_model_provider_icon_url(provider)
model_icon_dark = self._generate_model_provider_icon_url(provider, dark=True)
trace_state.trace_segments.append(
LLMTraceSegment(
type="model",
duration=duration,
usage=usage,
output=ModelTraceSegment(
text="".join(buffers.pending_content) if buffers.pending_content else None,
reasoning="".join(buffers.pending_thought) if buffers.pending_thought else None,
tool_calls=list(buffers.pending_tool_calls),
),
provider=provider,
name=model_name,
icon=model_icon,
icon_dark=model_icon_dark,
error=error,
status="error" if error else "success",
)
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_end"],
chunk="",
chunk_type=ChunkType.MODEL_END,
is_final=False,
model_usage=usage,
model_duration=duration,
)
buffers.pending_thought.clear()
buffers.pending_content.clear()
buffers.pending_tool_calls.clear()
trace_state.model_segment_start_time = None
trace_state.model_start_emitted = False
trace_state.pending_usage = None
def _handle_agent_log_output(
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
) -> Generator[NodeEventBase, None, None]:
payload = ToolLogPayload.from_log(output)
agent_log_event = AgentLogEvent(
message_id=output.id,
label=output.label,
node_execution_id=self.id,
parent_id=output.parent_id,
error=output.error,
status=output.status.value,
data=output.data,
metadata={k.value: v for k, v in output.metadata.items()},
node_id=self._node_id,
)
for log in agent_context.agent_logs:
if log.message_id == agent_log_event.message_id:
log.data = agent_log_event.data
log.status = agent_log_event.status
log.error = agent_log_event.error
log.label = agent_log_event.label
log.metadata = agent_log_event.metadata
break
else:
agent_context.agent_logs.append(agent_log_event)
# Handle THOUGHT log completion - capture usage for model segment
if output.log_type == AgentLog.LogType.THOUGHT and output.status == AgentLog.LogStatus.SUCCESS:
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) if output.metadata else None
if llm_usage:
trace_state.pending_usage = llm_usage
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
yield from self._emit_model_start(trace_state)
tool_name = payload.tool_name
tool_call_id = payload.tool_call_id
tool_arguments = json.dumps(payload.tool_args or {})
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
icon=tool_icon,
icon_dark=tool_icon_dark,
),
is_final=False,
)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = payload.tool_name
tool_output = payload.tool_output
tool_call_id = payload.tool_call_id
tool_files = payload.files if isinstance(payload.files, list) else []
tool_error = payload.tool_error
tool_arguments = json.dumps(payload.tool_args or {})
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
# Flush model segment before tool result processing
yield from self._flush_model_segment(buffers, trace_state)
if output.status == AgentLog.LogStatus.ERROR:
tool_error = output.error or payload.tool_error
if not tool_error and payload.meta:
tool_error = payload.meta.get("error")
else:
if payload.meta:
meta_error = payload.meta.get("error")
if meta_error:
tool_error = meta_error
elapsed_time = output.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if output.metadata else None
tool_provider = output.metadata.get(AgentLog.LogMetadata.PROVIDER) if output.metadata else None
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
result_str = str(tool_output) if tool_output is not None else None
tool_status: Literal["success", "error"] = "error" if tool_error else "success"
tool_call_segment = LLMTraceSegment(
type="tool",
duration=elapsed_time or 0.0,
usage=None,
output=ToolTraceSegment(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
output=result_str,
),
provider=tool_provider,
name=tool_name,
icon=tool_icon,
icon_dark=tool_icon_dark,
error=str(tool_error) if tool_error else None,
status=tool_status,
)
trace_state.trace_segments.append(tool_call_segment)
if tool_call_id:
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
# Start new model segment tracking
trace_state.model_segment_start_time = time.perf_counter()
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=result_str or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_str,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
elapsed_time=elapsed_time,
icon=tool_icon,
icon_dark=tool_icon_dark,
provider=tool_provider,
),
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
buffers.current_turn_reasoning.clear()
def _handle_llm_chunk_output(
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
message = output.delta.message
if message and message.content:
chunk_text = message.content
if isinstance(chunk_text, list):
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
else:
chunk_text = str(chunk_text)
for kind, segment in buffers.think_parser.process(chunk_text):
if not segment and kind not in {"thought_start", "thought_end"}:
continue
yield from self._emit_model_start(trace_state)
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if output.delta.usage:
self._accumulate_usage(aggregate.usage, output.delta.usage)
if output.delta.finish_reason:
aggregate.finish_reason = output.delta.finish_reason
def _flush_remaining_stream(
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
for kind, segment in buffers.think_parser.flush():
if not segment and kind not in {"thought_start", "thought_end"}:
continue
yield from self._emit_model_start(trace_state)
if kind == "thought_start":
yield ThoughtStartChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
elif kind == "thought":
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
elif kind == "thought_end":
yield ThoughtEndChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=False,
)
else:
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
# For final flush, use aggregate.usage if pending_usage is not set
# (e.g., for simple LLM calls without tool invocations)
if trace_state.pending_usage is None:
trace_state.pending_usage = aggregate.usage
# Flush final model segment
yield from self._flush_model_segment(buffers, trace_state)
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_start"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_end"],
chunk="",
is_final=True,
)
def _build_generation_data(
self,
trace_state: TraceState,
agent_context: AgentContext,
aggregate: AggregatedResult,
buffers: StreamBuffers,
) -> LLMGenerationData:
sequence: list[dict[str, Any]] = []
reasoning_index = 0
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
elif trace_segment.type == "content":
segment_text = trace_segment.text or ""
start = content_position
end = start + len(segment_text)
sequence.append({"type": "content", "start": start, "end": end})
content_position = end
elif trace_segment.type == "tool_call":
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
if tool_id not in tool_call_seen_index:
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
tool_calls_for_generation: list[ToolCallResult] = []
for log in agent_context.agent_logs:
payload = ToolLogPayload.from_mapping(log.data or {})
tool_call_id = payload.tool_call_id
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
continue
tool_args = payload.tool_args
log_error = payload.tool_error
log_output = payload.tool_output
result_text = log_output or log_error or ""
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
tool_calls_for_generation.append(
ToolCallResult(
id=tool_call_id,
name=payload.tool_name,
arguments=json.dumps(tool_args) if tool_args else "",
output=result_text,
status=status,
elapsed_time=log.metadata.get(AgentLog.LogMetadata.ELAPSED_TIME) if log.metadata else None,
)
)
tool_calls_for_generation.sort(
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
)
return LLMGenerationData(
text=aggregate.text,
reasoning_contents=buffers.reasoning_per_turn,
tool_calls=tool_calls_for_generation,
sequence=sequence,
usage=aggregate.usage,
finish_reason=aggregate.finish_reason,
files=aggregate.files,
trace=trace_state.trace_segments,
)
def _process_tool_outputs(
self,
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Process strategy outputs and convert to node events."""
state = ToolOutputState()
try:
for output in outputs:
if isinstance(output, AgentLog):
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
else:
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
except StopIteration as exception:
if isinstance(getattr(exception, "value", None), AgentResult):
state.agent.agent_result = exception.value
if state.agent.agent_result:
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
state.aggregate.usage = state.agent.agent_result.usage
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
total_usage.prompt_tokens += delta_usage.prompt_tokens
total_usage.completion_tokens += delta_usage.completion_tokens
total_usage.total_tokens += delta_usage.total_tokens
total_usage.prompt_price += delta_usage.prompt_price
total_usage.completion_price += delta_usage.completion_price
total_usage.total_price += delta_usage.total_price
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: BaseMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: BaseMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message)
return prompt_messages