revert: add tools for output in agent mode

feat: hide output tools and improve JSON formatting for structured output
feat: hide output tools and improve JSON formatting for structured output
fix: handle prompt template correctly to extract selectors for step run
fix: emit StreamChunkEvent correctly for sandbox agent
chore: better debug message
fix: incorrect output tool runtime selection
fix: type issues
fix: align parameter list
fix: align parameter list
fix: hide internal builtin providers from tool list
vibe: implement file structured output
vibe: implement file structured output
fix: refix parameter for tool
fix: crash
fix: crash
refactor: remove union types
fix: type check
Merge branch 'feat/structured-output-with-sandbox' into feat/support-agent-sandbox
fix: provide json as text
fix: provide json as text
fix: get AgentResult correctly
fix: provides correct prompts, tools and terminal predicates
fix: provides correct prompts, tools and terminal predicates
fix: circular import
feat: support structured output in sandbox and tool mode
This commit is contained in:
Stream
2026-02-04 21:13:07 +08:00
parent 25065a4f2f
commit e0082dbf18
41 changed files with 1014 additions and 1358 deletions

View File

@@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
@@ -24,7 +23,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMeta
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
@@ -103,13 +102,11 @@ class AgentAppRunner(BaseAgentRunner):
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
tenant_id=self.tenant_id,
invoke_from=self.application_generate_entity.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
)
# Initialize state variables
current_agent_thought_id = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
@@ -121,6 +118,7 @@ class AgentAppRunner(BaseAgentRunner):
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
@@ -135,10 +133,17 @@ class AgentAppRunner(BaseAgentRunner):
break
if isinstance(output, LLMResultChunk):
# No more expect streaming data
continue
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
else:
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
@@ -151,6 +156,7 @@ class AgentAppRunner(BaseAgentRunner):
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
@@ -258,30 +264,23 @@ class AgentAppRunner(BaseAgentRunner):
raise
# Process final result
if not isinstance(result, AgentResult):
raise ValueError("Agent did not return AgentResult")
output_payload = result.output
if isinstance(output_payload, dict):
final_answer = json.dumps(output_payload, ensure_ascii=False)
elif isinstance(output_payload, str):
final_answer = output_payload
else:
raise ValueError("Final output is not a string or structured data.")
usage = result.usage or LLMUsage.empty_usage()
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""

View File

@@ -7,7 +7,6 @@ from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
from core.agent.output_tools import build_agent_output_tools
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -37,7 +36,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolInvokeFrom,
ToolParameter,
)
from core.tools.tool_manager import ToolManager
@@ -253,14 +251,6 @@ class BaseAgentRunner(AppRunner):
# save tool entity
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
output_tools = build_agent_output_tools(
tenant_id=self.tenant_id,
invoke_from=self.application_generate_entity.invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
)
for tool in output_tools:
tool_instances[tool.entity.identity.name] = tool
return tool_instances, prompt_messages_tools
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:

View File

@@ -1,11 +1,10 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Any
from typing import Any, Union
from pydantic import BaseModel, Field
from core.agent.output_tools import FINAL_OUTPUT_TOOL
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
@@ -42,9 +41,9 @@ class AgentScratchpadUnit(BaseModel):
"""
action_name: str
action_input: dict[str, Any] | str
action_input: Union[dict, str]
def to_dict(self) -> dict[str, Any]:
def to_dict(self):
"""
Convert to dictionary.
"""
@@ -63,9 +62,9 @@ class AgentScratchpadUnit(BaseModel):
"""
Check if the scratchpad unit is final.
"""
if self.action is None:
return False
return self.action.action_name.lower() == FINAL_OUTPUT_TOOL
return self.action is None or (
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
@@ -126,7 +125,7 @@ class ExecutionContext(BaseModel):
"tenant_id": self.tenant_id,
}
def with_updates(self, **kwargs: Any) -> "ExecutionContext":
def with_updates(self, **kwargs) -> "ExecutionContext":
"""Create a new context with updated fields."""
data = self.to_dict()
data.update(kwargs)
@@ -179,23 +178,12 @@ class AgentLog(BaseModel):
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
class AgentOutputKind(StrEnum):
"""
Agent output kind.
"""
OUTPUT_TEXT = "output_text"
FINAL_OUTPUT_ANSWER = "final_output_answer"
FINAL_STRUCTURED_OUTPUT = "final_structured_output"
ILLEGAL_OUTPUT = "illegal_output"
class AgentResult(BaseModel):
"""
Agent execution result.
"""
output: str | dict = Field(default="", description="The generated output")
text: str = Field(default="", description="The generated text")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")

View File

@@ -1,7 +1,7 @@
import json
import re
from collections.abc import Generator
from typing import Any, Union, cast
from typing import Union
from core.agent.entities import AgentScratchpadUnit
from core.model_runtime.entities.llm_entities import LLMResultChunk
@@ -10,52 +10,46 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]:
action_name: str | None = None
action_input: Any | None = None
parsed_action: Any = action
if isinstance(parsed_action, str):
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
action_name = None
action_input = None
if isinstance(action, str):
try:
parsed_action = json.loads(parsed_action, strict=False)
action = json.loads(action, strict=False)
except json.JSONDecodeError:
return parsed_action or ""
return action or ""
# cohere always returns a list
if isinstance(parsed_action, list):
action_list: list[Any] = cast(list[Any], parsed_action)
if len(action_list) == 1:
parsed_action = action_list[0]
if isinstance(action, list) and len(action) == 1:
action = action[0]
if isinstance(parsed_action, dict):
action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action)
for key, value in action_dict.items():
if "input" in key.lower():
action_input = value
elif isinstance(value, str):
action_name = value
else:
return json.dumps(parsed_action)
for key, value in action.items():
if "input" in key.lower():
action_input = value
else:
action_name = value
if action_name is not None and action_input is not None:
return AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
return json.dumps(parsed_action)
else:
return json.dumps(action)
def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]:
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
if not blocks:
return []
try:
json_blocks: list[dict[str, Any] | list[Any]] = []
json_blocks = []
for block in blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
json_blocks.append(json.loads(json_text, strict=False))
return json_blocks
except Exception:
except:
return []
code_block_cache = ""

View File

@@ -1,108 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
OUTPUT_TOOL_PROVIDER = "agent_output"
OUTPUT_TEXT_TOOL = "output_text"
FINAL_OUTPUT_TOOL = "final_output_answer"
FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output"
ILLEGAL_OUTPUT_TOOL = "illegal_output"
OUTPUT_TOOL_NAMES: Sequence[str] = (
OUTPUT_TEXT_TOOL,
FINAL_OUTPUT_TOOL,
FINAL_STRUCTURED_OUTPUT_TOOL,
ILLEGAL_OUTPUT_TOOL,
)
TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL)
TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session."
def is_terminal_output_tool(tool_name: str) -> bool:
return tool_name in TERMINAL_OUTPUT_TOOL_NAMES
def get_terminal_tool_name(structured_output_enabled: bool) -> str:
return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
def build_agent_output_tools(
tenant_id: str,
invoke_from: InvokeFrom,
tool_invoke_from: ToolInvokeFrom,
structured_output_schema: Mapping[str, Any] | None = None,
) -> list[Tool]:
def get_tool_runtime(_tool_name: str) -> Tool:
return ToolManager.get_tool_runtime(
provider_type=ToolProviderType.BUILT_IN,
provider_id=OUTPUT_TOOL_PROVIDER,
tool_name=_tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
tools: list[Tool] = [
get_tool_runtime(OUTPUT_TEXT_TOOL),
get_tool_runtime(ILLEGAL_OUTPUT_TOOL),
]
if structured_output_schema:
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
raw_tool.entity = raw_tool.entity.model_copy(deep=True)
data_parameter = ToolParameter(
name="data",
type=ToolParameter.ToolParameterType.OBJECT,
form=ToolParameter.ToolParameterForm.LLM,
required=True,
input_schema=dict(structured_output_schema),
label=I18nObject(en_US="__Data", zh_Hans="__Data"),
)
raw_tool.entity.parameters = [data_parameter]
def invoke_tool(
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> ToolInvokeMessage:
data = tool_parameters["data"]
if not data:
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required"))
if not isinstance(data, dict):
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict"))
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
tools.append(raw_tool)
else:
raw_tool = get_tool_runtime(FINAL_OUTPUT_TOOL)
def invoke_tool(
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> ToolInvokeMessage:
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
tools.append(raw_tool)
return tools

View File

@@ -10,7 +10,6 @@ from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
from core.agent.output_tools import ILLEGAL_OUTPUT_TOOL
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
@@ -57,7 +56,8 @@ class AgentPattern(ABC):
@abstractmethod
def run(
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
@@ -462,8 +462,6 @@ class AgentPattern(ABC):
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
if tool.entity.identity.name == ILLEGAL_OUTPUT_TOOL:
continue
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools

View File

@@ -1,15 +1,10 @@
"""Function Call strategy implementation."""
import json
import uuid
from collections.abc import Generator
from typing import Any, Literal, Protocol, Union, cast
from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
from core.agent.output_tools import (
ILLEGAL_OUTPUT_TOOL,
TERMINAL_OUTPUT_MESSAGE,
)
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -30,7 +25,8 @@ class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
@@ -43,23 +39,9 @@ class FunctionCallStrategy(AgentPattern):
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
class _LLMInvoker(Protocol):
def invoke_llm(
self,
*,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
tools: list[PromptMessageTool],
stop: list[str],
stream: Literal[False],
user: str | None,
callbacks: list[Any],
) -> LLMResult: ...
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
@@ -69,7 +51,8 @@ class FunctionCallStrategy(AgentPattern):
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
@@ -86,63 +69,47 @@ class FunctionCallStrategy(AgentPattern):
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
invoker = cast(_LLMInvoker, self.model_instance)
chunks = invoker.invoke_llm(
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=prompt_tools,
tools=current_tools,
stop=stop,
stream=False,
stream=stream,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, emit_chunks=False
chunks, round_usage, model_log
)
if response_content:
replaced_tool_call = (
str(uuid.uuid4()),
ILLEGAL_OUTPUT_TOOL,
{
"raw": response_content,
},
)
tool_calls.append(replaced_tool_call)
messages.append(self._create_assistant_message("", tool_calls))
messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
assert len(tool_calls) > 0
# Process tool calls
tool_outputs: dict[str, str] = {}
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_entity = self._find_tool_by_name(tool_name)
if not tool_entity:
raise ValueError(f"Tool {tool_name} not found")
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
if tool_response == TERMINAL_OUTPUT_MESSAGE:
function_call_state = False
final_tool_args = tool_entity.transform_tool_parameters_type(tool_args)
if tool_calls:
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
yield self._finish_log(
round_log,
data={
@@ -161,19 +128,8 @@ class FunctionCallStrategy(AgentPattern):
# Return final result
from core.agent.entities import AgentResult
output_payload: str | dict
output_text = final_tool_args.get("text")
output_structured_payload = final_tool_args.get("data")
if isinstance(output_structured_payload, dict):
output_payload = output_structured_payload
elif isinstance(output_text, str):
output_payload = output_text
else:
raise ValueError(f"Final output ({final_tool_args}) is not a string or structured data.")
return AgentResult(
output=output_payload,
text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
@@ -184,8 +140,6 @@ class FunctionCallStrategy(AgentPattern):
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
*,
emit_chunks: bool,
) -> Generator[
LLMResultChunk | AgentLog,
None,
@@ -217,8 +171,7 @@ class FunctionCallStrategy(AgentPattern):
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
if emit_chunks:
yield chunk
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
@@ -233,12 +186,11 @@ class FunctionCallStrategy(AgentPattern):
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
if emit_chunks:
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
@@ -248,14 +200,6 @@ class FunctionCallStrategy(AgentPattern):
)
return tool_calls, response_content, finish_reason
@staticmethod
def _format_output_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:

View File

@@ -4,17 +4,10 @@ from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.agent.output_tools import (
FINAL_OUTPUT_TOOL,
FINAL_STRUCTURED_OUTPUT_TOOL,
ILLEGAL_OUTPUT_TOOL,
OUTPUT_TEXT_TOOL,
OUTPUT_TOOL_NAME_SET,
)
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
@@ -59,7 +52,8 @@ class ReActStrategy(AgentPattern):
self.instruction = instruction
def run(
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
@@ -71,19 +65,6 @@ class ReActStrategy(AgentPattern):
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
available_output_tool_names = {
tool_name
for tool_name in tool_instance_names
if tool_name in OUTPUT_TOOL_NAME_SET and tool_name != ILLEGAL_OUTPUT_TOOL
}
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
terminal_tool_name = FINAL_OUTPUT_TOOL
else:
raise ValueError("No terminal output tool configured")
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
# Add "Observation" to stop sequences
if "Observation" not in stop:
@@ -100,15 +81,10 @@ class ReActStrategy(AgentPattern):
)
yield round_log
# Build prompt with tool restrictions on last iteration
if iteration_step == max_iterations:
tools_for_prompt = [
tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names
]
else:
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL]
# Build prompt with/without tools based on iteration
include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
@@ -130,18 +106,18 @@ class ReActStrategy(AgentPattern):
messages_to_use = current_messages
# Invoke model
chunks = self.model_instance.invoke_llm(
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=False,
stream=stream,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, current_messages, emit_chunks=False
chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
@@ -155,44 +131,28 @@ class ReActStrategy(AgentPattern):
finish_reason = chunk_finish_reason
# Check if we have an action to execute
if scratchpad.action is None:
if not allow_illegal_output:
raise ValueError("Model did not call any tools")
illegal_action = AgentScratchpadUnit.Action(
action_name=ILLEGAL_OUTPUT_TOOL,
action_input={"raw": scratchpad.thought or ""},
)
scratchpad.action = illegal_action
scratchpad.action_str = illegal_action.model_dump_json()
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
scratchpad.observation = observation
output_files.extend(tool_files)
else:
action_name = scratchpad.action.action_name
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
pass # output_text_payload = scratchpad.action.action_input.get("text")
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict):
data = scratchpad.action.action_input.get("data")
if isinstance(data, dict):
pass # structured_output_payload = data
elif action_name == FINAL_OUTPUT_TOOL:
if isinstance(scratchpad.action.action_input, dict):
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
else:
final_text = self._format_output_text(scratchpad.action.action_input)
# Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
# Track files produced by tools
output_files.extend(tool_files)
if action_name == terminal_tool_name:
pass # terminal_output_seen = True
react_state = False
else:
react_state = True
# Add observation to scratchpad for display
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
else:
# Extract final answer
if scratchpad.action and scratchpad.action.action_input:
final_answer = scratchpad.action.action_input
if isinstance(final_answer, dict):
final_answer = json.dumps(final_answer, ensure_ascii=False)
final_text = str(final_answer)
elif scratchpad.thought:
# If no action but we have thought, use thought as final answer
final_text = scratchpad.thought
yield self._finish_log(
round_log,
@@ -208,22 +168,17 @@ class ReActStrategy(AgentPattern):
# Return final result
output_payload: str | dict
# TODO
from core.agent.entities import AgentResult
return AgentResult(
output=output_payload,
files=output_files,
usage=total_usage.get("usage"),
finish_reason=finish_reason,
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
tools: list[Tool] | None,
include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
@@ -240,13 +195,9 @@ class ReActStrategy(AgentPattern):
# Format tools
tools_str = ""
tool_names = []
if tools:
if include_tools and self.tools:
# Convert tools to prompt message tools format
prompt_tools = [
tool.to_prompt_message_tool()
for tool in tools
if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL
]
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
@@ -258,19 +209,12 @@ class ReActStrategy(AgentPattern):
tools_str = "No tools available"
tool_names_str = ""
final_tool_name = FINAL_OUTPUT_TOOL
if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names:
final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
if final_tool_name not in tool_names:
raise ValueError("No terminal output tool available for prompt")
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
@@ -302,12 +246,10 @@ class ReActStrategy(AgentPattern):
def _handle_chunks(
self,
chunks: LLMResult,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
*,
emit_chunks: bool,
) -> Generator[
LLMResultChunk | AgentLog,
None,
@@ -319,20 +261,25 @@ class ReActStrategy(AgentPattern):
"""
usage_dict: dict[str, Any] = {}
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
# Create a generator from the LLMResult
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
@@ -356,18 +303,14 @@ class ReActStrategy(AgentPattern):
scratchpad.action_str = action_str
scratchpad.action = chunk
if emit_chunks:
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
elif isinstance(chunk, str):
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
if emit_chunks:
yield self._create_text_chunk(chunk_text, current_messages)
else:
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
@@ -391,14 +334,6 @@ class ReActStrategy(AgentPattern):
return scratchpad, finish_reason
@staticmethod
def _format_output_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
return json.dumps(value, ensure_ascii=False)
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,

View File

@@ -2,17 +2,13 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from ...app.entities.app_invoke_entities import InvokeFrom
from ...tools.entities.tool_entities import ToolInvokeFrom
from ..output_tools import build_agent_output_tools
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
@@ -29,10 +25,6 @@ class StrategyFactory:
@staticmethod
def create_strategy(
*,
tenant_id: str,
invoke_from: InvokeFrom,
tool_invoke_from: ToolInvokeFrom,
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
@@ -43,16 +35,11 @@ class StrategyFactory:
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
structured_output_schema: Mapping[str, Any] | None = None,
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
tenant_id:
invoke_from:
tool_invoke_from:
structured_output_schema:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
@@ -67,14 +54,6 @@ class StrategyFactory:
Returns:
AgentStrategy instance
"""
output_tools = build_agent_output_tools(
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
structured_output_schema=structured_output_schema,
)
tools.extend(output_tools)
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:

View File

@@ -7,7 +7,7 @@ You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
@@ -32,14 +32,12 @@ Thought: I know what to respond
Action:
```
{
"action": "{{final_tool_name}}",
"action_input": {
"text": "Final response to human"
}
"action": "Final Answer",
"action_input": "Final response to human"
}
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
{{historic_messages}}
Question: {{query}}
{{agent_scratchpad}}
@@ -58,7 +56,7 @@ You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
@@ -83,14 +81,12 @@ Thought: I know what to respond
Action:
```
{
"action": "{{final_tool_name}}",
"action_input": {
"text": "Final response to human"
}
"action": "Final Answer",
"action_input": "Final response to human"
}
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
""" # noqa: E501

View File

@@ -477,6 +477,7 @@ class LLMGenerator:
prompt_messages=complete_messages,
output_model=CodeNodeStructuredOutput,
model_parameters=model_parameters,
tenant_id=tenant_id,
)
return {
@@ -557,14 +558,10 @@ class LLMGenerator:
completion_params = model_config.get("completion_params", {}) if model_config else {}
try:
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
output_model=SuggestedQuestionsOutput,
model_parameters=completion_params,
)
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
model_instance=model_instance, prompt_messages=prompt_messages,
output_model=SuggestedQuestionsOutput,
model_parameters=completion_params, tenant_id=tenant_id)
return {"questions": response.questions, "error": ""}

View File

@@ -1,190 +1,188 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
"""
File reference detection and conversion for structured output.
This module provides utilities to:
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
2. Convert file ID strings to File objects after LLM returns
"""
import uuid
from collections.abc import Mapping
from typing import Any
from core.file import File
from core.variables.segments import ArrayFileSegment, FileSegment
from factories.file_factory import build_from_mapping
FILE_PATH_SCHEMA_TYPE = "file"
FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"}
FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)."
FILE_REF_FORMAT = "dify-file-ref"
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
if schema.get("type") == FILE_PATH_SCHEMA_TYPE:
return True
format_value = schema.get("format")
if not isinstance(format_value, str):
return False
normalized_format = format_value.lower().replace("_", "-")
return normalized_format in FILE_PATH_SCHEMA_FORMATS
def is_file_ref_property(schema: dict) -> bool:
"""Check if a schema property is a file reference."""
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
file_path_fields: list[str] = []
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
"""
Recursively detect file reference fields in schema.
Args:
schema: JSON Schema to analyze
path: Current path in the schema (used for recursion)
Returns:
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
"""
file_ref_paths: list[str] = []
schema_type = schema.get("type")
if schema_type == "object":
properties = schema.get("properties")
if isinstance(properties, Mapping):
properties_mapping = cast(Mapping[str, Any], properties)
for prop_name, prop_schema in properties_mapping.items():
if not isinstance(prop_schema, Mapping):
continue
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
current_path = f"{path}.{prop_name}" if path else prop_name
for prop_name, prop_schema in schema.get("properties", {}).items():
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_path_property(prop_schema_mapping):
file_path_fields.append(current_path)
else:
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
if is_file_ref_property(prop_schema):
file_ref_paths.append(current_path)
elif isinstance(prop_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
elif schema_type == "array":
items_schema = schema.get("items")
if not isinstance(items_schema, Mapping):
return file_path_fields
items_schema_mapping = cast(Mapping[str, Any], items_schema)
items_schema = schema.get("items", {})
array_path = f"{path}[*]" if path else "[*]"
if is_file_path_property(items_schema_mapping):
file_path_fields.append(array_path)
else:
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
if is_file_ref_property(items_schema):
file_ref_paths.append(array_path)
elif isinstance(items_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
return file_path_fields
return file_ref_paths
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
result = _deep_copy_value(schema)
if not isinstance(result, dict):
raise ValueError("structured_output_schema must be a JSON object")
result_dict = cast(dict[str, Any], result)
file_path_fields: list[str] = []
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
return result_dict, file_path_fields
def convert_sandbox_file_paths_in_output(
def convert_file_refs_in_output(
output: Mapping[str, Any],
file_path_fields: Sequence[str],
file_resolver: Callable[[str], File],
) -> tuple[dict[str, Any], list[File]]:
if not file_path_fields:
return dict(output), []
json_schema: Mapping[str, Any],
tenant_id: str,
) -> dict[str, Any]:
"""
Convert file ID strings to File objects based on schema.
result = _deep_copy_value(output)
if not isinstance(result, dict):
raise ValueError("Structured output must be a JSON object")
result_dict = cast(dict[str, Any], result)
Args:
output: The structured_output from LLM result
json_schema: The original JSON schema (to detect file ref fields)
tenant_id: Tenant ID for file lookup
files: list[File] = []
for path in file_path_fields:
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
Returns:
Output with file references converted to File objects
"""
file_ref_paths = detect_file_ref_fields(json_schema)
if not file_ref_paths:
return dict(output)
return result_dict, files
result = _deep_copy_dict(output)
for path in file_ref_paths:
_convert_path_in_place(result, path.split("."), tenant_id)
return result
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
schema_type = schema.get("type")
if schema_type == "object":
properties = schema.get("properties")
if isinstance(properties, Mapping):
properties_mapping = cast(Mapping[str, Any], properties)
for prop_name, prop_schema in properties_mapping.items():
if not isinstance(prop_schema, dict):
continue
prop_schema_dict = cast(dict[str, Any], prop_schema)
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_path_property(prop_schema_dict):
_normalize_file_path_schema(prop_schema_dict)
file_path_fields.append(current_path)
else:
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
elif schema_type == "array":
items_schema = schema.get("items")
if not isinstance(items_schema, dict):
return
items_schema_dict = cast(dict[str, Any], items_schema)
array_path = f"{path}[*]" if path else "[*]"
if is_file_path_property(items_schema_dict):
_normalize_file_path_schema(items_schema_dict)
file_path_fields.append(array_path)
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
"""Deep copy a mapping to a mutable dict."""
result: dict[str, Any] = {}
for key, value in obj.items():
if isinstance(value, Mapping):
result[key] = _deep_copy_dict(value)
elif isinstance(value, list):
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
else:
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
result[key] = value
return result
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
schema["type"] = "string"
schema.pop("format", None)
description = schema.get("description", "")
if description:
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
else:
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
def _deep_copy_value(value: Any) -> Any:
if isinstance(value, Mapping):
mapping = cast(Mapping[str, Any], value)
return {key: _deep_copy_value(item) for key, item in mapping.items()}
if isinstance(value, list):
list_value = cast(list[Any], value)
return [_deep_copy_value(item) for item in list_value]
return value
def _convert_path_in_place(
obj: dict[str, Any],
path_parts: list[str],
file_resolver: Callable[[str], File],
files: list[File],
) -> None:
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
"""Convert file refs at the given path in place, wrapping in Segment types."""
if not path_parts:
return
current = path_parts[0]
remaining = path_parts[1:]
# Handle array notation like "files[*]"
if current.endswith("[*]"):
key = current[:-3] if current != "[*]" else ""
target_value = obj.get(key) if key else obj
key = current[:-3] if current != "[*]" else None
target = obj.get(key) if key else obj
if isinstance(target_value, list):
target_list = cast(list[Any], target_value)
if isinstance(target, list):
if remaining:
for item in target_list:
# Nested array with remaining path - recurse into each item
for item in target:
if isinstance(item, dict):
item_dict = cast(dict[str, Any], item)
_convert_path_in_place(item_dict, remaining, file_resolver, files)
_convert_path_in_place(item, remaining, tenant_id)
else:
resolved_files: list[File] = []
for item in target_list:
if not isinstance(item, str):
raise ValueError("File path must be a string")
file = file_resolver(item)
files.append(file)
resolved_files.append(file)
# Array of file IDs - convert all and wrap in ArrayFileSegment
files: list[File] = []
for item in target:
file = _convert_file_id(item, tenant_id)
if file is not None:
files.append(file)
# Replace the array with ArrayFileSegment
if key:
obj[key] = ArrayFileSegment(value=resolved_files)
obj[key] = ArrayFileSegment(value=files)
return
if not remaining:
if current not in obj:
return
value = obj[current]
if value is None:
obj[current] = None
return
if not isinstance(value, str):
raise ValueError("File path must be a string")
file = file_resolver(value)
files.append(file)
obj[current] = FileSegment(value=file)
return
# Leaf node - convert the value and wrap in FileSegment
if current in obj:
file = _convert_file_id(obj[current], tenant_id)
if file is not None:
obj[current] = FileSegment(value=file)
else:
obj[current] = None
else:
# Recurse into nested object
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, tenant_id)
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, file_resolver, files)
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
"""
Convert a file ID string to a File object.
Tries multiple file sources in order:
1. ToolFile (files generated by tools/workflows)
2. UploadFile (files uploaded by users)
"""
if not isinstance(file_id, str):
return None
# Validate UUID format
try:
uuid.UUID(file_id)
except ValueError:
return None
# Try ToolFile first (files generated by tools/workflows)
try:
return build_from_mapping(
mapping={
"transfer_method": "tool_file",
"tool_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# Try UploadFile (files uploaded by users)
try:
return build_from_mapping(
mapping={
"transfer_method": "local_file",
"upload_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# File not found in any source
return None

View File

@@ -8,7 +8,7 @@ import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import detect_file_path_fields
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
@@ -55,11 +55,12 @@ def invoke_llm_with_structured_output(
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any] | None = None,
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with structured output.
@@ -77,12 +78,14 @@ def invoke_llm_with_structured_output(
:param stop: stop words
:param user: unique user id
:param callbacks: callbacks
:return: response with structured output
:param tenant_id: tenant ID for file reference conversion. When provided and
json_schema contains file reference fields (format: "dify-file-ref"),
file IDs in the output will be automatically converted to File objects.
:return: full response or stream response chunk generator result
"""
model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {})
if detect_file_path_fields(json_schema):
raise OutputParserError("Structured output file paths are only supported in sandbox mode.")
model_parameters_with_json_schema: dict[str, Any] = {
**(model_parameters or {}),
}
# Determine structured output strategy
@@ -119,6 +122,14 @@ def invoke_llm_with_structured_output(
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
@@ -136,11 +147,12 @@ def invoke_llm_with_pydantic_model(
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping[str, Any] | None = None,
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> T:
"""
Invoke large language model with a Pydantic output model.
@@ -148,8 +160,11 @@ def invoke_llm_with_pydantic_model(
This helper generates a JSON schema from the Pydantic model, invokes the
structured-output LLM path, and validates the result.
The helper performs a non-streaming invocation and returns the validated
Pydantic model directly.
The stream parameter controls the underlying LLM invocation mode:
- stream=True (default): Uses streaming LLM call, consumes the generator internally
- stream=False: Uses non-streaming LLM call
In both cases, the function returns the validated Pydantic model directly.
"""
json_schema = _schema_from_pydantic(output_model)
@@ -164,6 +179,7 @@ def invoke_llm_with_pydantic_model(
stop=stop,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
)
structured_output = result.structured_output
@@ -220,27 +236,25 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
return _parse_structured_output(content)
def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]:
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
"""Parse JSON from tool call arguments."""
if not arguments:
raise OutputParserError("Tool call arguments is empty")
try:
parsed_any = json.loads(arguments)
if not isinstance(parsed_any, dict):
parsed = json.loads(arguments)
if not isinstance(parsed, dict):
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
parsed = cast(dict[str, Any], parsed_any)
return parsed
except json.JSONDecodeError:
# Try to repair malformed JSON
repaired_any = json_repair.loads(arguments)
if not isinstance(repaired_any, dict):
repaired = json_repair.loads(arguments)
if not isinstance(repaired, dict):
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
repaired: dict[str, Any] = repaired_any
return repaired
def get_default_value_for_type(type_name: str | list[str] | None) -> Any:
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
"""Get default empty value for a JSON schema type."""
# Handle array of types (e.g., ["string", "null"])
if isinstance(type_name, list):
@@ -297,7 +311,7 @@ def fill_defaults_from_schema(
# Create empty object and recursively fill its required fields
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
else:
result[prop_name] = get_default_value_for_type(prop_type)
result[prop_name] = _get_default_value_for_type(prop_type)
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
# Field exists and is an object, recursively fill nested required fields
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
@@ -308,10 +322,10 @@ def fill_defaults_from_schema(
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
structured_output_schema: Mapping[str, Any],
model_parameters: dict[str, Any],
structured_output_schema: Mapping,
model_parameters: dict,
rules: list[ParameterRule],
) -> dict[str, Any]:
):
"""
Handle structured output for models with native JSON schema support.
@@ -333,7 +347,7 @@ def _handle_native_json_schema(
return model_parameters
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
def _set_response_format(model_parameters: dict, rules: list):
"""
Set the appropriate response format parameter based on model rules.
@@ -349,7 +363,7 @@ def _set_response_format(model_parameters: dict[str, Any], rules: list[Parameter
def _handle_prompt_based_schema(
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any]
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
) -> list[PromptMessage]:
"""
Handle structured output for models without native JSON schema support.
@@ -386,27 +400,28 @@ def _handle_prompt_based_schema(
return updated_prompt
def _parse_structured_output(result_text: str) -> dict[str, Any]:
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
structured_output: Mapping[str, Any] = {}
parsed: Mapping[str, Any] = {}
try:
parsed = TypeAdapter(dict[str, Any]).validate_json(result_text)
return parsed
parsed = TypeAdapter(Mapping).validate_json(result_text)
if not isinstance(parsed, dict):
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except ValidationError:
# if the result_text is not a valid json, try to repair it
temp_parsed: Any = json_repair.loads(result_text)
if isinstance(temp_parsed, list):
temp_parsed_list = cast(list[Any], temp_parsed)
dict_items: list[dict[str, Any]] = []
for item in temp_parsed_list:
if isinstance(item, dict):
dict_items.append(cast(dict[str, Any], item))
temp_parsed = dict_items[0] if dict_items else {}
temp_parsed = json_repair.loads(result_text)
if not isinstance(temp_parsed, dict):
raise OutputParserError(f"Failed to parse structured output: {result_text}")
temp_parsed_dict = cast(dict[str, Any], temp_parsed)
return temp_parsed_dict
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
if isinstance(temp_parsed, list):
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
else:
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = cast(dict, temp_parsed)
return structured_output
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]:
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
"""
Prepare JSON schema based on model requirements.
@@ -418,49 +433,54 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
"""
# Deep copy to avoid modifying the original schema
processed_schema = deepcopy(schema)
processed_schema_dict = dict(processed_schema)
processed_schema = dict(deepcopy(schema))
# Convert boolean types to string types (common requirement)
convert_boolean_to_string(processed_schema_dict)
convert_boolean_to_string(processed_schema)
# Apply model-specific transformations
if SpecialModelType.GEMINI in model_schema.model:
remove_additional_properties(processed_schema_dict)
return processed_schema_dict
if SpecialModelType.OLLAMA in provider:
return processed_schema_dict
# Default format with name field
return {"schema": processed_schema_dict, "name": "llm_response"}
remove_additional_properties(processed_schema)
return processed_schema
elif SpecialModelType.OLLAMA in provider:
return processed_schema
else:
# Default format with name field
return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict[str, Any]) -> None:
def remove_additional_properties(schema: dict):
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(cast(dict[str, Any], value))
remove_additional_properties(value)
elif isinstance(value, list):
for item in cast(list[Any], value):
for item in value:
if isinstance(item, dict):
remove_additional_properties(cast(dict[str, Any], item))
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
def convert_boolean_to_string(schema: dict):
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
@@ -468,8 +488,8 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None:
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(cast(dict[str, Any], value))
convert_boolean_to_string(value)
elif isinstance(value, list):
for item in cast(list[Any], value):
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(cast(dict[str, Any], item))
convert_boolean_to_string(item)

View File

@@ -1,13 +1,11 @@
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar
from typing import TypeVar, Union
from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
if TYPE_CHECKING:
pass
MessageType = TypeVar("MessageType", bound=ToolInvokeMessage)
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
@dataclass
@@ -89,7 +87,7 @@ def merge_blob_chunks(
),
meta=resp.meta,
)
assert isinstance(merged_message, ToolInvokeMessage)
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
yield merged_message # type: ignore
# Clean up the buffer
del files[chunk_id]

View File

@@ -14,8 +14,7 @@ from core.skill.entities import ToolAccessPolicy
from core.skill.entities.tool_dependencies import ToolDependencies
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.helpers import execute, pipeline
from core.virtual_environment.__base.helpers import pipeline
from ..bash.dify_cli import DifyCliConfig
from ..entities import DifyCli
@@ -120,6 +119,21 @@ class SandboxBashSession:
return self._bash_tool
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
"""
Collect files from sandbox output directory and save them as ToolFiles.
Scans the specified output directory in sandbox, downloads each file,
saves it as a ToolFile, and returns a list of File objects. The File
objects will have valid tool_file_id that can be referenced by subsequent
nodes via structured output.
Args:
output_dir: Directory path in sandbox to scan for output files.
Defaults to "output" (relative to workspace).
Returns:
List of File objects representing the collected files.
"""
vm = self._sandbox.vm
collected_files: list[File] = []
@@ -130,6 +144,8 @@ class SandboxBashSession:
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
return collected_files
tool_file_manager = ToolFileManager()
for file_state in file_states:
# Skip files that are too large
if file_state.size > MAX_OUTPUT_FILE_SIZE:
@@ -146,14 +162,47 @@ class SandboxBashSession:
file_content = vm.download_file(file_state.path)
file_binary = file_content.getvalue()
# Determine mime type from extension
filename = os.path.basename(file_state.path)
file_obj = self._create_tool_file(filename=filename, file_binary=file_binary)
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
# Save as ToolFile
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
# Determine file type from mime type
file_type = _get_file_type_from_mime(mime_type)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
# Create File object with tool_file_id as related_id
file_obj = File(
id=tool_file.id, # Use tool_file_id as the File id for easy reference
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
collected_files.append(file_obj)
logger.info(
"Collected sandbox output file: %s -> tool_file_id=%s",
file_state.path,
file_obj.id,
tool_file.id,
)
except Exception as exc:
@@ -167,85 +216,6 @@ class SandboxBashSession:
)
return collected_files
def download_file(self, path: str) -> File:
path_kind = self._detect_path_kind(path)
if path_kind == "dir":
raise ValueError("Directory outputs are not supported")
if path_kind != "file":
raise ValueError(f"Sandbox file not found: {path}")
file_content = self._sandbox.vm.download_file(path)
file_binary = file_content.getvalue()
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
raise ValueError(f"Sandbox file exceeds size limit: {path}")
filename = os.path.basename(path) or "file"
return self._create_tool_file(filename=filename, file_binary=file_binary)
def _detect_path_kind(self, path: str) -> str:
script = r"""
import os
import sys
p = sys.argv[1]
if os.path.isdir(p):
print("dir")
raise SystemExit(0)
if os.path.isfile(p):
print("file")
raise SystemExit(0)
print("none")
raise SystemExit(2)
"""
try:
result = execute(
self._sandbox.vm,
[
"sh",
"-c",
'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"',
script,
path,
],
timeout=10,
error_message="Failed to inspect sandbox path",
)
except CommandExecutionError as exc:
raise ValueError(str(exc)) from exc
return result.stdout.decode("utf-8", errors="replace").strip()
def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File:
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = "application/octet-stream"
tool_file = ToolFileManager().create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mime_type,
filename=filename,
)
file_type = _get_file_type_from_mime(mime_type)
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
url = sign_tool_file(tool_file.id, extension)
return File(
id=tool_file.id,
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=filename,
extension=extension,
mime_type=mime_type,
size=len(file_binary),
related_id=tool_file.id,
url=url,
storage_key=tool_file.file_key,
)
def _get_file_type_from_mime(mime_type: str) -> FileType:
"""Determine FileType from mime type."""

View File

@@ -57,7 +57,7 @@ class Tool(ABC):
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
tool_parameters = self.transform_tool_parameters_type(tool_parameters)
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke(
user_id=user_id,
@@ -82,7 +82,7 @@ class Tool(ABC):
else:
return result
def transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""

View File

@@ -1,5 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
<rect x="3" y="3" width="18" height="18" rx="4" stroke="#2F2F2F" stroke-width="1.5"/>
<path d="M7 12h10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
<path d="M12 7v10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 332 B

View File

@@ -1,8 +0,0 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AgentOutputProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@@ -1,10 +0,0 @@
identity:
author: Dify
name: agent_output
label:
en_US: Agent Output
description:
en_US: Internal tools for agent output control.
icon: icon.svg
tags:
- utilities

View File

@@ -1,17 +0,0 @@
from collections.abc import Generator
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class FinalOutputAnswerTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("Final answer recorded.")

View File

@@ -1,18 +0,0 @@
identity:
name: final_output_answer
author: Dify
label:
en_US: Final Output Answer
description:
human:
en_US: Internal tool to deliver the final answer.
llm: Use this tool when you are ready to provide the final answer.
parameters:
- name: text
type: string
required: true
label:
en_US: Text
human_description:
en_US: Final answer text.
form: llm

View File

@@ -1,17 +0,0 @@
from collections.abc import Generator
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class FinalStructuredOutputTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("Structured output recorded.")

View File

@@ -1,18 +0,0 @@
identity:
name: final_structured_output
author: Dify
label:
en_US: Final Structured Output
description:
human:
en_US: Internal tool to deliver structured output.
llm: Use this tool to provide structured output data.
parameters:
- name: data
type: object
required: true
label:
en_US: Data
human_description:
en_US: Structured output data.
form: llm

View File

@@ -1,21 +0,0 @@
from collections.abc import Generator
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class IllegalOutputTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
message = (
"Protocol violation: do not output plain text. "
"Call an output tool and finish with the configured terminal tool."
)
yield self.create_text_message(message)

View File

@@ -1,18 +0,0 @@
identity:
name: illegal_output
author: Dify
label:
en_US: Illegal Output
description:
human:
en_US: Internal tool for output protocol violations.
llm: Use this tool to correct output protocol violations.
parameters:
- name: raw
type: string
required: true
label:
en_US: Raw Output
human_description:
en_US: Raw model output that violated the protocol.
form: llm

View File

@@ -1,17 +0,0 @@
from collections.abc import Generator
from typing import Any
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class OutputTextTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> Generator[ToolInvokeMessage, None, None]:
yield self.create_text_message("Output recorded.")

View File

@@ -1,18 +0,0 @@
identity:
name: output_text
author: Dify
label:
en_US: Output Text
description:
human:
en_US: Internal tool to store intermediate text output.
llm: Use this tool to emit non-final text output.
parameters:
- name: text
type: string
required: true
label:
en_US: Text
human_description:
en_US: Output text.
form: llm

View File

@@ -1,5 +1,3 @@
from __future__ import annotations
import json
import logging
import mimetypes
@@ -33,8 +31,9 @@ from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.agent.entities import AgentToolEntity
from core.workflow.nodes.tool.entities import ToolEntity
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
@@ -67,8 +66,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
INTERNAL_BUILTIN_TOOL_PROVIDERS = {"agent_output"}
class ApiProviderControllerItem(TypedDict):
provider: ApiToolProvider
@@ -364,7 +361,7 @@ class ToolManager:
app_id: str,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional[VariablePool] = None,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:
"""
get the agent tool runtime
@@ -404,9 +401,9 @@ class ToolManager:
tenant_id: str,
app_id: str,
node_id: str,
workflow_tool: ToolEntity,
workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional[VariablePool] = None,
variable_pool: Optional["VariablePool"] = None,
) -> Tool:
"""
get the workflow tool runtime
@@ -594,10 +591,6 @@ class ToolManager:
cls._hardcoded_providers = {}
cls._builtin_providers_loaded = False
@classmethod
def is_internal_builtin_provider(cls, provider_name: str) -> bool:
return provider_name in INTERNAL_BUILTIN_TOOL_PROVIDERS
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
"""
@@ -635,9 +628,9 @@ class ToolManager:
# MySQL: Use window function to achieve same result
sql = """
SELECT id FROM (
SELECT id,
SELECT id,
ROW_NUMBER() OVER (
PARTITION BY tenant_id, provider
PARTITION BY tenant_id, provider
ORDER BY is_default DESC, created_at DESC
) as rn
FROM tool_builtin_providers
@@ -674,8 +667,6 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
if cls.is_internal_builtin_provider(provider.entity.identity.name):
continue
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
@@ -1019,7 +1010,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
variable_pool: Optional[VariablePool],
variable_pool: Optional["VariablePool"],
tool_configurations: dict[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:

View File

@@ -577,12 +577,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_schema=None,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
tenant_id=self.tenant_id,
)
for event in generator:

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCall, ToolCallResult
@@ -156,9 +156,6 @@ class LLMGenerationData(BaseModel):
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
structured_output: LLMStructuredOutput | None = Field(
default=None, description="Structured output from tool-only agent runs"
)
class ThinkTagStreamParser:
@@ -287,7 +284,6 @@ class AggregatedResult(BaseModel):
files: list[File] = Field(default_factory=list)
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
finish_reason: str | None = None
structured_output: LLMStructuredOutput | None = None
class AgentContext(BaseModel):
@@ -387,7 +383,7 @@ class LLMNodeData(BaseNodeData):
Strategy for handling model reasoning output.
separated: Return clean text (without <think> tags) + reasoning_content field.
Recommended for new workflows. Enables safe downstream parsing and
Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with <think> tags) + reasoning_content field.

View File

@@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str:
"""
Build a text description of generated files for inclusion in context.
The description includes file_id for context; structured output file paths
are only supported in sandbox mode.
The description includes file_id which can be used by subsequent nodes
to reference the files via structured output.
"""
if not files:
return ""

View File

@@ -13,24 +13,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.output_tools import (
FINAL_OUTPUT_TOOL,
FINAL_STRUCTURED_OUTPUT_TOOL,
OUTPUT_TEXT_TOOL,
build_agent_output_tools,
)
from core.agent.patterns import StrategyFactory
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app_assets.constants import AppAssetsAttrs
from core.file import FileTransferMethod, FileType, file_manager
from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import (
adapt_schema_for_sandbox_file_paths,
convert_sandbox_file_paths_in_output,
detect_file_path_fields,
)
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
@@ -73,7 +62,6 @@ from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.skill.skill_compiler import SkillCompiler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.signature import sign_upload_file
from core.tools.tool_manager import ToolManager
from core.variables import (
@@ -197,11 +185,12 @@ class LLMNode(Node[LLMNodeData]):
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
usage: LLMUsage = LLMUsage.empty_usage()
finish_reason: str | None = None
reasoning_content: str = "" # Initialize as empty string for consistency
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
variable_pool: VariablePool = self.graph_runtime_state.variable_pool
variable_pool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
@@ -261,9 +250,8 @@ class LLMNode(Node[LLMNodeData]):
)
query: str | None = None
memory_config = self.node_data.memory
if memory_config:
query = memory_config.query_prompt_template
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@@ -305,23 +293,9 @@ class LLMNode(Node[LLMNodeData]):
sandbox=self.graph_runtime_state.sandbox,
)
structured_output_schema: Mapping[str, Any] | None
structured_output_file_paths: list[str] = []
if self.node_data.structured_output_enabled:
if not self.node_data.structured_output:
raise ValueError("structured_output_enabled is True but structured_output is not set")
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
if self.node_data.computer_use:
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
raw_schema
)
else:
if detect_file_path_fields(raw_schema):
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
structured_output_schema = raw_schema
else:
structured_output_schema = None
# Variables for outputs
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
if self.node_data.computer_use:
sandbox = self.graph_runtime_state.sandbox
@@ -335,10 +309,7 @@ class LLMNode(Node[LLMNodeData]):
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
structured_output_schema=structured_output_schema,
structured_output_file_paths=structured_output_file_paths,
)
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
model_instance=model_instance,
@@ -348,7 +319,6 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
structured_output_schema=structured_output_schema,
)
else:
# Use traditional LLM invocation
@@ -358,7 +328,8 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_schema=structured_output_schema,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
@@ -384,8 +355,6 @@ class LLMNode(Node[LLMNodeData]):
reasoning_content = ""
usage = generation_data.usage
finish_reason = generation_data.finish_reason
if generation_data.structured_output:
structured_output = generation_data.structured_output
# Unified process_data building
process_data = {
@@ -409,7 +378,16 @@ class LLMNode(Node[LLMNodeData]):
if tool.enabled
]
# Build generation field and determine files_to_output first
# Unified outputs building
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
}
# Build generation field
if generation_data:
# Use generation_data from tool invocation (supports multi-turn)
generation = {
@@ -437,15 +415,6 @@ class LLMNode(Node[LLMNodeData]):
}
files_to_output = self._file_outputs
# Unified outputs building (files passed to context for subsequent node reference)
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data, files=files_to_output),
}
outputs["generation"] = generation
if files_to_output:
outputs["files"] = ArrayFileSegment(value=files_to_output)
@@ -524,7 +493,8 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
structured_output_schema: Mapping[str, Any] | None,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@@ -538,7 +508,10 @@ class LLMNode(Node[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if structured_output_schema:
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
@@ -546,12 +519,12 @@ class LLMNode(Node[LLMNodeData]):
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=structured_output_schema,
json_schema=output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
user=user_id,
tenant_id=tenant_id,
)
else:
request_start_time = time.perf_counter()
@@ -1288,16 +1261,18 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
if isinstance(prompt_content, str):
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
@@ -1307,12 +1282,13 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
if isinstance(prompt_content, str):
if prompt_content_type == str:
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
@@ -1438,11 +1414,9 @@ class LLMNode(Node[LLMNodeData]):
if isinstance(item, PromptMessageContext):
if len(item.value_selector) >= 2:
prompt_context_selectors.append(item.value_selector)
elif isinstance(item, LLMNodeChatModelMessage):
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
variable_template_parser = VariableTemplateParser(template=item.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
@@ -1478,14 +1452,13 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type == "jinja2":
enable_jinja = True
else:
for prompt in prompt_template:
if prompt.edition_type == "jinja2":
if isinstance(prompt_template, list):
for item in prompt_template:
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
enable_jinja = True
break
else:
enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
@@ -1898,7 +1871,6 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
structured_output_schema: Mapping[str, Any] | None,
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
@@ -1915,16 +1887,12 @@ class LLMNode(Node[LLMNodeData]):
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema,
)
# Run strategy
@@ -1932,6 +1900,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
result = yield from self._process_tool_outputs(outputs)
@@ -1945,12 +1914,8 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
structured_output_schema: Mapping[str, Any] | None,
structured_output_file_paths: Sequence[str] | None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]:
) -> Generator[NodeEventBase, None, LLMGenerationData]:
result: LLMGenerationData | None = None
sandbox_output_files: list[File] = []
structured_output_files: list[File] = []
# FIXME(Mairuis): Async processing for bash session.
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
@@ -1958,9 +1923,6 @@ class LLMNode(Node[LLMNodeData]):
model_features = self._get_model_features(model_instance)
strategy = StrategyFactory.create_strategy(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
model_features=model_features,
model_instance=model_instance,
tools=[session.bash_tool],
@@ -1968,55 +1930,20 @@ class LLMNode(Node[LLMNodeData]):
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
structured_output_schema=structured_output_schema,
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
result = yield from self._process_tool_outputs(outputs)
if result and result.structured_output and structured_output_file_paths:
structured_output_payload = result.structured_output.structured_output or {}
try:
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
output=structured_output_payload,
file_path_fields=structured_output_file_paths,
file_resolver=session.download_file,
)
except ValueError as exc:
raise LLMNodeError(str(exc)) from exc
result = result.model_copy(
update={"structured_output": LLMStructuredOutput(structured_output=converted_output)}
)
# Collect output files from sandbox before session ends
# Files are saved as ToolFiles with valid tool_file_id for later reference
sandbox_output_files = session.collect_output_files()
if result is None:
raise LLMNodeError("SandboxSession exited unexpectedly")
structured_output = result.structured_output
if structured_output is not None:
yield structured_output
# Merge sandbox output files into result
if sandbox_output_files or structured_output_files:
result = LLMGenerationData(
text=result.text,
reasoning_contents=result.reasoning_contents,
tool_calls=result.tool_calls,
sequence=result.sequence,
usage=result.usage,
finish_reason=result.finish_reason,
files=result.files + sandbox_output_files + structured_output_files,
trace=result.trace,
)
return result
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
@@ -2084,20 +2011,6 @@ class LLMNode(Node[LLMNodeData]):
logger.warning("Failed to load tool %s: %s", tool, str(e))
continue
structured_output_schema = None
if self._node_data.structured_output_enabled:
structured_output_schema = LLMNode.fetch_structured_output_schema(
structured_output=self._node_data.structured_output or {},
)
tool_instances.extend(
build_agent_output_tools(
tenant_id=self.tenant_id,
invoke_from=self.invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
structured_output_schema=structured_output_schema,
)
)
return tool_instances
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
@@ -2261,45 +2174,18 @@ class LLMNode(Node[LLMNodeData]):
# Add tool call to pending list for model segment
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
output_tool_names = {OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL}
if tool_name not in output_tool_names:
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
icon=tool_icon,
icon_dark=tool_icon_dark,
),
is_final=False,
)
if tool_name in output_tool_names:
content = ""
if tool_name in (OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL):
content = payload.tool_args["text"]
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
raw_content = json.dumps(
payload.tool_args["data"],
ensure_ascii=False,
indent=2
)
content = f"```json\n{raw_content}\n```"
if content:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=content,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=content,
is_final=False,
)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
icon=tool_icon,
icon_dark=tool_icon_dark,
),
is_final=False,
)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = payload.tool_name
@@ -2543,7 +2429,6 @@ class LLMNode(Node[LLMNodeData]):
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
# FIXME: These if will never happen
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
@@ -2586,15 +2471,8 @@ class LLMNode(Node[LLMNodeData]):
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
)
text_content: str
if aggregate.text:
text_content = aggregate.text
elif aggregate.structured_output:
text_content = json.dumps(aggregate.structured_output.structured_output)
else:
raise ValueError("Aggregate must have either text or structured output.")
return LLMGenerationData(
text=text_content,
text=aggregate.text,
reasoning_contents=buffers.reasoning_per_turn,
tool_calls=tool_calls_for_generation,
sequence=sequence,
@@ -2602,7 +2480,6 @@ class LLMNode(Node[LLMNodeData]):
finish_reason=aggregate.finish_reason,
files=aggregate.files,
trace=trace_state.trace_segments,
structured_output=aggregate.structured_output,
)
def _process_tool_outputs(
@@ -2613,33 +2490,22 @@ class LLMNode(Node[LLMNodeData]):
state = ToolOutputState()
try:
while True:
output = next(outputs)
for output in outputs:
if isinstance(output, AgentLog):
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
else:
continue
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
except StopIteration as exception:
if not isinstance(exception.value, AgentResult):
raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
state.agent.agent_result = exception.value
agent_result = state.agent.agent_result
if not agent_result:
raise ValueError("No agent result found in tool outputs")
output_payload = agent_result.output
if isinstance(output_payload, dict):
state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload)
state.aggregate.text = json.dumps(output_payload)
elif isinstance(output_payload, str):
state.aggregate.text = output_payload
else:
raise ValueError(f"Unexpected output type: {type(output_payload)}")
if isinstance(getattr(exception, "value", None), AgentResult):
state.agent.agent_result = exception.value
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
state.aggregate.usage = state.agent.agent_result.usage
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
if state.agent.agent_result:
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
state.aggregate.usage = state.agent.agent_result.usage
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()

View File

@@ -156,7 +156,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_schema=None,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,

View File

@@ -91,8 +91,6 @@ class BuiltinToolManageService:
:return: the list of tools
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
return []
tools = provider_controller.get_tools()
result: list[ToolApiEntity] = []
@@ -543,8 +541,6 @@ class BuiltinToolManageService:
for provider_controller in provider_controllers:
try:
if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
continue
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,

View File

@@ -126,8 +126,9 @@ workflow:
additionalProperties: false
properties:
image:
description: Sandbox file path of the selected image
type: file
description: File ID (UUID) of the selected image
format: dify-file-ref
type: string
required:
- image
type: object

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Generator
from core.agent.entities import AgentScratchpadUnit
@@ -63,7 +64,7 @@ def test_cot_output_parser():
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
assert result.model_dump() == test_case["action"]
output += result.model_dump_json()
assert result.to_dict() == test_case["action"]
output += json.dumps(result.to_dict())
if test_case["output"]:
assert output == test_case["output"]

View File

@@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
class ConcreteAgentPattern(AgentPattern):
"""Concrete implementation of AgentPattern for testing."""
def run(self, prompt_messages, model_parameters, stop=[]):
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
"""Minimal implementation for testing."""
yield from []

View File

@@ -329,15 +329,13 @@ class TestAgentLogProcessing:
)
result = AgentResult(
output="Final answer",
text="Final answer",
files=[],
usage=usage,
finish_reason="stop",
)
output_payload = result.output
assert isinstance(output_payload, str)
assert output_payload == "Final answer"
assert result.text == "Final answer"
assert result.files == []
assert result.usage == usage
assert result.finish_reason == "stop"

View File

@@ -153,7 +153,7 @@ class TestAgentScratchpadUnit:
action_input={"query": "test"},
)
result = action.model_dump()
result = action.to_dict()
assert result == {
"action": "search",

View File

@@ -1,93 +1,269 @@
"""
Unit tests for sandbox file path detection and conversion.
Unit tests for file reference detection and conversion.
"""
import uuid
from unittest.mock import MagicMock, patch
import pytest
from core.file import File, FileTransferMethod, FileType
from core.llm_generator.output_parser.file_ref import (
FILE_PATH_DESCRIPTION_SUFFIX,
adapt_schema_for_sandbox_file_paths,
convert_sandbox_file_paths_in_output,
detect_file_path_fields,
is_file_path_property,
FILE_REF_FORMAT,
convert_file_refs_in_output,
detect_file_ref_fields,
is_file_ref_property,
)
from core.variables.segments import ArrayFileSegment, FileSegment
def _build_file(file_id: str) -> File:
return File(
id=file_id,
tenant_id="tenant_123",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
mime_type="image/png",
size=128,
related_id=file_id,
storage_key="sandbox/path",
)
class TestIsFileRefProperty:
"""Tests for is_file_ref_property function."""
def test_valid_file_ref(self):
schema = {"type": "string", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is True
def test_invalid_type(self):
schema = {"type": "number", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is False
def test_missing_format(self):
schema = {"type": "string"}
assert is_file_ref_property(schema) is False
def test_wrong_format(self):
schema = {"type": "string", "format": "uuid"}
assert is_file_ref_property(schema) is False
class TestFilePathSchema:
def test_is_file_path_property(self):
assert is_file_path_property({"type": "file"}) is True
assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True
assert is_file_path_property({"type": "string"}) is False
class TestDetectFileRefFields:
"""Tests for detect_file_ref_fields function."""
def test_detect_file_path_fields(self):
def test_simple_file_ref(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": "dify-file-ref"},
"files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}},
"meta": {"type": "object", "properties": {"doc": {"type": "file"}}},
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
paths = detect_file_ref_fields(schema)
assert paths == ["image"]
def test_adapt_schema_for_sandbox_file_paths(self):
def test_multiple_file_refs(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": "dify-file-ref"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"document": {"type": "string", "format": FILE_REF_FORMAT},
"name": {"type": "string"},
},
}
adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "document"}
assert set(fields) == {"image"}
adapted_image = adapted["properties"]["image"]
assert adapted_image["type"] == "string"
assert "format" not in adapted_image
assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
def test_array_of_file_refs(self):
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["files[*]"]
def test_nested_file_ref(self):
schema = {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["data.image"]
def test_no_file_refs(self):
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_empty_schema(self):
schema = {}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_mixed_schema(self):
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"documents": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "documents[*]"}
class TestConvertSandboxFilePaths:
def test_convert_sandbox_file_paths(self):
output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"}
class TestConvertFileRefsInOutput:
"""Tests for convert_file_refs_in_output function."""
def resolver(path: str) -> File:
return _build_file(path)
@pytest.fixture
def mock_file(self):
"""Create a mock File object with all required attributes."""
file = MagicMock(spec=File)
file.type = FileType.IMAGE
file.transfer_method = FileTransferMethod.TOOL_FILE
file.related_id = "test-related-id"
file.remote_url = None
file.tenant_id = "tenant_123"
file.id = None
file.filename = "test.png"
file.extension = ".png"
file.mime_type = "image/png"
file.size = 1024
file.dify_model_identity = "__dify__file__"
return file
converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver)
@pytest.fixture
def mock_build_from_mapping(self, mock_file):
"""Mock the build_from_mapping function."""
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.return_value = mock_file
yield mock
assert isinstance(converted["image"], FileSegment)
assert isinstance(converted["files"], ArrayFileSegment)
assert converted["name"] == "demo"
assert [file.id for file in files] == ["a.png", "b.png", "c.png"]
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
def test_invalid_path_value_raises(self):
def resolver(path: str) -> File:
return _build_file(path)
result = convert_file_refs_in_output(output, schema, "tenant_123")
with pytest.raises(ValueError):
convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver)
# Result should be wrapped in FileSegment
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
mock_build_from_mapping.assert_called_once_with(
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
tenant_id="tenant_123",
)
def test_no_file_paths_returns_output(self):
output = {"name": "demo"}
converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
output = {"files": [file_id1, file_id2]}
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
assert converted == output
assert files == []
result = convert_file_refs_in_output(output, schema, "tenant_123")
# Result should be wrapped in ArrayFileSegment
assert isinstance(result["files"], ArrayFileSegment)
assert list(result["files"].value) == [mock_file, mock_file]
assert mock_build_from_mapping.call_count == 2
def test_no_conversion_without_file_refs(self):
output = {"name": "test", "count": 5}
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result == {"name": "test", "count": 5}
def test_invalid_uuid_returns_none(self):
output = {"image": "not-a-valid-uuid"}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_file_not_found_returns_none(self):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.side_effect = ValueError("File not found")
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"query": "search term", "image": file_id, "count": 10}
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["query"] == "search term"
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
assert result["count"] == 10
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
original = {"image": file_id}
output = dict(original)
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
convert_file_refs_in_output(output, schema, "tenant_123")
# Original should still contain the string ID
assert original["image"] == file_id

View File

@@ -1,6 +1,4 @@
from collections.abc import Mapping
from decimal import Decimal
from typing import Any, NotRequired, TypedDict
from unittest.mock import MagicMock, patch
import pytest
@@ -8,14 +6,16 @@ from pydantic import BaseModel, ConfigDict
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
_get_default_value_for_type,
fill_defaults_from_schema,
get_default_value_for_type,
invoke_llm_with_pydantic_model,
invoke_llm_with_structured_output,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
@@ -25,29 +25,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelType,
ParameterRule,
ParameterType,
)
class StructuredOutputTestCase(TypedDict):
name: str
provider: str
model_name: str
support_structure_output: bool
stream: bool
json_schema: Mapping[str, Any]
expected_llm_response: LLMResult
expected_result_type: type[LLMResultWithStructuredOutput] | None
should_raise: bool
expected_error: NotRequired[type[OutputParserError]]
parameter_rules: NotRequired[list[ParameterRule]]
SchemaData = dict[str, Any]
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
@@ -93,7 +71,7 @@ def get_model_instance() -> MagicMock:
def test_structured_output_parser():
"""Test cases for invoke_llm_with_structured_output function"""
testcases: list[StructuredOutputTestCase] = [
testcases = [
# Test case 1: Model with native structured output support, non-streaming
{
"name": "native_structured_output_non_streaming",
@@ -110,6 +88,39 @@ def test_structured_output_parser():
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 2: Model with native structured output support, streaming
{
"name": "native_structured_output_streaming",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": True,
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"name":'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' "test"}'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 3: Model without native structured output support, non-streaming
{
"name": "prompt_based_structured_output_non_streaming",
@@ -126,24 +137,78 @@ def test_structured_output_parser():
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
# Test case 4: Model without native structured output support, streaming
{
"name": "non_streaming_with_list_content",
"name": "prompt_based_structured_output_streaming",
"provider": "anthropic",
"model_name": "claude-3-sonnet",
"support_structure_output": False,
"stream": True,
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
"expected_llm_response": [
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"answer": "test'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
),
),
LLMResultChunk(
model="claude-3-sonnet",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' response"}'),
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 5: Streaming with list content
{
"name": "streaming_with_list_content",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
"stream": False,
"stream": True,
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"data":'),
TextPromptMessageContent(data=' "value"}'),
]
"expected_llm_response": [
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"data":'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
),
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data=' "value"}'),
]
),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
),
),
],
"expected_result_type": "generator",
"should_raise": False,
},
# Test case 6: Error case - non-string LLM response content (non-streaming)
@@ -188,13 +253,7 @@ def test_structured_output_parser():
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
"parameter_rules": [
ParameterRule(
name="response_format",
label=I18nObject(en_US="response_format"),
type=ParameterType.STRING,
required=False,
options=["json_schema"],
),
MagicMock(name="response_format", options=["json_schema"], required=False),
],
"expected_llm_response": LLMResult(
model="gpt-4o",
@@ -213,13 +272,7 @@ def test_structured_output_parser():
"stream": False,
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
"parameter_rules": [
ParameterRule(
name="response_format",
label=I18nObject(en_US="response_format"),
type=ParameterType.STRING,
required=False,
options=["JSON"],
),
MagicMock(name="response_format", options=["JSON"], required=False),
],
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
@@ -232,72 +285,89 @@ def test_structured_output_parser():
]
for case in testcases:
provider = case["provider"]
model_name = case["model_name"]
support_structure_output = case["support_structure_output"]
json_schema = case["json_schema"]
stream = case["stream"]
# Setup model entity
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
model_schema = get_model_entity(provider, model_name, support_structure_output)
parameter_rules = case.get("parameter_rules")
if parameter_rules is not None:
model_schema.parameter_rules = parameter_rules
# Add parameter rules if specified
if "parameter_rules" in case:
model_schema.parameter_rules = case["parameter_rules"]
# Setup model instance
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = case["expected_llm_response"]
# Setup prompt messages
prompt_messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Generate a response according to the schema."),
]
if case["should_raise"]:
expected_error = case.get("expected_error", OutputParserError)
with pytest.raises(expected_error): # noqa: PT012
if stream:
# Test error cases
with pytest.raises(case["expected_error"]): # noqa: PT012
if case["stream"]:
result_generator = invoke_llm_with_structured_output(
provider=provider,
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
json_schema=case["json_schema"],
)
# Consume the generator to trigger the error
list(result_generator)
else:
invoke_llm_with_structured_output(
provider=provider,
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
json_schema=case["json_schema"],
)
else:
# Test successful cases
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Configure json_repair mock for cases that need it
if case["name"] == "json_repair_scenario":
mock_json_repair.return_value = {"name": "test"}
result = invoke_llm_with_structured_output(
provider=provider,
provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
json_schema=case["json_schema"],
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
expected_result_type = case["expected_result_type"]
assert expected_result_type is not None
assert isinstance(result, expected_result_type)
assert result.model == model_name
assert result.structured_output is not None
assert isinstance(result.structured_output, dict)
if case["expected_result_type"] == "generator":
# Test streaming results
assert hasattr(result, "__iter__")
chunks = list(result)
assert len(chunks) > 0
# Verify all chunks are LLMResultChunkWithStructuredOutput
for chunk in chunks[:-1]: # All except last
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
assert chunk.model == case["model_name"]
# Last chunk should have structured output
last_chunk = chunks[-1]
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
assert last_chunk.structured_output is not None
assert isinstance(last_chunk.structured_output, dict)
else:
# Test non-streaming results
assert isinstance(result, case["expected_result_type"])
assert result.model == case["model_name"]
assert result.structured_output is not None
assert isinstance(result.structured_output, dict)
# Verify model_instance.invoke_llm was called with correct parameters
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
assert call_args.kwargs["stream"] == stream
assert call_args.kwargs["stream"] == case["stream"]
assert call_args.kwargs["user"] == "test_user"
assert "temperature" in call_args.kwargs["model_parameters"]
assert "max_tokens" in call_args.kwargs["model_parameters"]
@@ -306,32 +376,45 @@ def test_structured_output_parser():
def test_parse_structured_output_edge_cases():
"""Test edge cases for structured output parsing"""
provider = "deepseek"
model_name = "deepseek-r1"
support_structure_output = False
json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}}
expected_llm_response = LLMResult(
model="deepseek-r1",
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
# Test case with list that contains dict (reasoning model scenario)
testcase_list_with_dict = {
"name": "list_with_dict_parsing",
"provider": "deepseek",
"model_name": "deepseek-r1",
"support_structure_output": False,
"stream": False,
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
"expected_llm_response": LLMResult(
model="deepseek-r1",
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
),
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
}
# Setup for list parsing test
model_schema = get_model_entity(
testcase_list_with_dict["provider"],
testcase_list_with_dict["model_name"],
testcase_list_with_dict["support_structure_output"],
)
model_schema = get_model_entity(provider, model_name, support_structure_output)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = expected_llm_response
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
prompt_messages = [UserPromptMessage(content="Test reasoning")]
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
# Mock json_repair to return a list with dict
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
result = invoke_llm_with_structured_output(
provider=provider,
provider=testcase_list_with_dict["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
json_schema=testcase_list_with_dict["json_schema"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
@@ -341,16 +424,18 @@ def test_parse_structured_output_edge_cases():
def test_model_specific_schema_preparation():
"""Test schema preparation for different model types"""
provider = "google"
model_name = "gemini-pro"
support_structure_output = True
json_schema: SchemaData = {
"type": "object",
"properties": {"result": {"type": "boolean"}},
"additionalProperties": False,
# Test Gemini model
gemini_case = {
"provider": "google",
"model_name": "gemini-pro",
"support_structure_output": True,
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
}
model_schema = get_model_entity(provider, model_name, support_structure_output)
model_schema = get_model_entity(
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
@@ -362,11 +447,11 @@ def test_model_specific_schema_preparation():
prompt_messages = [UserPromptMessage(content="Test")]
result = invoke_llm_with_structured_output(
provider=provider,
provider=gemini_case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
json_schema=gemini_case["json_schema"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
@@ -408,26 +493,40 @@ def test_structured_output_with_pydantic_model_non_streaming():
assert result.name == "test"
def test_structured_output_with_pydantic_model_list_content():
def test_structured_output_with_pydantic_model_streaming():
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(
content=[
TextPromptMessageContent(data='{"name":'),
TextPromptMessageContent(data=' "test"}'),
]
),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
)
def mock_streaming_response():
yield LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content='{"name":'),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
),
)
yield LLMResultChunk(
model="gpt-4o",
prompt_messages=[UserPromptMessage(content="test")],
system_fingerprint="test",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=' "test"}'),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
),
)
model_instance.invoke_llm.return_value = mock_streaming_response()
result = invoke_llm_with_pydantic_model(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
output_model=ExampleOutput,
output_model=ExampleOutput
)
assert isinstance(result, ExampleOutput)
@@ -449,51 +548,51 @@ def test_structured_output_with_pydantic_model_validation_error():
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="test")],
output_model=ExampleOutput,
output_model=ExampleOutput
)
class TestGetDefaultValueForType:
"""Test cases for get_default_value_for_type function"""
"""Test cases for _get_default_value_for_type function"""
def test_string_type(self):
assert get_default_value_for_type("string") == ""
assert _get_default_value_for_type("string") == ""
def test_object_type(self):
assert get_default_value_for_type("object") == {}
assert _get_default_value_for_type("object") == {}
def test_array_type(self):
assert get_default_value_for_type("array") == []
assert _get_default_value_for_type("array") == []
def test_number_type(self):
assert get_default_value_for_type("number") == 0
assert _get_default_value_for_type("number") == 0
def test_integer_type(self):
assert get_default_value_for_type("integer") == 0
assert _get_default_value_for_type("integer") == 0
def test_boolean_type(self):
assert get_default_value_for_type("boolean") is False
assert _get_default_value_for_type("boolean") is False
def test_null_type(self):
assert get_default_value_for_type("null") is None
assert _get_default_value_for_type("null") is None
def test_none_type(self):
assert get_default_value_for_type(None) is None
assert _get_default_value_for_type(None) is None
def test_unknown_type(self):
assert get_default_value_for_type("unknown") is None
assert _get_default_value_for_type("unknown") is None
def test_union_type_string_null(self):
# ["string", "null"] should return "" (first non-null type)
assert get_default_value_for_type(["string", "null"]) == ""
assert _get_default_value_for_type(["string", "null"]) == ""
def test_union_type_null_first(self):
# ["null", "integer"] should return 0 (first non-null type)
assert get_default_value_for_type(["null", "integer"]) == 0
assert _get_default_value_for_type(["null", "integer"]) == 0
def test_union_type_only_null(self):
# ["null"] should return None
assert get_default_value_for_type(["null"]) is None
assert _get_default_value_for_type(["null"]) is None
class TestFillDefaultsFromSchema:
@@ -501,7 +600,7 @@ class TestFillDefaultsFromSchema:
def test_simple_required_fields(self):
"""Test filling simple required fields"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
@@ -510,7 +609,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["name", "age"],
}
output: SchemaData = {"name": "Alice"}
output = {"name": "Alice"}
result = fill_defaults_from_schema(output, schema)
@@ -520,7 +619,7 @@ class TestFillDefaultsFromSchema:
def test_non_required_fields_not_filled(self):
"""Test that non-required fields are not filled"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"required_field": {"type": "string"},
@@ -528,7 +627,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["required_field"],
}
output: SchemaData = {}
output = {}
result = fill_defaults_from_schema(output, schema)
@@ -537,7 +636,7 @@ class TestFillDefaultsFromSchema:
def test_nested_object_required_fields(self):
"""Test filling nested object required fields"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"user": {
@@ -560,7 +659,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["user"],
}
output: SchemaData = {
output = {
"user": {
"name": "Alice",
"address": {
@@ -585,7 +684,7 @@ class TestFillDefaultsFromSchema:
def test_missing_nested_object_created(self):
"""Test that missing required nested objects are created"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"metadata": {
@@ -599,7 +698,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["metadata"],
}
output: SchemaData = {}
output = {}
result = fill_defaults_from_schema(output, schema)
@@ -611,7 +710,7 @@ class TestFillDefaultsFromSchema:
def test_all_types_default_values(self):
"""Test default values for all types"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"str_field": {"type": "string"},
@@ -623,7 +722,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
}
output: SchemaData = {}
output = {}
result = fill_defaults_from_schema(output, schema)
@@ -638,7 +737,7 @@ class TestFillDefaultsFromSchema:
def test_existing_values_preserved(self):
"""Test that existing values are not overwritten"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
@@ -646,7 +745,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["name", "count"],
}
output: SchemaData = {"name": "Bob", "count": 42}
output = {"name": "Bob", "count": 42}
result = fill_defaults_from_schema(output, schema)
@@ -654,7 +753,7 @@ class TestFillDefaultsFromSchema:
def test_complex_nested_structure(self):
"""Test complex nested structure with multiple levels"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"user": {
@@ -690,7 +789,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["user", "tags", "metadata", "is_active"],
}
output: SchemaData = {
output = {
"user": {
"name": "Alice",
"age": 25,
@@ -730,8 +829,8 @@ class TestFillDefaultsFromSchema:
def test_empty_schema(self):
"""Test with empty schema"""
schema: SchemaData = {}
output: SchemaData = {"any": "value"}
schema = {}
output = {"any": "value"}
result = fill_defaults_from_schema(output, schema)
@@ -739,14 +838,14 @@ class TestFillDefaultsFromSchema:
def test_schema_without_required(self):
"""Test schema without required field"""
schema: SchemaData = {
schema = {
"type": "object",
"properties": {
"optional1": {"type": "string"},
"optional2": {"type": "integer"},
},
}
output: SchemaData = {}
output = {}
result = fill_defaults_from_schema(output, schema)