diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py
index c67d2d831d..2ee0a23aab 100644
--- a/api/core/agent/agent_app_runner.py
+++ b/api/core/agent/agent_app_runner.py
@@ -1,4 +1,3 @@
-import json
import logging
from collections.abc import Generator
from copy import deepcopy
@@ -24,7 +23,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
-from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMeta
+from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
@@ -103,13 +102,11 @@ class AgentAppRunner(BaseAgentRunner):
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
- tenant_id=self.tenant_id,
- invoke_from=self.application_generate_entity.invoke_from,
- tool_invoke_from=ToolInvokeFrom.WORKFLOW,
)
# Initialize state variables
current_agent_thought_id = None
+ has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
@@ -121,6 +118,7 @@ class AgentAppRunner(BaseAgentRunner):
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
+ stream=True,
)
# Consume generator and collect result
@@ -135,10 +133,17 @@ class AgentAppRunner(BaseAgentRunner):
break
if isinstance(output, LLMResultChunk):
- # No more expect streaming data
- continue
+ # Handle LLM chunk
+ if current_agent_thought_id and not has_published_thought:
+ self.queue_manager.publish(
+ QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
+ PublishFrom.APPLICATION_MANAGER,
+ )
+ has_published_thought = True
- else:
+ yield output
+
+ elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
@@ -151,6 +156,7 @@ class AgentAppRunner(BaseAgentRunner):
tool_input="",
messages_ids=message_file_ids,
)
+ has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
@@ -258,30 +264,23 @@ class AgentAppRunner(BaseAgentRunner):
raise
# Process final result
- if not isinstance(result, AgentResult):
- raise ValueError("Agent did not return AgentResult")
- output_payload = result.output
- if isinstance(output_payload, dict):
- final_answer = json.dumps(output_payload, ensure_ascii=False)
- elif isinstance(output_payload, str):
- final_answer = output_payload
- else:
- raise ValueError("Final output is not a string or structured data.")
- usage = result.usage or LLMUsage.empty_usage()
+ if isinstance(result, AgentResult):
+ final_answer = result.text
+ usage = result.usage or LLMUsage.empty_usage()
- # Publish end event
- self.queue_manager.publish(
- QueueMessageEndEvent(
- llm_result=LLMResult(
- model=self.model_instance.model,
- prompt_messages=prompt_messages,
- message=AssistantPromptMessage(content=final_answer),
- usage=usage,
- system_fingerprint="",
- )
- ),
- PublishFrom.APPLICATION_MANAGER,
- )
+ # Publish end event
+ self.queue_manager.publish(
+ QueueMessageEndEvent(
+ llm_result=LLMResult(
+ model=self.model_instance.model,
+ prompt_messages=prompt_messages,
+ message=AssistantPromptMessage(content=final_answer),
+ usage=usage,
+ system_fingerprint="",
+ )
+ ),
+ PublishFrom.APPLICATION_MANAGER,
+ )
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index e5da2b3e12..b5459611b1 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -7,7 +7,6 @@ from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
-from core.agent.output_tools import build_agent_output_tools
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -37,7 +36,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
- ToolInvokeFrom,
ToolParameter,
)
from core.tools.tool_manager import ToolManager
@@ -253,14 +251,6 @@ class BaseAgentRunner(AppRunner):
# save tool entity
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
- output_tools = build_agent_output_tools(
- tenant_id=self.tenant_id,
- invoke_from=self.application_generate_entity.invoke_from,
- tool_invoke_from=ToolInvokeFrom.AGENT,
- )
- for tool in output_tools:
- tool_instances[tool.entity.identity.name] = tool
-
return tool_instances, prompt_messages_tools
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py
index ad8bdb723a..46af4d2d72 100644
--- a/api/core/agent/entities.py
+++ b/api/core/agent/entities.py
@@ -1,11 +1,10 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
-from typing import Any
+from typing import Any, Union
from pydantic import BaseModel, Field
-from core.agent.output_tools import FINAL_OUTPUT_TOOL
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
@@ -42,9 +41,9 @@ class AgentScratchpadUnit(BaseModel):
"""
action_name: str
- action_input: dict[str, Any] | str
+ action_input: Union[dict, str]
- def to_dict(self) -> dict[str, Any]:
+ def to_dict(self):
"""
Convert to dictionary.
"""
@@ -63,9 +62,9 @@ class AgentScratchpadUnit(BaseModel):
"""
Check if the scratchpad unit is final.
"""
- if self.action is None:
- return False
- return self.action.action_name.lower() == FINAL_OUTPUT_TOOL
+ return self.action is None or (
+ "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
+ )
class AgentEntity(BaseModel):
@@ -126,7 +125,7 @@ class ExecutionContext(BaseModel):
"tenant_id": self.tenant_id,
}
- def with_updates(self, **kwargs: Any) -> "ExecutionContext":
+ def with_updates(self, **kwargs) -> "ExecutionContext":
"""Create a new context with updated fields."""
data = self.to_dict()
data.update(kwargs)
@@ -179,23 +178,12 @@ class AgentLog(BaseModel):
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
-class AgentOutputKind(StrEnum):
- """
- Agent output kind.
- """
-
- OUTPUT_TEXT = "output_text"
- FINAL_OUTPUT_ANSWER = "final_output_answer"
- FINAL_STRUCTURED_OUTPUT = "final_structured_output"
- ILLEGAL_OUTPUT = "illegal_output"
-
-
class AgentResult(BaseModel):
"""
Agent execution result.
"""
- output: str | dict = Field(default="", description="The generated output")
+ text: str = Field(default="", description="The generated text")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")
diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py
index 2d4ce58b6a..7c8f09e6b9 100644
--- a/api/core/agent/output_parser/cot_output_parser.py
+++ b/api/core/agent/output_parser/cot_output_parser.py
@@ -1,7 +1,7 @@
import json
import re
from collections.abc import Generator
-from typing import Any, Union, cast
+from typing import Union
from core.agent.entities import AgentScratchpadUnit
from core.model_runtime.entities.llm_entities import LLMResultChunk
@@ -10,52 +10,46 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
- cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
+ cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
- def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]:
- action_name: str | None = None
- action_input: Any | None = None
- parsed_action: Any = action
- if isinstance(parsed_action, str):
+ def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
+ action_name = None
+ action_input = None
+ if isinstance(action, str):
try:
- parsed_action = json.loads(parsed_action, strict=False)
+ action = json.loads(action, strict=False)
except json.JSONDecodeError:
- return parsed_action or ""
+ return action or ""
# cohere always returns a list
- if isinstance(parsed_action, list):
- action_list: list[Any] = cast(list[Any], parsed_action)
- if len(action_list) == 1:
- parsed_action = action_list[0]
+ if isinstance(action, list) and len(action) == 1:
+ action = action[0]
- if isinstance(parsed_action, dict):
- action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action)
- for key, value in action_dict.items():
- if "input" in key.lower():
- action_input = value
- elif isinstance(value, str):
- action_name = value
- else:
- return json.dumps(parsed_action)
+ for key, value in action.items():
+ if "input" in key.lower():
+ action_input = value
+ else:
+ action_name = value
if action_name is not None and action_input is not None:
return AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
- return json.dumps(parsed_action)
+ else:
+ return json.dumps(action)
- def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]:
+ def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
if not blocks:
return []
try:
- json_blocks: list[dict[str, Any] | list[Any]] = []
+ json_blocks = []
for block in blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
json_blocks.append(json.loads(json_text, strict=False))
return json_blocks
- except Exception:
+ except:
return []
code_block_cache = ""
diff --git a/api/core/agent/output_tools.py b/api/core/agent/output_tools.py
deleted file mode 100644
index 253d138c3a..0000000000
--- a/api/core/agent/output_tools.py
+++ /dev/null
@@ -1,108 +0,0 @@
-from __future__ import annotations
-
-from collections.abc import Mapping, Sequence
-from typing import Any
-
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.tools.__base.tool import Tool
-from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType
-from core.tools.tool_manager import ToolManager
-
-OUTPUT_TOOL_PROVIDER = "agent_output"
-
-OUTPUT_TEXT_TOOL = "output_text"
-FINAL_OUTPUT_TOOL = "final_output_answer"
-FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output"
-ILLEGAL_OUTPUT_TOOL = "illegal_output"
-
-OUTPUT_TOOL_NAMES: Sequence[str] = (
- OUTPUT_TEXT_TOOL,
- FINAL_OUTPUT_TOOL,
- FINAL_STRUCTURED_OUTPUT_TOOL,
- ILLEGAL_OUTPUT_TOOL,
-)
-
-TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL)
-
-
-TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session."
-
-
-def is_terminal_output_tool(tool_name: str) -> bool:
- return tool_name in TERMINAL_OUTPUT_TOOL_NAMES
-
-
-def get_terminal_tool_name(structured_output_enabled: bool) -> str:
- return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
-
-
-OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
-
-
-def build_agent_output_tools(
- tenant_id: str,
- invoke_from: InvokeFrom,
- tool_invoke_from: ToolInvokeFrom,
- structured_output_schema: Mapping[str, Any] | None = None,
-) -> list[Tool]:
- def get_tool_runtime(_tool_name: str) -> Tool:
- return ToolManager.get_tool_runtime(
- provider_type=ToolProviderType.BUILT_IN,
- provider_id=OUTPUT_TOOL_PROVIDER,
- tool_name=_tool_name,
- tenant_id=tenant_id,
- invoke_from=invoke_from,
- tool_invoke_from=tool_invoke_from,
- )
-
- tools: list[Tool] = [
- get_tool_runtime(OUTPUT_TEXT_TOOL),
- get_tool_runtime(ILLEGAL_OUTPUT_TOOL),
- ]
-
- if structured_output_schema:
- raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
- raw_tool.entity = raw_tool.entity.model_copy(deep=True)
- data_parameter = ToolParameter(
- name="data",
- type=ToolParameter.ToolParameterType.OBJECT,
- form=ToolParameter.ToolParameterForm.LLM,
- required=True,
- input_schema=dict(structured_output_schema),
- label=I18nObject(en_US="__Data", zh_Hans="__Data"),
- )
- raw_tool.entity.parameters = [data_parameter]
-
- def invoke_tool(
- user_id: str,
- tool_parameters: dict[str, Any],
- conversation_id: str | None = None,
- app_id: str | None = None,
- message_id: str | None = None,
- ) -> ToolInvokeMessage:
- data = tool_parameters["data"]
- if not data:
- return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required"))
- if not isinstance(data, dict):
- return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict"))
- return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
-
- raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
- tools.append(raw_tool)
- else:
- raw_tool = get_tool_runtime(FINAL_OUTPUT_TOOL)
-
- def invoke_tool(
- user_id: str,
- tool_parameters: dict[str, Any],
- conversation_id: str | None = None,
- app_id: str | None = None,
- message_id: str | None = None,
- ) -> ToolInvokeMessage:
- return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
-
- raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
- tools.append(raw_tool)
-
- return tools
diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py
index a97b796837..03c349a475 100644
--- a/api/core/agent/patterns/base.py
+++ b/api/core/agent/patterns/base.py
@@ -10,7 +10,6 @@ from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
-from core.agent.output_tools import ILLEGAL_OUTPUT_TOOL
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
@@ -57,7 +56,8 @@ class AgentPattern(ABC):
@abstractmethod
def run(
- self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
+ self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
+ stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
@@ -462,8 +462,6 @@ class AgentPattern(ABC):
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
- if tool.entity.identity.name == ILLEGAL_OUTPUT_TOOL:
- continue
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools
diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py
index 9d58f4e00d..9c87c186d8 100644
--- a/api/core/agent/patterns/function_call.py
+++ b/api/core/agent/patterns/function_call.py
@@ -1,15 +1,10 @@
"""Function Call strategy implementation."""
import json
-import uuid
from collections.abc import Generator
-from typing import Any, Literal, Protocol, Union, cast
+from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
-from core.agent.output_tools import (
- ILLEGAL_OUTPUT_TOOL,
- TERMINAL_OUTPUT_MESSAGE,
-)
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
@@ -30,7 +25,8 @@ class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
- self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
+ self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
+ stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
@@ -43,23 +39,9 @@ class FunctionCallStrategy(AgentPattern):
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
- final_tool_args: dict[str, Any] = {"!!!": "!!!"}
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
- class _LLMInvoker(Protocol):
- def invoke_llm(
- self,
- *,
- prompt_messages: list[PromptMessage],
- model_parameters: dict[str, Any],
- tools: list[PromptMessageTool],
- stop: list[str],
- stream: Literal[False],
- user: str | None,
- callbacks: list[Any],
- ) -> LLMResult: ...
-
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
@@ -69,7 +51,8 @@ class FunctionCallStrategy(AgentPattern):
data={},
)
yield round_log
-
+ # On last iteration, remove tools to force final answer
+ current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
@@ -86,63 +69,47 @@ class FunctionCallStrategy(AgentPattern):
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
- invoker = cast(_LLMInvoker, self.model_instance)
- chunks = invoker.invoke_llm(
+ chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
- tools=prompt_tools,
+ tools=current_tools,
stop=stop,
- stream=False,
+ stream=stream,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
- chunks, round_usage, model_log, emit_chunks=False
+ chunks, round_usage, model_log
)
-
- if response_content:
- replaced_tool_call = (
- str(uuid.uuid4()),
- ILLEGAL_OUTPUT_TOOL,
- {
- "raw": response_content,
- },
- )
- tool_calls.append(replaced_tool_call)
-
- messages.append(self._create_assistant_message("", tool_calls))
+ messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
+ # Update final text if no tool calls (this is likely the final answer)
+ if not tool_calls:
+ final_text = response_content
+
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
- assert len(tool_calls) > 0
-
# Process tool calls
tool_outputs: dict[str, str] = {}
- function_call_state = True
- # Execute tools
- for tool_call_id, tool_name, tool_args in tool_calls:
- tool_response, tool_files, _ = yield from self._handle_tool_call(
- tool_name, tool_args, tool_call_id, messages, round_log
- )
- tool_entity = self._find_tool_by_name(tool_name)
- if not tool_entity:
- raise ValueError(f"Tool {tool_name} not found")
- tool_outputs[tool_name] = tool_response
- # Track files produced by tools
- output_files.extend(tool_files)
- if tool_response == TERMINAL_OUTPUT_MESSAGE:
- function_call_state = False
- final_tool_args = tool_entity.transform_tool_parameters_type(tool_args)
-
+ if tool_calls:
+ function_call_state = True
+ # Execute tools
+ for tool_call_id, tool_name, tool_args in tool_calls:
+ tool_response, tool_files, _ = yield from self._handle_tool_call(
+ tool_name, tool_args, tool_call_id, messages, round_log
+ )
+ tool_outputs[tool_name] = tool_response
+ # Track files produced by tools
+ output_files.extend(tool_files)
yield self._finish_log(
round_log,
data={
@@ -161,19 +128,8 @@ class FunctionCallStrategy(AgentPattern):
# Return final result
from core.agent.entities import AgentResult
- output_payload: str | dict
- output_text = final_tool_args.get("text")
- output_structured_payload = final_tool_args.get("data")
-
- if isinstance(output_structured_payload, dict):
- output_payload = output_structured_payload
- elif isinstance(output_text, str):
- output_payload = output_text
- else:
- raise ValueError(f"Final output ({final_tool_args}) is not a string or structured data.")
-
return AgentResult(
- output=output_payload,
+ text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
@@ -184,8 +140,6 @@ class FunctionCallStrategy(AgentPattern):
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
- *,
- emit_chunks: bool,
) -> Generator[
LLMResultChunk | AgentLog,
None,
@@ -217,8 +171,7 @@ class FunctionCallStrategy(AgentPattern):
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
- if emit_chunks:
- yield chunk
+ yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
@@ -233,12 +186,11 @@ class FunctionCallStrategy(AgentPattern):
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
- if emit_chunks:
- yield LLMResultChunk(
- model=result.model,
- prompt_messages=result.prompt_messages,
- delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
- )
+ yield LLMResultChunk(
+ model=result.model,
+ prompt_messages=result.prompt_messages,
+ delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
+ )
yield self._finish_log(
start_log,
data={
@@ -248,14 +200,6 @@ class FunctionCallStrategy(AgentPattern):
)
return tool_calls, response_content, finish_reason
- @staticmethod
- def _format_output_text(value: Any) -> str:
- if value is None:
- return ""
- if isinstance(value, str):
- return value
- return json.dumps(value, ensure_ascii=False)
-
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:
diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py
index 88b4e8e930..b81d572aea 100644
--- a/api/core/agent/patterns/react.py
+++ b/api/core/agent/patterns/react.py
@@ -4,17 +4,10 @@ from __future__ import annotations
import json
from collections.abc import Generator
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
-from core.agent.output_tools import (
- FINAL_OUTPUT_TOOL,
- FINAL_STRUCTURED_OUTPUT_TOOL,
- ILLEGAL_OUTPUT_TOOL,
- OUTPUT_TEXT_TOOL,
- OUTPUT_TOOL_NAME_SET,
-)
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
@@ -59,7 +52,8 @@ class ReActStrategy(AgentPattern):
self.instruction = instruction
def run(
- self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
+ self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
+ stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
@@ -71,19 +65,6 @@ class ReActStrategy(AgentPattern):
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
- tool_instance_names = {tool.entity.identity.name for tool in self.tools}
- available_output_tool_names = {
- tool_name
- for tool_name in tool_instance_names
- if tool_name in OUTPUT_TOOL_NAME_SET and tool_name != ILLEGAL_OUTPUT_TOOL
- }
- if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
- terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
- elif FINAL_OUTPUT_TOOL in available_output_tool_names:
- terminal_tool_name = FINAL_OUTPUT_TOOL
- else:
- raise ValueError("No terminal output tool configured")
- allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
# Add "Observation" to stop sequences
if "Observation" not in stop:
@@ -100,15 +81,10 @@ class ReActStrategy(AgentPattern):
)
yield round_log
- # Build prompt with tool restrictions on last iteration
- if iteration_step == max_iterations:
- tools_for_prompt = [
- tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names
- ]
- else:
- tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL]
+ # Build prompt with/without tools based on iteration
+ include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
- prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
+ prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
@@ -130,18 +106,18 @@ class ReActStrategy(AgentPattern):
messages_to_use = current_messages
# Invoke model
- chunks = self.model_instance.invoke_llm(
+ chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
- stream=False,
+ stream=stream,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
- chunks, round_usage, model_log, current_messages, emit_chunks=False
+ chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
@@ -155,44 +131,28 @@ class ReActStrategy(AgentPattern):
finish_reason = chunk_finish_reason
# Check if we have an action to execute
- if scratchpad.action is None:
- if not allow_illegal_output:
- raise ValueError("Model did not call any tools")
- illegal_action = AgentScratchpadUnit.Action(
- action_name=ILLEGAL_OUTPUT_TOOL,
- action_input={"raw": scratchpad.thought or ""},
- )
- scratchpad.action = illegal_action
- scratchpad.action_str = illegal_action.model_dump_json()
+ if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
- observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
- scratchpad.observation = observation
- output_files.extend(tool_files)
- else:
- action_name = scratchpad.action.action_name
- if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
- pass # output_text_payload = scratchpad.action.action_input.get("text")
- elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict):
- data = scratchpad.action.action_input.get("data")
- if isinstance(data, dict):
- pass # structured_output_payload = data
- elif action_name == FINAL_OUTPUT_TOOL:
- if isinstance(scratchpad.action.action_input, dict):
- final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
- else:
- final_text = self._format_output_text(scratchpad.action.action_input)
-
+ # Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
+ # Track files produced by tools
output_files.extend(tool_files)
- if action_name == terminal_tool_name:
- pass # terminal_output_seen = True
- react_state = False
- else:
- react_state = True
+ # Add observation to scratchpad for display
+ yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
+ else:
+ # Extract final answer
+ if scratchpad.action and scratchpad.action.action_input:
+ final_answer = scratchpad.action.action_input
+ if isinstance(final_answer, dict):
+ final_answer = json.dumps(final_answer, ensure_ascii=False)
+ final_text = str(final_answer)
+ elif scratchpad.thought:
+ # If no action but we have thought, use thought as final answer
+ final_text = scratchpad.thought
yield self._finish_log(
round_log,
@@ -208,22 +168,17 @@ class ReActStrategy(AgentPattern):
# Return final result
- output_payload: str | dict
-
- # TODO
+ from core.agent.entities import AgentResult
return AgentResult(
- output=output_payload,
- files=output_files,
- usage=total_usage.get("usage"),
- finish_reason=finish_reason,
+ text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
- tools: list[Tool] | None,
+ include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
@@ -240,13 +195,9 @@ class ReActStrategy(AgentPattern):
# Format tools
tools_str = ""
tool_names = []
- if tools:
+ if include_tools and self.tools:
# Convert tools to prompt message tools format
- prompt_tools = [
- tool.to_prompt_message_tool()
- for tool in tools
- if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL
- ]
+ prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
@@ -258,19 +209,12 @@ class ReActStrategy(AgentPattern):
tools_str = "No tools available"
tool_names_str = ""
- final_tool_name = FINAL_OUTPUT_TOOL
- if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names:
- final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
- if final_tool_name not in tool_names:
- raise ValueError("No terminal output tool available for prompt")
-
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
- updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
@@ -302,12 +246,10 @@ class ReActStrategy(AgentPattern):
def _handle_chunks(
self,
- chunks: LLMResult,
+ chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
- *,
- emit_chunks: bool,
) -> Generator[
LLMResultChunk | AgentLog,
None,
@@ -319,20 +261,25 @@ class ReActStrategy(AgentPattern):
"""
usage_dict: dict[str, Any] = {}
- def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
- yield LLMResultChunk(
- model=chunks.model,
- prompt_messages=chunks.prompt_messages,
- delta=LLMResultChunkDelta(
- index=0,
- message=chunks.message,
- usage=chunks.usage,
- finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
- ),
- system_fingerprint=chunks.system_fingerprint or "",
- )
+ # Convert non-streaming to streaming format if needed
+ if isinstance(chunks, LLMResult):
+ # Create a generator from the LLMResult
+ def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
+ yield LLMResultChunk(
+ model=chunks.model,
+ prompt_messages=chunks.prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=chunks.message,
+ usage=chunks.usage,
+ finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
+ ),
+ system_fingerprint=chunks.system_fingerprint or "",
+ )
- streaming_chunks = result_to_chunks()
+ streaming_chunks = result_to_chunks()
+ else:
+ streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
@@ -356,18 +303,14 @@ class ReActStrategy(AgentPattern):
scratchpad.action_str = action_str
scratchpad.action = chunk
- if emit_chunks:
- yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
- elif isinstance(chunk, str):
+ yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
+ else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
- if emit_chunks:
- yield self._create_text_chunk(chunk_text, current_messages)
- else:
- raise ValueError(f"Unexpected chunk type: {type(chunk)}")
+ yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
@@ -391,14 +334,6 @@ class ReActStrategy(AgentPattern):
return scratchpad, finish_reason
- @staticmethod
- def _format_output_text(value: Any) -> str:
- if value is None:
- return ""
- if isinstance(value, str):
- return value
- return json.dumps(value, ensure_ascii=False)
-
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,
diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py
index fbb550534d..2ec845c9b0 100644
--- a/api/core/agent/patterns/strategy_factory.py
+++ b/api/core/agent/patterns/strategy_factory.py
@@ -2,17 +2,13 @@
from __future__ import annotations
-from collections.abc import Mapping
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
-from ...app.entities.app_invoke_entities import InvokeFrom
-from ...tools.entities.tool_entities import ToolInvokeFrom
-from ..output_tools import build_agent_output_tools
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
@@ -29,10 +25,6 @@ class StrategyFactory:
@staticmethod
def create_strategy(
- *,
- tenant_id: str,
- invoke_from: InvokeFrom,
- tool_invoke_from: ToolInvokeFrom,
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
@@ -43,16 +35,11 @@ class StrategyFactory:
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
- structured_output_schema: Mapping[str, Any] | None = None,
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
- tenant_id:
- invoke_from:
- tool_invoke_from:
- structured_output_schema:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
@@ -67,14 +54,6 @@ class StrategyFactory:
Returns:
AgentStrategy instance
"""
- output_tools = build_agent_output_tools(
- tenant_id=tenant_id,
- invoke_from=invoke_from,
- tool_invoke_from=tool_invoke_from,
- structured_output_schema=structured_output_schema,
- )
-
- tools.extend(output_tools)
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py
index 0bbc9007b7..f5ba2119f4 100644
--- a/api/core/agent/prompt/template.py
+++ b/api/core/agent/prompt/template.py
@@ -7,7 +7,7 @@ You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
+Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
@@ -32,14 +32,12 @@ Thought: I know what to respond
Action:
```
{
- "action": "{{final_tool_name}}",
- "action_input": {
- "text": "Final response to human"
- }
+ "action": "Final Answer",
+ "action_input": "Final response to human"
}
```
-Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
+Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
{{historic_messages}}
Question: {{query}}
{{agent_scratchpad}}
@@ -58,7 +56,7 @@ You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
-Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
+Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
@@ -83,14 +81,12 @@ Thought: I know what to respond
Action:
```
{
- "action": "{{final_tool_name}}",
- "action_input": {
- "text": "Final response to human"
- }
+ "action": "Final Answer",
+ "action_input": "Final response to human"
}
```
-Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
+Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
""" # noqa: E501
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index 5510aacf02..6cb4f450eb 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -477,6 +477,7 @@ class LLMGenerator:
prompt_messages=complete_messages,
output_model=CodeNodeStructuredOutput,
model_parameters=model_parameters,
+ tenant_id=tenant_id,
)
return {
@@ -557,14 +558,10 @@ class LLMGenerator:
completion_params = model_config.get("completion_params", {}) if model_config else {}
try:
- response = invoke_llm_with_pydantic_model(
- provider=model_instance.provider,
- model_schema=model_schema,
- model_instance=model_instance,
- prompt_messages=prompt_messages,
- output_model=SuggestedQuestionsOutput,
- model_parameters=completion_params,
- )
+ response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
+ model_instance=model_instance, prompt_messages=prompt_messages,
+ output_model=SuggestedQuestionsOutput,
+ model_parameters=completion_params, tenant_id=tenant_id)
return {"questions": response.questions, "error": ""}
diff --git a/api/core/llm_generator/output_parser/file_ref.py b/api/core/llm_generator/output_parser/file_ref.py
index 74e47570c6..83489e6a79 100644
--- a/api/core/llm_generator/output_parser/file_ref.py
+++ b/api/core/llm_generator/output_parser/file_ref.py
@@ -1,190 +1,188 @@
-from collections.abc import Callable, Mapping, Sequence
-from typing import Any, cast
+"""
+File reference detection and conversion for structured output.
+
+This module provides utilities to:
+1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
+2. Convert file ID strings to File objects after LLM returns
+"""
+
+import uuid
+from collections.abc import Mapping
+from typing import Any
from core.file import File
from core.variables.segments import ArrayFileSegment, FileSegment
+from factories.file_factory import build_from_mapping
-FILE_PATH_SCHEMA_TYPE = "file"
-FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"}
-FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)."
+FILE_REF_FORMAT = "dify-file-ref"
-def is_file_path_property(schema: Mapping[str, Any]) -> bool:
- if schema.get("type") == FILE_PATH_SCHEMA_TYPE:
- return True
- format_value = schema.get("format")
- if not isinstance(format_value, str):
- return False
- normalized_format = format_value.lower().replace("_", "-")
- return normalized_format in FILE_PATH_SCHEMA_FORMATS
+def is_file_ref_property(schema: dict) -> bool:
+ """Check if a schema property is a file reference."""
+ return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
-def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
- file_path_fields: list[str] = []
+def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
+ """
+ Recursively detect file reference fields in schema.
+
+ Args:
+ schema: JSON Schema to analyze
+ path: Current path in the schema (used for recursion)
+
+ Returns:
+ List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
+ """
+ file_ref_paths: list[str] = []
schema_type = schema.get("type")
if schema_type == "object":
- properties = schema.get("properties")
- if isinstance(properties, Mapping):
- properties_mapping = cast(Mapping[str, Any], properties)
- for prop_name, prop_schema in properties_mapping.items():
- if not isinstance(prop_schema, Mapping):
- continue
- prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
- current_path = f"{path}.{prop_name}" if path else prop_name
+ for prop_name, prop_schema in schema.get("properties", {}).items():
+ current_path = f"{path}.{prop_name}" if path else prop_name
- if is_file_path_property(prop_schema_mapping):
- file_path_fields.append(current_path)
- else:
- file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
+ if is_file_ref_property(prop_schema):
+ file_ref_paths.append(current_path)
+ elif isinstance(prop_schema, dict):
+ file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
elif schema_type == "array":
- items_schema = schema.get("items")
- if not isinstance(items_schema, Mapping):
- return file_path_fields
- items_schema_mapping = cast(Mapping[str, Any], items_schema)
+ items_schema = schema.get("items", {})
array_path = f"{path}[*]" if path else "[*]"
- if is_file_path_property(items_schema_mapping):
- file_path_fields.append(array_path)
- else:
- file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
+ if is_file_ref_property(items_schema):
+ file_ref_paths.append(array_path)
+ elif isinstance(items_schema, dict):
+ file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
- return file_path_fields
+ return file_ref_paths
-def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
- result = _deep_copy_value(schema)
- if not isinstance(result, dict):
- raise ValueError("structured_output_schema must be a JSON object")
- result_dict = cast(dict[str, Any], result)
-
- file_path_fields: list[str] = []
- _adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
- return result_dict, file_path_fields
-
-
-def convert_sandbox_file_paths_in_output(
+def convert_file_refs_in_output(
output: Mapping[str, Any],
- file_path_fields: Sequence[str],
- file_resolver: Callable[[str], File],
-) -> tuple[dict[str, Any], list[File]]:
- if not file_path_fields:
- return dict(output), []
+ json_schema: Mapping[str, Any],
+ tenant_id: str,
+) -> dict[str, Any]:
+ """
+ Convert file ID strings to File objects based on schema.
- result = _deep_copy_value(output)
- if not isinstance(result, dict):
- raise ValueError("Structured output must be a JSON object")
- result_dict = cast(dict[str, Any], result)
+ Args:
+ output: The structured_output from LLM result
+ json_schema: The original JSON schema (to detect file ref fields)
+ tenant_id: Tenant ID for file lookup
- files: list[File] = []
- for path in file_path_fields:
- _convert_path_in_place(result_dict, path.split("."), file_resolver, files)
+ Returns:
+ Output with file references converted to File objects
+ """
+ file_ref_paths = detect_file_ref_fields(json_schema)
+ if not file_ref_paths:
+ return dict(output)
- return result_dict, files
+ result = _deep_copy_dict(output)
+
+ for path in file_ref_paths:
+ _convert_path_in_place(result, path.split("."), tenant_id)
+
+ return result
-def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
- schema_type = schema.get("type")
-
- if schema_type == "object":
- properties = schema.get("properties")
- if isinstance(properties, Mapping):
- properties_mapping = cast(Mapping[str, Any], properties)
- for prop_name, prop_schema in properties_mapping.items():
- if not isinstance(prop_schema, dict):
- continue
- prop_schema_dict = cast(dict[str, Any], prop_schema)
- current_path = f"{path}.{prop_name}" if path else prop_name
-
- if is_file_path_property(prop_schema_dict):
- _normalize_file_path_schema(prop_schema_dict)
- file_path_fields.append(current_path)
- else:
- _adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
-
- elif schema_type == "array":
- items_schema = schema.get("items")
- if not isinstance(items_schema, dict):
- return
- items_schema_dict = cast(dict[str, Any], items_schema)
- array_path = f"{path}[*]" if path else "[*]"
-
- if is_file_path_property(items_schema_dict):
- _normalize_file_path_schema(items_schema_dict)
- file_path_fields.append(array_path)
+def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
+ """Deep copy a mapping to a mutable dict."""
+ result: dict[str, Any] = {}
+ for key, value in obj.items():
+ if isinstance(value, Mapping):
+ result[key] = _deep_copy_dict(value)
+ elif isinstance(value, list):
+ result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
else:
- _adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
+ result[key] = value
+ return result
-def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
- schema["type"] = "string"
- schema.pop("format", None)
- description = schema.get("description", "")
- if description:
- schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
- else:
- schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
-
-
-def _deep_copy_value(value: Any) -> Any:
- if isinstance(value, Mapping):
- mapping = cast(Mapping[str, Any], value)
- return {key: _deep_copy_value(item) for key, item in mapping.items()}
- if isinstance(value, list):
- list_value = cast(list[Any], value)
- return [_deep_copy_value(item) for item in list_value]
- return value
-
-
-def _convert_path_in_place(
- obj: dict[str, Any],
- path_parts: list[str],
- file_resolver: Callable[[str], File],
- files: list[File],
-) -> None:
+def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
+ """Convert file refs at the given path in place, wrapping in Segment types."""
if not path_parts:
return
current = path_parts[0]
remaining = path_parts[1:]
+ # Handle array notation like "files[*]"
if current.endswith("[*]"):
- key = current[:-3] if current != "[*]" else ""
- target_value = obj.get(key) if key else obj
+ key = current[:-3] if current != "[*]" else None
+ target = obj.get(key) if key else obj
- if isinstance(target_value, list):
- target_list = cast(list[Any], target_value)
+ if isinstance(target, list):
if remaining:
- for item in target_list:
+ # Nested array with remaining path - recurse into each item
+ for item in target:
if isinstance(item, dict):
- item_dict = cast(dict[str, Any], item)
- _convert_path_in_place(item_dict, remaining, file_resolver, files)
+ _convert_path_in_place(item, remaining, tenant_id)
else:
- resolved_files: list[File] = []
- for item in target_list:
- if not isinstance(item, str):
- raise ValueError("File path must be a string")
- file = file_resolver(item)
- files.append(file)
- resolved_files.append(file)
+ # Array of file IDs - convert all and wrap in ArrayFileSegment
+ files: list[File] = []
+ for item in target:
+ file = _convert_file_id(item, tenant_id)
+ if file is not None:
+ files.append(file)
+ # Replace the array with ArrayFileSegment
if key:
- obj[key] = ArrayFileSegment(value=resolved_files)
+ obj[key] = ArrayFileSegment(value=files)
return
if not remaining:
- if current not in obj:
- return
- value = obj[current]
- if value is None:
- obj[current] = None
- return
- if not isinstance(value, str):
- raise ValueError("File path must be a string")
- file = file_resolver(value)
- files.append(file)
- obj[current] = FileSegment(value=file)
- return
+ # Leaf node - convert the value and wrap in FileSegment
+ if current in obj:
+ file = _convert_file_id(obj[current], tenant_id)
+ if file is not None:
+ obj[current] = FileSegment(value=file)
+ else:
+ obj[current] = None
+ else:
+ # Recurse into nested object
+ if current in obj and isinstance(obj[current], dict):
+ _convert_path_in_place(obj[current], remaining, tenant_id)
- if current in obj and isinstance(obj[current], dict):
- _convert_path_in_place(obj[current], remaining, file_resolver, files)
+
+def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
+ """
+ Convert a file ID string to a File object.
+
+ Tries multiple file sources in order:
+ 1. ToolFile (files generated by tools/workflows)
+ 2. UploadFile (files uploaded by users)
+ """
+ if not isinstance(file_id, str):
+ return None
+
+ # Validate UUID format
+ try:
+ uuid.UUID(file_id)
+ except ValueError:
+ return None
+
+ # Try ToolFile first (files generated by tools/workflows)
+ try:
+ return build_from_mapping(
+ mapping={
+ "transfer_method": "tool_file",
+ "tool_file_id": file_id,
+ },
+ tenant_id=tenant_id,
+ )
+ except ValueError:
+ pass
+
+ # Try UploadFile (files uploaded by users)
+ try:
+ return build_from_mapping(
+ mapping={
+ "transfer_method": "local_file",
+ "upload_file_id": file_id,
+ },
+ tenant_id=tenant_id,
+ )
+ except ValueError:
+ pass
+
+ # File not found in any source
+ return None
diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py
index c790feda53..a069f0409c 100644
--- a/api/core/llm_generator/output_parser/structured_output.py
+++ b/api/core/llm_generator/output_parser/structured_output.py
@@ -8,7 +8,7 @@ import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
from core.llm_generator.output_parser.errors import OutputParserError
-from core.llm_generator.output_parser.file_ref import detect_file_path_fields
+from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
@@ -55,11 +55,12 @@ def invoke_llm_with_structured_output(
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
- model_parameters: Mapping[str, Any] | None = None,
+ model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
+ tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with structured output.
@@ -77,12 +78,14 @@ def invoke_llm_with_structured_output(
:param stop: stop words
:param user: unique user id
:param callbacks: callbacks
- :return: response with structured output
+ :param tenant_id: tenant ID for file reference conversion. When provided and
+ json_schema contains file reference fields (format: "dify-file-ref"),
+ file IDs in the output will be automatically converted to File objects.
+ :return: full response or stream response chunk generator result
"""
- model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {})
-
- if detect_file_path_fields(json_schema):
- raise OutputParserError("Structured output file paths are only supported in sandbox mode.")
+ model_parameters_with_json_schema: dict[str, Any] = {
+ **(model_parameters or {}),
+ }
# Determine structured output strategy
@@ -119,6 +122,14 @@ def invoke_llm_with_structured_output(
# Fill missing fields with default values
structured_output = fill_defaults_from_schema(structured_output, json_schema)
+ # Convert file references if tenant_id is provided
+ if tenant_id is not None:
+ structured_output = convert_file_refs_in_output(
+ output=structured_output,
+ json_schema=json_schema,
+ tenant_id=tenant_id,
+ )
+
return LLMResultWithStructuredOutput(
structured_output=structured_output,
model=llm_result.model,
@@ -136,11 +147,12 @@ def invoke_llm_with_pydantic_model(
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
- model_parameters: Mapping[str, Any] | None = None,
+ model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
user: str | None = None,
callbacks: list[Callback] | None = None,
+ tenant_id: str | None = None,
) -> T:
"""
Invoke large language model with a Pydantic output model.
@@ -148,8 +160,11 @@ def invoke_llm_with_pydantic_model(
This helper generates a JSON schema from the Pydantic model, invokes the
structured-output LLM path, and validates the result.
- The helper performs a non-streaming invocation and returns the validated
- Pydantic model directly.
+ The stream parameter controls the underlying LLM invocation mode:
+ - stream=True (default): Uses streaming LLM call, consumes the generator internally
+ - stream=False: Uses non-streaming LLM call
+
+ In both cases, the function returns the validated Pydantic model directly.
"""
json_schema = _schema_from_pydantic(output_model)
@@ -164,6 +179,7 @@ def invoke_llm_with_pydantic_model(
stop=stop,
user=user,
callbacks=callbacks,
+ tenant_id=tenant_id,
)
structured_output = result.structured_output
@@ -220,27 +236,25 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
return _parse_structured_output(content)
-def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]:
+def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
"""Parse JSON from tool call arguments."""
if not arguments:
raise OutputParserError("Tool call arguments is empty")
try:
- parsed_any = json.loads(arguments)
- if not isinstance(parsed_any, dict):
+ parsed = json.loads(arguments)
+ if not isinstance(parsed, dict):
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
- parsed = cast(dict[str, Any], parsed_any)
return parsed
except json.JSONDecodeError:
# Try to repair malformed JSON
- repaired_any = json_repair.loads(arguments)
- if not isinstance(repaired_any, dict):
+ repaired = json_repair.loads(arguments)
+ if not isinstance(repaired, dict):
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
- repaired: dict[str, Any] = repaired_any
return repaired
-def get_default_value_for_type(type_name: str | list[str] | None) -> Any:
+def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
"""Get default empty value for a JSON schema type."""
# Handle array of types (e.g., ["string", "null"])
if isinstance(type_name, list):
@@ -297,7 +311,7 @@ def fill_defaults_from_schema(
# Create empty object and recursively fill its required fields
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
else:
- result[prop_name] = get_default_value_for_type(prop_type)
+ result[prop_name] = _get_default_value_for_type(prop_type)
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
# Field exists and is an object, recursively fill nested required fields
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
@@ -308,10 +322,10 @@ def fill_defaults_from_schema(
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
- structured_output_schema: Mapping[str, Any],
- model_parameters: dict[str, Any],
+ structured_output_schema: Mapping,
+ model_parameters: dict,
rules: list[ParameterRule],
-) -> dict[str, Any]:
+):
"""
Handle structured output for models with native JSON schema support.
@@ -333,7 +347,7 @@ def _handle_native_json_schema(
return model_parameters
-def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
+def _set_response_format(model_parameters: dict, rules: list):
"""
Set the appropriate response format parameter based on model rules.
@@ -349,7 +363,7 @@ def _set_response_format(model_parameters: dict[str, Any], rules: list[Parameter
def _handle_prompt_based_schema(
- prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any]
+ prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
) -> list[PromptMessage]:
"""
Handle structured output for models without native JSON schema support.
@@ -386,27 +400,28 @@ def _handle_prompt_based_schema(
return updated_prompt
-def _parse_structured_output(result_text: str) -> dict[str, Any]:
+def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
+ structured_output: Mapping[str, Any] = {}
+ parsed: Mapping[str, Any] = {}
try:
- parsed = TypeAdapter(dict[str, Any]).validate_json(result_text)
- return parsed
+ parsed = TypeAdapter(Mapping).validate_json(result_text)
+ if not isinstance(parsed, dict):
+ raise OutputParserError(f"Failed to parse structured output: {result_text}")
+ structured_output = parsed
except ValidationError:
# if the result_text is not a valid json, try to repair it
- temp_parsed: Any = json_repair.loads(result_text)
- if isinstance(temp_parsed, list):
- temp_parsed_list = cast(list[Any], temp_parsed)
- dict_items: list[dict[str, Any]] = []
- for item in temp_parsed_list:
- if isinstance(item, dict):
- dict_items.append(cast(dict[str, Any], item))
- temp_parsed = dict_items[0] if dict_items else {}
+ temp_parsed = json_repair.loads(result_text)
if not isinstance(temp_parsed, dict):
- raise OutputParserError(f"Failed to parse structured output: {result_text}")
- temp_parsed_dict = cast(dict[str, Any], temp_parsed)
- return temp_parsed_dict
+ # handle reasoning model like deepseek-r1 got '\n\n\n' prefix
+ if isinstance(temp_parsed, list):
+ temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
+ else:
+ raise OutputParserError(f"Failed to parse structured output: {result_text}")
+ structured_output = cast(dict, temp_parsed)
+ return structured_output
-def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]:
+def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
"""
Prepare JSON schema based on model requirements.
@@ -418,49 +433,54 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
"""
# Deep copy to avoid modifying the original schema
- processed_schema = deepcopy(schema)
- processed_schema_dict = dict(processed_schema)
+ processed_schema = dict(deepcopy(schema))
# Convert boolean types to string types (common requirement)
- convert_boolean_to_string(processed_schema_dict)
+ convert_boolean_to_string(processed_schema)
# Apply model-specific transformations
if SpecialModelType.GEMINI in model_schema.model:
- remove_additional_properties(processed_schema_dict)
- return processed_schema_dict
- if SpecialModelType.OLLAMA in provider:
- return processed_schema_dict
-
- # Default format with name field
- return {"schema": processed_schema_dict, "name": "llm_response"}
+ remove_additional_properties(processed_schema)
+ return processed_schema
+ elif SpecialModelType.OLLAMA in provider:
+ return processed_schema
+ else:
+ # Default format with name field
+ return {"schema": processed_schema, "name": "llm_response"}
-def remove_additional_properties(schema: dict[str, Any]) -> None:
+def remove_additional_properties(schema: dict):
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
+ if not isinstance(schema, dict):
+ return
+
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
- remove_additional_properties(cast(dict[str, Any], value))
+ remove_additional_properties(value)
elif isinstance(value, list):
- for item in cast(list[Any], value):
+ for item in value:
if isinstance(item, dict):
- remove_additional_properties(cast(dict[str, Any], item))
+ remove_additional_properties(item)
-def convert_boolean_to_string(schema: dict[str, Any]) -> None:
+def convert_boolean_to_string(schema: dict):
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
+ if not isinstance(schema, dict):
+ return
+
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
@@ -468,8 +488,8 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None:
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
- convert_boolean_to_string(cast(dict[str, Any], value))
+ convert_boolean_to_string(value)
elif isinstance(value, list):
- for item in cast(list[Any], value):
+ for item in value:
if isinstance(item, dict):
- convert_boolean_to_string(cast(dict[str, Any], item))
+ convert_boolean_to_string(item)
diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py
index 6a266d1d9d..28cb70f96a 100644
--- a/api/core/plugin/utils/chunk_merger.py
+++ b/api/core/plugin/utils/chunk_merger.py
@@ -1,13 +1,11 @@
from collections.abc import Generator
from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, TypeVar
+from typing import TypeVar, Union
+from core.agent.entities import AgentInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage
-if TYPE_CHECKING:
- pass
-
-MessageType = TypeVar("MessageType", bound=ToolInvokeMessage)
+MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
@dataclass
@@ -89,7 +87,7 @@ def merge_blob_chunks(
),
meta=resp.meta,
)
- assert isinstance(merged_message, ToolInvokeMessage)
+ assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
yield merged_message # type: ignore
# Clean up the buffer
del files[chunk_id]
diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py
index 60691790a2..57e790b2c2 100644
--- a/api/core/sandbox/bash/session.py
+++ b/api/core/sandbox/bash/session.py
@@ -14,8 +14,7 @@ from core.skill.entities import ToolAccessPolicy
from core.skill.entities.tool_dependencies import ToolDependencies
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
-from core.virtual_environment.__base.exec import CommandExecutionError
-from core.virtual_environment.__base.helpers import execute, pipeline
+from core.virtual_environment.__base.helpers import pipeline
from ..bash.dify_cli import DifyCliConfig
from ..entities import DifyCli
@@ -120,6 +119,21 @@ class SandboxBashSession:
return self._bash_tool
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
+ """
+ Collect files from sandbox output directory and save them as ToolFiles.
+
+ Scans the specified output directory in sandbox, downloads each file,
+ saves it as a ToolFile, and returns a list of File objects. The File
+ objects will have valid tool_file_id that can be referenced by subsequent
+ nodes via structured output.
+
+ Args:
+ output_dir: Directory path in sandbox to scan for output files.
+ Defaults to "output" (relative to workspace).
+
+ Returns:
+ List of File objects representing the collected files.
+ """
vm = self._sandbox.vm
collected_files: list[File] = []
@@ -130,6 +144,8 @@ class SandboxBashSession:
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
return collected_files
+ tool_file_manager = ToolFileManager()
+
for file_state in file_states:
# Skip files that are too large
if file_state.size > MAX_OUTPUT_FILE_SIZE:
@@ -146,14 +162,47 @@ class SandboxBashSession:
file_content = vm.download_file(file_state.path)
file_binary = file_content.getvalue()
+ # Determine mime type from extension
filename = os.path.basename(file_state.path)
- file_obj = self._create_tool_file(filename=filename, file_binary=file_binary)
+ mime_type, _ = mimetypes.guess_type(filename)
+ if not mime_type:
+ mime_type = "application/octet-stream"
+
+ # Save as ToolFile
+ tool_file = tool_file_manager.create_file_by_raw(
+ user_id=self._user_id,
+ tenant_id=self._tenant_id,
+ conversation_id=None,
+ file_binary=file_binary,
+ mimetype=mime_type,
+ filename=filename,
+ )
+
+ # Determine file type from mime type
+ file_type = _get_file_type_from_mime(mime_type)
+ extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
+ url = sign_tool_file(tool_file.id, extension)
+
+ # Create File object with tool_file_id as related_id
+ file_obj = File(
+ id=tool_file.id, # Use tool_file_id as the File id for easy reference
+ tenant_id=self._tenant_id,
+ type=file_type,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=len(file_binary),
+ related_id=tool_file.id,
+ url=url,
+ storage_key=tool_file.file_key,
+ )
collected_files.append(file_obj)
logger.info(
"Collected sandbox output file: %s -> tool_file_id=%s",
file_state.path,
- file_obj.id,
+ tool_file.id,
)
except Exception as exc:
@@ -167,85 +216,6 @@ class SandboxBashSession:
)
return collected_files
- def download_file(self, path: str) -> File:
- path_kind = self._detect_path_kind(path)
- if path_kind == "dir":
- raise ValueError("Directory outputs are not supported")
- if path_kind != "file":
- raise ValueError(f"Sandbox file not found: {path}")
-
- file_content = self._sandbox.vm.download_file(path)
- file_binary = file_content.getvalue()
- if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
- raise ValueError(f"Sandbox file exceeds size limit: {path}")
-
- filename = os.path.basename(path) or "file"
- return self._create_tool_file(filename=filename, file_binary=file_binary)
-
- def _detect_path_kind(self, path: str) -> str:
- script = r"""
-import os
-import sys
-
-p = sys.argv[1]
-if os.path.isdir(p):
- print("dir")
- raise SystemExit(0)
-if os.path.isfile(p):
- print("file")
- raise SystemExit(0)
-print("none")
-raise SystemExit(2)
-"""
- try:
- result = execute(
- self._sandbox.vm,
- [
- "sh",
- "-c",
- 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"',
- script,
- path,
- ],
- timeout=10,
- error_message="Failed to inspect sandbox path",
- )
- except CommandExecutionError as exc:
- raise ValueError(str(exc)) from exc
- return result.stdout.decode("utf-8", errors="replace").strip()
-
- def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File:
- mime_type, _ = mimetypes.guess_type(filename)
- if not mime_type:
- mime_type = "application/octet-stream"
-
- tool_file = ToolFileManager().create_file_by_raw(
- user_id=self._user_id,
- tenant_id=self._tenant_id,
- conversation_id=None,
- file_binary=file_binary,
- mimetype=mime_type,
- filename=filename,
- )
-
- file_type = _get_file_type_from_mime(mime_type)
- extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
- url = sign_tool_file(tool_file.id, extension)
-
- return File(
- id=tool_file.id,
- tenant_id=self._tenant_id,
- type=file_type,
- transfer_method=FileTransferMethod.TOOL_FILE,
- filename=filename,
- extension=extension,
- mime_type=mime_type,
- size=len(file_binary),
- related_id=tool_file.id,
- url=url,
- storage_key=tool_file.file_key,
- )
-
def _get_file_type_from_mime(mime_type: str) -> FileType:
"""Determine FileType from mime type."""
diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py
index 9c0e8d0963..24fc11aefc 100644
--- a/api/core/tools/__base/tool.py
+++ b/api/core/tools/__base/tool.py
@@ -57,7 +57,7 @@ class Tool(ABC):
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
- tool_parameters = self.transform_tool_parameters_type(tool_parameters)
+ tool_parameters = self._transform_tool_parameters_type(tool_parameters)
result = self._invoke(
user_id=user_id,
@@ -82,7 +82,7 @@ class Tool(ABC):
else:
return result
- def transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
+ def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
diff --git a/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg b/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg
deleted file mode 100644
index 82e3b54edd..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg
+++ /dev/null
@@ -1,5 +0,0 @@
-
diff --git a/api/core/tools/builtin_tool/providers/agent_output/agent_output.py b/api/core/tools/builtin_tool/providers/agent_output/agent_output.py
deleted file mode 100644
index 65c3eb5a1d..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/agent_output.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from typing import Any
-
-from core.tools.builtin_tool.provider import BuiltinToolProviderController
-
-
-class AgentOutputProvider(BuiltinToolProviderController):
- def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
- pass
diff --git a/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml
deleted file mode 100644
index 1e4f624104..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-identity:
- author: Dify
- name: agent_output
- label:
- en_US: Agent Output
- description:
- en_US: Internal tools for agent output control.
- icon: icon.svg
- tags:
- - utilities
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py
deleted file mode 100644
index ed5edc9fd3..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from collections.abc import Generator
-from typing import Any
-
-from core.tools.builtin_tool.tool import BuiltinTool
-from core.tools.entities.tool_entities import ToolInvokeMessage
-
-
-class FinalOutputAnswerTool(BuiltinTool):
- def _invoke(
- self,
- user_id: str,
- tool_parameters: dict[str, Any],
- conversation_id: str | None = None,
- app_id: str | None = None,
- message_id: str | None = None,
- ) -> Generator[ToolInvokeMessage, None, None]:
- yield self.create_text_message("Final answer recorded.")
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml
deleted file mode 100644
index b8e4f14d85..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-identity:
- name: final_output_answer
- author: Dify
- label:
- en_US: Final Output Answer
-description:
- human:
- en_US: Internal tool to deliver the final answer.
- llm: Use this tool when you are ready to provide the final answer.
-parameters:
- - name: text
- type: string
- required: true
- label:
- en_US: Text
- human_description:
- en_US: Final answer text.
- form: llm
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py
deleted file mode 100644
index 916dc9f0bc..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from collections.abc import Generator
-from typing import Any
-
-from core.tools.builtin_tool.tool import BuiltinTool
-from core.tools.entities.tool_entities import ToolInvokeMessage
-
-
-class FinalStructuredOutputTool(BuiltinTool):
- def _invoke(
- self,
- user_id: str,
- tool_parameters: dict[str, Any],
- conversation_id: str | None = None,
- app_id: str | None = None,
- message_id: str | None = None,
- ) -> Generator[ToolInvokeMessage, None, None]:
- yield self.create_text_message("Structured output recorded.")
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml
deleted file mode 100644
index 79eb2d6412..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-identity:
- name: final_structured_output
- author: Dify
- label:
- en_US: Final Structured Output
-description:
- human:
- en_US: Internal tool to deliver structured output.
- llm: Use this tool to provide structured output data.
-parameters:
- - name: data
- type: object
- required: true
- label:
- en_US: Data
- human_description:
- en_US: Structured output data.
- form: llm
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py
deleted file mode 100644
index 2276c527e9..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from collections.abc import Generator
-from typing import Any
-
-from core.tools.builtin_tool.tool import BuiltinTool
-from core.tools.entities.tool_entities import ToolInvokeMessage
-
-
-class IllegalOutputTool(BuiltinTool):
- def _invoke(
- self,
- user_id: str,
- tool_parameters: dict[str, Any],
- conversation_id: str | None = None,
- app_id: str | None = None,
- message_id: str | None = None,
- ) -> Generator[ToolInvokeMessage, None, None]:
- message = (
- "Protocol violation: do not output plain text. "
- "Call an output tool and finish with the configured terminal tool."
- )
- yield self.create_text_message(message)
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml
deleted file mode 100644
index b293a41659..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-identity:
- name: illegal_output
- author: Dify
- label:
- en_US: Illegal Output
-description:
- human:
- en_US: Internal tool for output protocol violations.
- llm: Use this tool to correct output protocol violations.
-parameters:
- - name: raw
- type: string
- required: true
- label:
- en_US: Raw Output
- human_description:
- en_US: Raw model output that violated the protocol.
- form: llm
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py
deleted file mode 100644
index 942846f507..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from collections.abc import Generator
-from typing import Any
-
-from core.tools.builtin_tool.tool import BuiltinTool
-from core.tools.entities.tool_entities import ToolInvokeMessage
-
-
-class OutputTextTool(BuiltinTool):
- def _invoke(
- self,
- user_id: str,
- tool_parameters: dict[str, Any],
- conversation_id: str | None = None,
- app_id: str | None = None,
- message_id: str | None = None,
- ) -> Generator[ToolInvokeMessage, None, None]:
- yield self.create_text_message("Output recorded.")
diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml
deleted file mode 100644
index 48d237e521..0000000000
--- a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml
+++ /dev/null
@@ -1,18 +0,0 @@
-identity:
- name: output_text
- author: Dify
- label:
- en_US: Output Text
-description:
- human:
- en_US: Internal tool to store intermediate text output.
- llm: Use this tool to emit non-final text output.
-parameters:
- - name: text
- type: string
- required: true
- label:
- en_US: Text
- human_description:
- en_US: Output text.
- form: llm
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index b524de34e4..5dae773841 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import json
import logging
import mimetypes
@@ -33,8 +31,9 @@ from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
- from core.agent.entities import AgentToolEntity
from core.workflow.nodes.tool.entities import ToolEntity
+
+from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
@@ -67,8 +66,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-INTERNAL_BUILTIN_TOOL_PROVIDERS = {"agent_output"}
-
class ApiProviderControllerItem(TypedDict):
provider: ApiToolProvider
@@ -364,7 +361,7 @@ class ToolManager:
app_id: str,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
- variable_pool: Optional[VariablePool] = None,
+ variable_pool: Optional["VariablePool"] = None,
) -> Tool:
"""
get the agent tool runtime
@@ -404,9 +401,9 @@ class ToolManager:
tenant_id: str,
app_id: str,
node_id: str,
- workflow_tool: ToolEntity,
+ workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
- variable_pool: Optional[VariablePool] = None,
+ variable_pool: Optional["VariablePool"] = None,
) -> Tool:
"""
get the workflow tool runtime
@@ -594,10 +591,6 @@ class ToolManager:
cls._hardcoded_providers = {}
cls._builtin_providers_loaded = False
- @classmethod
- def is_internal_builtin_provider(cls, provider_name: str) -> bool:
- return provider_name in INTERNAL_BUILTIN_TOOL_PROVIDERS
-
@classmethod
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
"""
@@ -635,9 +628,9 @@ class ToolManager:
# MySQL: Use window function to achieve same result
sql = """
SELECT id FROM (
- SELECT id,
+ SELECT id,
ROW_NUMBER() OVER (
- PARTITION BY tenant_id, provider
+ PARTITION BY tenant_id, provider
ORDER BY is_default DESC, created_at DESC
) as rn
FROM tool_builtin_providers
@@ -674,8 +667,6 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
- if cls.is_internal_builtin_provider(provider.entity.identity.name):
- continue
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
@@ -1019,7 +1010,7 @@ class ToolManager:
def _convert_tool_parameters_type(
cls,
parameters: list[ToolParameter],
- variable_pool: Optional[VariablePool],
+ variable_pool: Optional["VariablePool"],
tool_configurations: dict[str, Any],
typ: Literal["agent", "workflow", "tool"] = "workflow",
) -> dict[str, Any]:
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index dded96eacf..0827494a48 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -577,12 +577,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
- structured_output_schema=None,
+ structured_output_enabled=self.node_data.structured_output_enabled,
+ structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
- tenant_id=self.tenant_id,
)
for event in generator:
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index b0eacd00ea..1a33d137d6 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
-from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage
+from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCall, ToolCallResult
@@ -156,9 +156,6 @@ class LLMGenerationData(BaseModel):
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
- structured_output: LLMStructuredOutput | None = Field(
- default=None, description="Structured output from tool-only agent runs"
- )
class ThinkTagStreamParser:
@@ -287,7 +284,6 @@ class AggregatedResult(BaseModel):
files: list[File] = Field(default_factory=list)
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
finish_reason: str | None = None
- structured_output: LLMStructuredOutput | None = None
class AgentContext(BaseModel):
@@ -387,7 +383,7 @@ class LLMNodeData(BaseNodeData):
Strategy for handling model reasoning output.
separated: Return clean text (without tags) + reasoning_content field.
- Recommended for new workflows. Enables safe downstream parsing and
+ Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with tags) + reasoning_content field.
diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py
index b1460ae402..86bf8f473e 100644
--- a/api/core/workflow/nodes/llm/llm_utils.py
+++ b/api/core/workflow/nodes/llm/llm_utils.py
@@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str:
"""
Build a text description of generated files for inclusion in context.
- The description includes file_id for context; structured output file paths
- are only supported in sandbox mode.
+ The description includes file_id which can be used by subsequent nodes
+ to reference the files via structured output.
"""
if not files:
return ""
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index 82c2e1f1c1..feb76e6510 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -13,24 +13,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
-from core.agent.output_tools import (
- FINAL_OUTPUT_TOOL,
- FINAL_STRUCTURED_OUTPUT_TOOL,
- OUTPUT_TEXT_TOOL,
- build_agent_output_tools,
-)
from core.agent.patterns import StrategyFactory
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app_assets.constants import AppAssetsAttrs
-from core.file import FileTransferMethod, FileType, file_manager
+from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
-from core.llm_generator.output_parser.file_ref import (
- adapt_schema_for_sandbox_file_paths,
- convert_sandbox_file_paths_in_output,
- detect_file_path_fields,
-)
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
@@ -73,7 +62,6 @@ from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.skill.skill_compiler import SkillCompiler
from core.tools.__base.tool import Tool
-from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.signature import sign_upload_file
from core.tools.tool_manager import ToolManager
from core.variables import (
@@ -197,11 +185,12 @@ class LLMNode(Node[LLMNodeData]):
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
- usage: LLMUsage = LLMUsage.empty_usage()
- finish_reason: str | None = None
- reasoning_content: str = "" # Initialize as empty string for consistency
+ clean_text = ""
+ usage = LLMUsage.empty_usage()
+ finish_reason = None
+ reasoning_content = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
- variable_pool: VariablePool = self.graph_runtime_state.variable_pool
+ variable_pool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
@@ -261,9 +250,8 @@ class LLMNode(Node[LLMNodeData]):
)
query: str | None = None
- memory_config = self.node_data.memory
- if memory_config:
- query = memory_config.query_prompt_template
+ if self.node_data.memory:
+ query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@@ -305,23 +293,9 @@ class LLMNode(Node[LLMNodeData]):
sandbox=self.graph_runtime_state.sandbox,
)
- structured_output_schema: Mapping[str, Any] | None
- structured_output_file_paths: list[str] = []
-
- if self.node_data.structured_output_enabled:
- if not self.node_data.structured_output:
- raise ValueError("structured_output_enabled is True but structured_output is not set")
- raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
- if self.node_data.computer_use:
- structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
- raw_schema
- )
- else:
- if detect_file_path_fields(raw_schema):
- raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
- structured_output_schema = raw_schema
- else:
- structured_output_schema = None
+ # Variables for outputs
+ generation_data: LLMGenerationData | None = None
+ structured_output: LLMStructuredOutput | None = None
if self.node_data.computer_use:
sandbox = self.graph_runtime_state.sandbox
@@ -335,10 +309,7 @@ class LLMNode(Node[LLMNodeData]):
stop=stop,
variable_pool=variable_pool,
tool_dependencies=tool_dependencies,
- structured_output_schema=structured_output_schema,
- structured_output_file_paths=structured_output_file_paths,
)
-
elif self.tool_call_enabled:
generator = self._invoke_llm_with_tools(
model_instance=model_instance,
@@ -348,7 +319,6 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
- structured_output_schema=structured_output_schema,
)
else:
# Use traditional LLM invocation
@@ -358,7 +328,8 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
- structured_output_schema=structured_output_schema,
+ structured_output_enabled=self._node_data.structured_output_enabled,
+ structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
@@ -384,8 +355,6 @@ class LLMNode(Node[LLMNodeData]):
reasoning_content = ""
usage = generation_data.usage
finish_reason = generation_data.finish_reason
- if generation_data.structured_output:
- structured_output = generation_data.structured_output
# Unified process_data building
process_data = {
@@ -409,7 +378,16 @@ class LLMNode(Node[LLMNodeData]):
if tool.enabled
]
- # Build generation field and determine files_to_output first
+ # Unified outputs building
+ outputs = {
+ "text": clean_text,
+ "reasoning_content": reasoning_content,
+ "usage": jsonable_encoder(usage),
+ "finish_reason": finish_reason,
+ "context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
+ }
+
+ # Build generation field
if generation_data:
# Use generation_data from tool invocation (supports multi-turn)
generation = {
@@ -437,15 +415,6 @@ class LLMNode(Node[LLMNodeData]):
}
files_to_output = self._file_outputs
- # Unified outputs building (files passed to context for subsequent node reference)
- outputs = {
- "text": clean_text,
- "reasoning_content": reasoning_content,
- "usage": jsonable_encoder(usage),
- "finish_reason": finish_reason,
- "context": llm_utils.build_context(prompt_messages, clean_text, generation_data, files=files_to_output),
- }
-
outputs["generation"] = generation
if files_to_output:
outputs["files"] = ArrayFileSegment(value=files_to_output)
@@ -524,7 +493,8 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
user_id: str,
- structured_output_schema: Mapping[str, Any] | None,
+ structured_output_enabled: bool,
+ structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@@ -538,7 +508,10 @@ class LLMNode(Node[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
- if structured_output_schema:
+ if structured_output_enabled:
+ output_schema = LLMNode.fetch_structured_output_schema(
+ structured_output=structured_output or {},
+ )
request_start_time = time.perf_counter()
invoke_result = invoke_llm_with_structured_output(
@@ -546,12 +519,12 @@ class LLMNode(Node[LLMNodeData]):
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
- json_schema=structured_output_schema,
+ json_schema=output_schema,
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
user=user_id,
+ tenant_id=tenant_id,
)
-
else:
request_start_time = time.perf_counter()
@@ -1288,16 +1261,18 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
- if isinstance(prompt_content, str):
+ prompt_content_type = type(prompt_content)
+ if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
- elif isinstance(prompt_content, list):
+ elif prompt_content_type == list:
+ prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
- if isinstance(content_item, TextPromptMessageContent):
+ if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
@@ -1307,12 +1282,13 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
- if isinstance(prompt_content, str):
+ if prompt_content_type == str:
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
- elif isinstance(prompt_content, list):
+ elif prompt_content_type == list:
+ prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
- if isinstance(content_item, TextPromptMessageContent):
+ if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
@@ -1438,11 +1414,9 @@ class LLMNode(Node[LLMNodeData]):
if isinstance(item, PromptMessageContext):
if len(item.value_selector) >= 2:
prompt_context_selectors.append(item.value_selector)
- elif isinstance(item, LLMNodeChatModelMessage):
+ elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
variable_template_parser = VariableTemplateParser(template=item.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
- else:
- raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
@@ -1478,14 +1452,13 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config:
enable_jinja = False
- if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
- if prompt_template.edition_type == "jinja2":
- enable_jinja = True
- else:
- for prompt in prompt_template:
- if prompt.edition_type == "jinja2":
+ if isinstance(prompt_template, list):
+ for item in prompt_template:
+ if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
enable_jinja = True
break
+ else:
+ enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
@@ -1898,7 +1871,6 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
- structured_output_schema: Mapping[str, Any] | None,
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
@@ -1915,16 +1887,12 @@ class LLMNode(Node[LLMNodeData]):
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
- tenant_id=self.tenant_id,
- invoke_from=self.invoke_from,
- tool_invoke_from=ToolInvokeFrom.WORKFLOW,
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
- structured_output_schema=structured_output_schema,
)
# Run strategy
@@ -1932,6 +1900,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
+ stream=True,
)
result = yield from self._process_tool_outputs(outputs)
@@ -1945,12 +1914,8 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None,
variable_pool: VariablePool,
tool_dependencies: ToolDependencies | None,
- structured_output_schema: Mapping[str, Any] | None,
- structured_output_file_paths: Sequence[str] | None,
- ) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]:
+ ) -> Generator[NodeEventBase, None, LLMGenerationData]:
result: LLMGenerationData | None = None
- sandbox_output_files: list[File] = []
- structured_output_files: list[File] = []
# FIXME(Mairuis): Async processing for bash session.
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
@@ -1958,9 +1923,6 @@ class LLMNode(Node[LLMNodeData]):
model_features = self._get_model_features(model_instance)
strategy = StrategyFactory.create_strategy(
- tenant_id=self.tenant_id,
- invoke_from=self.invoke_from,
- tool_invoke_from=ToolInvokeFrom.WORKFLOW,
model_features=model_features,
model_instance=model_instance,
tools=[session.bash_tool],
@@ -1968,55 +1930,20 @@ class LLMNode(Node[LLMNodeData]):
max_iterations=self._node_data.max_iterations or 100,
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
- structured_output_schema=structured_output_schema,
)
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
+ stream=True,
)
result = yield from self._process_tool_outputs(outputs)
- if result and result.structured_output and structured_output_file_paths:
- structured_output_payload = result.structured_output.structured_output or {}
- try:
- converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
- output=structured_output_payload,
- file_path_fields=structured_output_file_paths,
- file_resolver=session.download_file,
- )
- except ValueError as exc:
- raise LLMNodeError(str(exc)) from exc
- result = result.model_copy(
- update={"structured_output": LLMStructuredOutput(structured_output=converted_output)}
- )
-
- # Collect output files from sandbox before session ends
- # Files are saved as ToolFiles with valid tool_file_id for later reference
- sandbox_output_files = session.collect_output_files()
-
if result is None:
raise LLMNodeError("SandboxSession exited unexpectedly")
- structured_output = result.structured_output
- if structured_output is not None:
- yield structured_output
-
- # Merge sandbox output files into result
- if sandbox_output_files or structured_output_files:
- result = LLMGenerationData(
- text=result.text,
- reasoning_contents=result.reasoning_contents,
- tool_calls=result.tool_calls,
- sequence=result.sequence,
- usage=result.usage,
- finish_reason=result.finish_reason,
- files=result.files + sandbox_output_files + structured_output_files,
- trace=result.trace,
- )
-
return result
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
@@ -2084,20 +2011,6 @@ class LLMNode(Node[LLMNodeData]):
logger.warning("Failed to load tool %s: %s", tool, str(e))
continue
- structured_output_schema = None
- if self._node_data.structured_output_enabled:
- structured_output_schema = LLMNode.fetch_structured_output_schema(
- structured_output=self._node_data.structured_output or {},
- )
- tool_instances.extend(
- build_agent_output_tools(
- tenant_id=self.tenant_id,
- invoke_from=self.invoke_from,
- tool_invoke_from=ToolInvokeFrom.WORKFLOW,
- structured_output_schema=structured_output_schema,
- )
- )
-
return tool_instances
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
@@ -2261,45 +2174,18 @@ class LLMNode(Node[LLMNodeData]):
# Add tool call to pending list for model segment
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
- output_tool_names = {OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL}
-
- if tool_name not in output_tool_names:
- yield ToolCallChunkEvent(
- selector=[self._node_id, "generation", "tool_calls"],
- chunk=tool_arguments,
- tool_call=ToolCall(
- id=tool_call_id,
- name=tool_name,
- arguments=tool_arguments,
- icon=tool_icon,
- icon_dark=tool_icon_dark,
- ),
- is_final=False,
- )
-
- if tool_name in output_tool_names:
- content = ""
- if tool_name in (OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL):
- content = payload.tool_args["text"]
- elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
- raw_content = json.dumps(
- payload.tool_args["data"],
- ensure_ascii=False,
- indent=2
- )
- content = f"```json\n{raw_content}\n```"
-
- if content:
- yield StreamChunkEvent(
- selector=[self._node_id, "text"],
- chunk=content,
- is_final=False,
- )
- yield StreamChunkEvent(
- selector=[self._node_id, "generation", "content"],
- chunk=content,
- is_final=False,
- )
+ yield ToolCallChunkEvent(
+ selector=[self._node_id, "generation", "tool_calls"],
+ chunk=tool_arguments,
+ tool_call=ToolCall(
+ id=tool_call_id,
+ name=tool_name,
+ arguments=tool_arguments,
+ icon=tool_icon,
+ icon_dark=tool_icon_dark,
+ ),
+ is_final=False,
+ )
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = payload.tool_name
@@ -2543,7 +2429,6 @@ class LLMNode(Node[LLMNodeData]):
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
- # FIXME: These if will never happen
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
@@ -2586,15 +2471,8 @@ class LLMNode(Node[LLMNodeData]):
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
)
- text_content: str
- if aggregate.text:
- text_content = aggregate.text
- elif aggregate.structured_output:
- text_content = json.dumps(aggregate.structured_output.structured_output)
- else:
- raise ValueError("Aggregate must have either text or structured output.")
return LLMGenerationData(
- text=text_content,
+ text=aggregate.text,
reasoning_contents=buffers.reasoning_per_turn,
tool_calls=tool_calls_for_generation,
sequence=sequence,
@@ -2602,7 +2480,6 @@ class LLMNode(Node[LLMNodeData]):
finish_reason=aggregate.finish_reason,
files=aggregate.files,
trace=trace_state.trace_segments,
- structured_output=aggregate.structured_output,
)
def _process_tool_outputs(
@@ -2613,33 +2490,22 @@ class LLMNode(Node[LLMNodeData]):
state = ToolOutputState()
try:
- while True:
- output = next(outputs)
+ for output in outputs:
if isinstance(output, AgentLog):
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
else:
- continue
+ yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
except StopIteration as exception:
- if not isinstance(exception.value, AgentResult):
- raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
- state.agent.agent_result = exception.value
- agent_result = state.agent.agent_result
- if not agent_result:
- raise ValueError("No agent result found in tool outputs")
- output_payload = agent_result.output
- if isinstance(output_payload, dict):
- state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload)
- state.aggregate.text = json.dumps(output_payload)
- elif isinstance(output_payload, str):
- state.aggregate.text = output_payload
- else:
- raise ValueError(f"Unexpected output type: {type(output_payload)}")
+ if isinstance(getattr(exception, "value", None), AgentResult):
+ state.agent.agent_result = exception.value
- state.aggregate.files = state.agent.agent_result.files
- if state.agent.agent_result.usage:
- state.aggregate.usage = state.agent.agent_result.usage
- if state.agent.agent_result.finish_reason:
- state.aggregate.finish_reason = state.agent.agent_result.finish_reason
+ if state.agent.agent_result:
+ state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
+ state.aggregate.files = state.agent.agent_result.files
+ if state.agent.agent_result.usage:
+ state.aggregate.usage = state.agent.agent_result.usage
+ if state.agent.agent_result.finish_reason:
+ state.aggregate.finish_reason = state.agent.agent_result.finish_reason
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()
diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
index c8dfe7ccf9..564e548e9f 100644
--- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py
+++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
@@ -156,7 +156,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
- structured_output_schema=None,
+ structured_output_enabled=False,
+ structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index 57c52c1c86..4c37a867b1 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -91,8 +91,6 @@ class BuiltinToolManageService:
:return: the list of tools
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
- return []
tools = provider_controller.get_tools()
result: list[ToolApiEntity] = []
@@ -543,8 +541,6 @@ class BuiltinToolManageService:
for provider_controller in provider_controllers:
try:
- if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
- continue
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
diff --git a/api/tests/fixtures/file output schema.yml b/api/tests/fixtures/file output schema.yml
index 8ae3072572..37fc9c72c7 100644
--- a/api/tests/fixtures/file output schema.yml
+++ b/api/tests/fixtures/file output schema.yml
@@ -126,8 +126,9 @@ workflow:
additionalProperties: false
properties:
image:
- description: Sandbox file path of the selected image
- type: file
+ description: File ID (UUID) of the selected image
+ format: dify-file-ref
+ type: string
required:
- image
type: object
diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py
index aec077124d..4a613e35b0 100644
--- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py
+++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py
@@ -1,3 +1,4 @@
+import json
from collections.abc import Generator
from core.agent.entities import AgentScratchpadUnit
@@ -63,7 +64,7 @@ def test_cot_output_parser():
output += result
elif isinstance(result, AgentScratchpadUnit.Action):
if test_case["action"]:
- assert result.model_dump() == test_case["action"]
- output += result.model_dump_json()
+ assert result.to_dict() == test_case["action"]
+ output += json.dumps(result.to_dict())
if test_case["output"]:
assert output == test_case["output"]
diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py
index dd42b31bad..b0e0d44940 100644
--- a/api/tests/unit_tests/core/agent/patterns/test_base.py
+++ b/api/tests/unit_tests/core/agent/patterns/test_base.py
@@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
class ConcreteAgentPattern(AgentPattern):
"""Concrete implementation of AgentPattern for testing."""
- def run(self, prompt_messages, model_parameters, stop=[]):
+ def run(self, prompt_messages, model_parameters, stop=[], stream=True):
"""Minimal implementation for testing."""
yield from []
diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py
index 43d28d39b7..d9301ccfe0 100644
--- a/api/tests/unit_tests/core/agent/test_agent_app_runner.py
+++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py
@@ -329,15 +329,13 @@ class TestAgentLogProcessing:
)
result = AgentResult(
- output="Final answer",
+ text="Final answer",
files=[],
usage=usage,
finish_reason="stop",
)
- output_payload = result.output
- assert isinstance(output_payload, str)
- assert output_payload == "Final answer"
+ assert result.text == "Final answer"
assert result.files == []
assert result.usage == usage
assert result.finish_reason == "stop"
diff --git a/api/tests/unit_tests/core/agent/test_entities.py b/api/tests/unit_tests/core/agent/test_entities.py
index 61b0ac0052..5136f48aab 100644
--- a/api/tests/unit_tests/core/agent/test_entities.py
+++ b/api/tests/unit_tests/core/agent/test_entities.py
@@ -153,7 +153,7 @@ class TestAgentScratchpadUnit:
action_input={"query": "test"},
)
- result = action.model_dump()
+ result = action.to_dict()
assert result == {
"action": "search",
diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py
index 2c16de8f6f..6d18ac7fc9 100644
--- a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py
+++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py
@@ -1,93 +1,269 @@
"""
-Unit tests for sandbox file path detection and conversion.
+Unit tests for file reference detection and conversion.
"""
+import uuid
+from unittest.mock import MagicMock, patch
+
import pytest
from core.file import File, FileTransferMethod, FileType
from core.llm_generator.output_parser.file_ref import (
- FILE_PATH_DESCRIPTION_SUFFIX,
- adapt_schema_for_sandbox_file_paths,
- convert_sandbox_file_paths_in_output,
- detect_file_path_fields,
- is_file_path_property,
+ FILE_REF_FORMAT,
+ convert_file_refs_in_output,
+ detect_file_ref_fields,
+ is_file_ref_property,
)
from core.variables.segments import ArrayFileSegment, FileSegment
-def _build_file(file_id: str) -> File:
- return File(
- id=file_id,
- tenant_id="tenant_123",
- type=FileType.IMAGE,
- transfer_method=FileTransferMethod.TOOL_FILE,
- filename="test.png",
- extension=".png",
- mime_type="image/png",
- size=128,
- related_id=file_id,
- storage_key="sandbox/path",
- )
+class TestIsFileRefProperty:
+ """Tests for is_file_ref_property function."""
+
+ def test_valid_file_ref(self):
+ schema = {"type": "string", "format": FILE_REF_FORMAT}
+ assert is_file_ref_property(schema) is True
+
+ def test_invalid_type(self):
+ schema = {"type": "number", "format": FILE_REF_FORMAT}
+ assert is_file_ref_property(schema) is False
+
+ def test_missing_format(self):
+ schema = {"type": "string"}
+ assert is_file_ref_property(schema) is False
+
+ def test_wrong_format(self):
+ schema = {"type": "string", "format": "uuid"}
+ assert is_file_ref_property(schema) is False
-class TestFilePathSchema:
- def test_is_file_path_property(self):
- assert is_file_path_property({"type": "file"}) is True
- assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True
- assert is_file_path_property({"type": "string"}) is False
+class TestDetectFileRefFields:
+ """Tests for detect_file_ref_fields function."""
- def test_detect_file_path_fields(self):
+ def test_simple_file_ref(self):
schema = {
"type": "object",
"properties": {
- "image": {"type": "string", "format": "dify-file-ref"},
- "files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}},
- "meta": {"type": "object", "properties": {"doc": {"type": "file"}}},
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
- assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
+ paths = detect_file_ref_fields(schema)
+ assert paths == ["image"]
- def test_adapt_schema_for_sandbox_file_paths(self):
+ def test_multiple_file_refs(self):
schema = {
"type": "object",
"properties": {
- "image": {"type": "string", "format": "dify-file-ref"},
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ "document": {"type": "string", "format": FILE_REF_FORMAT},
"name": {"type": "string"},
},
}
- adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
+ paths = detect_file_ref_fields(schema)
+ assert set(paths) == {"image", "document"}
- assert set(fields) == {"image"}
- adapted_image = adapted["properties"]["image"]
- assert adapted_image["type"] == "string"
- assert "format" not in adapted_image
- assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
+ def test_array_of_file_refs(self):
+ schema = {
+ "type": "object",
+ "properties": {
+ "files": {
+ "type": "array",
+ "items": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ },
+ }
+ paths = detect_file_ref_fields(schema)
+ assert paths == ["files[*]"]
+
+ def test_nested_file_ref(self):
+ schema = {
+ "type": "object",
+ "properties": {
+ "data": {
+ "type": "object",
+ "properties": {
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ },
+ },
+ }
+ paths = detect_file_ref_fields(schema)
+ assert paths == ["data.image"]
+
+ def test_no_file_refs(self):
+ schema = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "count": {"type": "number"},
+ },
+ }
+ paths = detect_file_ref_fields(schema)
+ assert paths == []
+
+ def test_empty_schema(self):
+ schema = {}
+ paths = detect_file_ref_fields(schema)
+ assert paths == []
+
+ def test_mixed_schema(self):
+ schema = {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string"},
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ "documents": {
+ "type": "array",
+ "items": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ },
+ }
+ paths = detect_file_ref_fields(schema)
+ assert set(paths) == {"image", "documents[*]"}
-class TestConvertSandboxFilePaths:
- def test_convert_sandbox_file_paths(self):
- output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"}
+class TestConvertFileRefsInOutput:
+ """Tests for convert_file_refs_in_output function."""
- def resolver(path: str) -> File:
- return _build_file(path)
+ @pytest.fixture
+ def mock_file(self):
+ """Create a mock File object with all required attributes."""
+ file = MagicMock(spec=File)
+ file.type = FileType.IMAGE
+ file.transfer_method = FileTransferMethod.TOOL_FILE
+ file.related_id = "test-related-id"
+ file.remote_url = None
+ file.tenant_id = "tenant_123"
+ file.id = None
+ file.filename = "test.png"
+ file.extension = ".png"
+ file.mime_type = "image/png"
+ file.size = 1024
+ file.dify_model_identity = "__dify__file__"
+ return file
- converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver)
+ @pytest.fixture
+ def mock_build_from_mapping(self, mock_file):
+ """Mock the build_from_mapping function."""
+ with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
+ mock.return_value = mock_file
+ yield mock
- assert isinstance(converted["image"], FileSegment)
- assert isinstance(converted["files"], ArrayFileSegment)
- assert converted["name"] == "demo"
- assert [file.id for file in files] == ["a.png", "b.png", "c.png"]
+ def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
+ file_id = str(uuid.uuid4())
+ output = {"image": file_id}
+ schema = {
+ "type": "object",
+ "properties": {
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ }
- def test_invalid_path_value_raises(self):
- def resolver(path: str) -> File:
- return _build_file(path)
+ result = convert_file_refs_in_output(output, schema, "tenant_123")
- with pytest.raises(ValueError):
- convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver)
+ # Result should be wrapped in FileSegment
+ assert isinstance(result["image"], FileSegment)
+ assert result["image"].value == mock_file
+ mock_build_from_mapping.assert_called_once_with(
+ mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
+ tenant_id="tenant_123",
+ )
- def test_no_file_paths_returns_output(self):
- output = {"name": "demo"}
- converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
+ def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
+ file_id1 = str(uuid.uuid4())
+ file_id2 = str(uuid.uuid4())
+ output = {"files": [file_id1, file_id2]}
+ schema = {
+ "type": "object",
+ "properties": {
+ "files": {
+ "type": "array",
+ "items": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ },
+ }
- assert converted == output
- assert files == []
+ result = convert_file_refs_in_output(output, schema, "tenant_123")
+
+ # Result should be wrapped in ArrayFileSegment
+ assert isinstance(result["files"], ArrayFileSegment)
+ assert list(result["files"].value) == [mock_file, mock_file]
+ assert mock_build_from_mapping.call_count == 2
+
+ def test_no_conversion_without_file_refs(self):
+ output = {"name": "test", "count": 5}
+ schema = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "count": {"type": "number"},
+ },
+ }
+
+ result = convert_file_refs_in_output(output, schema, "tenant_123")
+
+ assert result == {"name": "test", "count": 5}
+
+ def test_invalid_uuid_returns_none(self):
+ output = {"image": "not-a-valid-uuid"}
+ schema = {
+ "type": "object",
+ "properties": {
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ }
+
+ result = convert_file_refs_in_output(output, schema, "tenant_123")
+
+ assert result["image"] is None
+
+ def test_file_not_found_returns_none(self):
+ file_id = str(uuid.uuid4())
+ output = {"image": file_id}
+ schema = {
+ "type": "object",
+ "properties": {
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ }
+
+ with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
+ mock.side_effect = ValueError("File not found")
+ result = convert_file_refs_in_output(output, schema, "tenant_123")
+
+ assert result["image"] is None
+
+ def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
+ file_id = str(uuid.uuid4())
+ output = {"query": "search term", "image": file_id, "count": 10}
+ schema = {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string"},
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ "count": {"type": "number"},
+ },
+ }
+
+ result = convert_file_refs_in_output(output, schema, "tenant_123")
+
+ assert result["query"] == "search term"
+ assert isinstance(result["image"], FileSegment)
+ assert result["image"].value == mock_file
+ assert result["count"] == 10
+
+ def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
+ file_id = str(uuid.uuid4())
+ original = {"image": file_id}
+ output = dict(original)
+ schema = {
+ "type": "object",
+ "properties": {
+ "image": {"type": "string", "format": FILE_REF_FORMAT},
+ },
+ }
+
+ convert_file_refs_in_output(output, schema, "tenant_123")
+
+ # Original should still contain the string ID
+ assert original["image"] == file_id
diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
index 30f2f8cc73..df73c29004 100644
--- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
+++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
@@ -1,6 +1,4 @@
-from collections.abc import Mapping
from decimal import Decimal
-from typing import Any, NotRequired, TypedDict
from unittest.mock import MagicMock, patch
import pytest
@@ -8,14 +6,16 @@ from pydantic import BaseModel, ConfigDict
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
+ _get_default_value_for_type,
fill_defaults_from_schema,
- get_default_value_for_type,
invoke_llm_with_pydantic_model,
invoke_llm_with_structured_output,
)
-from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import (
LLMResult,
+ LLMResultChunk,
+ LLMResultChunkDelta,
+ LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
LLMUsage,
)
@@ -25,29 +25,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
-from core.model_runtime.entities.model_entities import (
- AIModelEntity,
- ModelType,
- ParameterRule,
- ParameterType,
-)
-
-
-class StructuredOutputTestCase(TypedDict):
- name: str
- provider: str
- model_name: str
- support_structure_output: bool
- stream: bool
- json_schema: Mapping[str, Any]
- expected_llm_response: LLMResult
- expected_result_type: type[LLMResultWithStructuredOutput] | None
- should_raise: bool
- expected_error: NotRequired[type[OutputParserError]]
- parameter_rules: NotRequired[list[ParameterRule]]
-
-
-SchemaData = dict[str, Any]
+from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
@@ -93,7 +71,7 @@ def get_model_instance() -> MagicMock:
def test_structured_output_parser():
"""Test cases for invoke_llm_with_structured_output function"""
- testcases: list[StructuredOutputTestCase] = [
+ testcases = [
# Test case 1: Model with native structured output support, non-streaming
{
"name": "native_structured_output_non_streaming",
@@ -110,6 +88,39 @@ def test_structured_output_parser():
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
+ # Test case 2: Model with native structured output support, streaming
+ {
+ "name": "native_structured_output_streaming",
+ "provider": "openai",
+ "model_name": "gpt-4o",
+ "support_structure_output": True,
+ "stream": True,
+ "json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
+ "expected_llm_response": [
+ LLMResultChunk(
+ model="gpt-4o",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content='{"name":'),
+ usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
+ ),
+ ),
+ LLMResultChunk(
+ model="gpt-4o",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=' "test"}'),
+ usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
+ ),
+ ),
+ ],
+ "expected_result_type": "generator",
+ "should_raise": False,
+ },
# Test case 3: Model without native structured output support, non-streaming
{
"name": "prompt_based_structured_output_non_streaming",
@@ -126,24 +137,78 @@ def test_structured_output_parser():
"expected_result_type": LLMResultWithStructuredOutput,
"should_raise": False,
},
+ # Test case 4: Model without native structured output support, streaming
{
- "name": "non_streaming_with_list_content",
+ "name": "prompt_based_structured_output_streaming",
+ "provider": "anthropic",
+ "model_name": "claude-3-sonnet",
+ "support_structure_output": False,
+ "stream": True,
+ "json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
+ "expected_llm_response": [
+ LLMResultChunk(
+ model="claude-3-sonnet",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content='{"answer": "test'),
+ usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
+ ),
+ ),
+ LLMResultChunk(
+ model="claude-3-sonnet",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=' response"}'),
+ usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
+ ),
+ ),
+ ],
+ "expected_result_type": "generator",
+ "should_raise": False,
+ },
+ # Test case 5: Streaming with list content
+ {
+ "name": "streaming_with_list_content",
"provider": "openai",
"model_name": "gpt-4o",
"support_structure_output": True,
- "stream": False,
+ "stream": True,
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
- "expected_llm_response": LLMResult(
- model="gpt-4o",
- message=AssistantPromptMessage(
- content=[
- TextPromptMessageContent(data='{"data":'),
- TextPromptMessageContent(data=' "value"}'),
- ]
+ "expected_llm_response": [
+ LLMResultChunk(
+ model="gpt-4o",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(
+ content=[
+ TextPromptMessageContent(data='{"data":'),
+ ]
+ ),
+ usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
+ ),
),
- usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
- ),
- "expected_result_type": LLMResultWithStructuredOutput,
+ LLMResultChunk(
+ model="gpt-4o",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(
+ content=[
+ TextPromptMessageContent(data=' "value"}'),
+ ]
+ ),
+ usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
+ ),
+ ),
+ ],
+ "expected_result_type": "generator",
"should_raise": False,
},
# Test case 6: Error case - non-string LLM response content (non-streaming)
@@ -188,13 +253,7 @@ def test_structured_output_parser():
"stream": False,
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
"parameter_rules": [
- ParameterRule(
- name="response_format",
- label=I18nObject(en_US="response_format"),
- type=ParameterType.STRING,
- required=False,
- options=["json_schema"],
- ),
+ MagicMock(name="response_format", options=["json_schema"], required=False),
],
"expected_llm_response": LLMResult(
model="gpt-4o",
@@ -213,13 +272,7 @@ def test_structured_output_parser():
"stream": False,
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
"parameter_rules": [
- ParameterRule(
- name="response_format",
- label=I18nObject(en_US="response_format"),
- type=ParameterType.STRING,
- required=False,
- options=["JSON"],
- ),
+ MagicMock(name="response_format", options=["JSON"], required=False),
],
"expected_llm_response": LLMResult(
model="claude-3-sonnet",
@@ -232,72 +285,89 @@ def test_structured_output_parser():
]
for case in testcases:
- provider = case["provider"]
- model_name = case["model_name"]
- support_structure_output = case["support_structure_output"]
- json_schema = case["json_schema"]
- stream = case["stream"]
+ # Setup model entity
+ model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
- model_schema = get_model_entity(provider, model_name, support_structure_output)
-
- parameter_rules = case.get("parameter_rules")
- if parameter_rules is not None:
- model_schema.parameter_rules = parameter_rules
+ # Add parameter rules if specified
+ if "parameter_rules" in case:
+ model_schema.parameter_rules = case["parameter_rules"]
+ # Setup model instance
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = case["expected_llm_response"]
+ # Setup prompt messages
prompt_messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Generate a response according to the schema."),
]
if case["should_raise"]:
- expected_error = case.get("expected_error", OutputParserError)
- with pytest.raises(expected_error): # noqa: PT012
- if stream:
+ # Test error cases
+ with pytest.raises(case["expected_error"]): # noqa: PT012
+ if case["stream"]:
result_generator = invoke_llm_with_structured_output(
- provider=provider,
+ provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
- json_schema=json_schema,
+ json_schema=case["json_schema"],
)
+ # Consume the generator to trigger the error
list(result_generator)
else:
invoke_llm_with_structured_output(
- provider=provider,
+ provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
- json_schema=json_schema,
+ json_schema=case["json_schema"],
)
else:
+ # Test successful cases
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
+ # Configure json_repair mock for cases that need it
if case["name"] == "json_repair_scenario":
mock_json_repair.return_value = {"name": "test"}
result = invoke_llm_with_structured_output(
- provider=provider,
+ provider=case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
- json_schema=json_schema,
+ json_schema=case["json_schema"],
model_parameters={"temperature": 0.7, "max_tokens": 100},
user="test_user",
)
- expected_result_type = case["expected_result_type"]
- assert expected_result_type is not None
- assert isinstance(result, expected_result_type)
- assert result.model == model_name
- assert result.structured_output is not None
- assert isinstance(result.structured_output, dict)
+ if case["expected_result_type"] == "generator":
+ # Test streaming results
+ assert hasattr(result, "__iter__")
+ chunks = list(result)
+ assert len(chunks) > 0
+ # Verify all chunks are LLMResultChunkWithStructuredOutput
+ for chunk in chunks[:-1]: # All except last
+ assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
+ assert chunk.model == case["model_name"]
+
+ # Last chunk should have structured output
+ last_chunk = chunks[-1]
+ assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
+ assert last_chunk.structured_output is not None
+ assert isinstance(last_chunk.structured_output, dict)
+ else:
+ # Test non-streaming results
+ assert isinstance(result, case["expected_result_type"])
+ assert result.model == case["model_name"]
+ assert result.structured_output is not None
+ assert isinstance(result.structured_output, dict)
+
+ # Verify model_instance.invoke_llm was called with correct parameters
model_instance.invoke_llm.assert_called_once()
call_args = model_instance.invoke_llm.call_args
- assert call_args.kwargs["stream"] == stream
+ assert call_args.kwargs["stream"] == case["stream"]
assert call_args.kwargs["user"] == "test_user"
assert "temperature" in call_args.kwargs["model_parameters"]
assert "max_tokens" in call_args.kwargs["model_parameters"]
@@ -306,32 +376,45 @@ def test_structured_output_parser():
def test_parse_structured_output_edge_cases():
"""Test edge cases for structured output parsing"""
- provider = "deepseek"
- model_name = "deepseek-r1"
- support_structure_output = False
- json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}}
- expected_llm_response = LLMResult(
- model="deepseek-r1",
- message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
- usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
+ # Test case with list that contains dict (reasoning model scenario)
+ testcase_list_with_dict = {
+ "name": "list_with_dict_parsing",
+ "provider": "deepseek",
+ "model_name": "deepseek-r1",
+ "support_structure_output": False,
+ "stream": False,
+ "json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
+ "expected_llm_response": LLMResult(
+ model="deepseek-r1",
+ message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
+ usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
+ ),
+ "expected_result_type": LLMResultWithStructuredOutput,
+ "should_raise": False,
+ }
+
+ # Setup for list parsing test
+ model_schema = get_model_entity(
+ testcase_list_with_dict["provider"],
+ testcase_list_with_dict["model_name"],
+ testcase_list_with_dict["support_structure_output"],
)
- model_schema = get_model_entity(provider, model_name, support_structure_output)
-
model_instance = get_model_instance()
- model_instance.invoke_llm.return_value = expected_llm_response
+ model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
prompt_messages = [UserPromptMessage(content="Test reasoning")]
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
+ # Mock json_repair to return a list with dict
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
result = invoke_llm_with_structured_output(
- provider=provider,
+ provider=testcase_list_with_dict["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
- json_schema=json_schema,
+ json_schema=testcase_list_with_dict["json_schema"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
@@ -341,16 +424,18 @@ def test_parse_structured_output_edge_cases():
def test_model_specific_schema_preparation():
"""Test schema preparation for different model types"""
- provider = "google"
- model_name = "gemini-pro"
- support_structure_output = True
- json_schema: SchemaData = {
- "type": "object",
- "properties": {"result": {"type": "boolean"}},
- "additionalProperties": False,
+ # Test Gemini model
+ gemini_case = {
+ "provider": "google",
+ "model_name": "gemini-pro",
+ "support_structure_output": True,
+ "stream": False,
+ "json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
}
- model_schema = get_model_entity(provider, model_name, support_structure_output)
+ model_schema = get_model_entity(
+ gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
+ )
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
@@ -362,11 +447,11 @@ def test_model_specific_schema_preparation():
prompt_messages = [UserPromptMessage(content="Test")]
result = invoke_llm_with_structured_output(
- provider=provider,
+ provider=gemini_case["provider"],
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
- json_schema=json_schema,
+ json_schema=gemini_case["json_schema"],
)
assert isinstance(result, LLMResultWithStructuredOutput)
@@ -408,26 +493,40 @@ def test_structured_output_with_pydantic_model_non_streaming():
assert result.name == "test"
-def test_structured_output_with_pydantic_model_list_content():
+def test_structured_output_with_pydantic_model_streaming():
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
model_instance = get_model_instance()
- model_instance.invoke_llm.return_value = LLMResult(
- model="gpt-4o",
- message=AssistantPromptMessage(
- content=[
- TextPromptMessageContent(data='{"name":'),
- TextPromptMessageContent(data=' "test"}'),
- ]
- ),
- usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
- )
+
+ def mock_streaming_response():
+ yield LLMResultChunk(
+ model="gpt-4o",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content='{"name":'),
+ usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
+ ),
+ )
+ yield LLMResultChunk(
+ model="gpt-4o",
+ prompt_messages=[UserPromptMessage(content="test")],
+ system_fingerprint="test",
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=' "test"}'),
+ usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
+ ),
+ )
+
+ model_instance.invoke_llm.return_value = mock_streaming_response()
result = invoke_llm_with_pydantic_model(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
- output_model=ExampleOutput,
+ output_model=ExampleOutput
)
assert isinstance(result, ExampleOutput)
@@ -449,51 +548,51 @@ def test_structured_output_with_pydantic_model_validation_error():
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="test")],
- output_model=ExampleOutput,
+ output_model=ExampleOutput
)
class TestGetDefaultValueForType:
- """Test cases for get_default_value_for_type function"""
+ """Test cases for _get_default_value_for_type function"""
def test_string_type(self):
- assert get_default_value_for_type("string") == ""
+ assert _get_default_value_for_type("string") == ""
def test_object_type(self):
- assert get_default_value_for_type("object") == {}
+ assert _get_default_value_for_type("object") == {}
def test_array_type(self):
- assert get_default_value_for_type("array") == []
+ assert _get_default_value_for_type("array") == []
def test_number_type(self):
- assert get_default_value_for_type("number") == 0
+ assert _get_default_value_for_type("number") == 0
def test_integer_type(self):
- assert get_default_value_for_type("integer") == 0
+ assert _get_default_value_for_type("integer") == 0
def test_boolean_type(self):
- assert get_default_value_for_type("boolean") is False
+ assert _get_default_value_for_type("boolean") is False
def test_null_type(self):
- assert get_default_value_for_type("null") is None
+ assert _get_default_value_for_type("null") is None
def test_none_type(self):
- assert get_default_value_for_type(None) is None
+ assert _get_default_value_for_type(None) is None
def test_unknown_type(self):
- assert get_default_value_for_type("unknown") is None
+ assert _get_default_value_for_type("unknown") is None
def test_union_type_string_null(self):
# ["string", "null"] should return "" (first non-null type)
- assert get_default_value_for_type(["string", "null"]) == ""
+ assert _get_default_value_for_type(["string", "null"]) == ""
def test_union_type_null_first(self):
# ["null", "integer"] should return 0 (first non-null type)
- assert get_default_value_for_type(["null", "integer"]) == 0
+ assert _get_default_value_for_type(["null", "integer"]) == 0
def test_union_type_only_null(self):
# ["null"] should return None
- assert get_default_value_for_type(["null"]) is None
+ assert _get_default_value_for_type(["null"]) is None
class TestFillDefaultsFromSchema:
@@ -501,7 +600,7 @@ class TestFillDefaultsFromSchema:
def test_simple_required_fields(self):
"""Test filling simple required fields"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
@@ -510,7 +609,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["name", "age"],
}
- output: SchemaData = {"name": "Alice"}
+ output = {"name": "Alice"}
result = fill_defaults_from_schema(output, schema)
@@ -520,7 +619,7 @@ class TestFillDefaultsFromSchema:
def test_non_required_fields_not_filled(self):
"""Test that non-required fields are not filled"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"required_field": {"type": "string"},
@@ -528,7 +627,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["required_field"],
}
- output: SchemaData = {}
+ output = {}
result = fill_defaults_from_schema(output, schema)
@@ -537,7 +636,7 @@ class TestFillDefaultsFromSchema:
def test_nested_object_required_fields(self):
"""Test filling nested object required fields"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"user": {
@@ -560,7 +659,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["user"],
}
- output: SchemaData = {
+ output = {
"user": {
"name": "Alice",
"address": {
@@ -585,7 +684,7 @@ class TestFillDefaultsFromSchema:
def test_missing_nested_object_created(self):
"""Test that missing required nested objects are created"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"metadata": {
@@ -599,7 +698,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["metadata"],
}
- output: SchemaData = {}
+ output = {}
result = fill_defaults_from_schema(output, schema)
@@ -611,7 +710,7 @@ class TestFillDefaultsFromSchema:
def test_all_types_default_values(self):
"""Test default values for all types"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"str_field": {"type": "string"},
@@ -623,7 +722,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
}
- output: SchemaData = {}
+ output = {}
result = fill_defaults_from_schema(output, schema)
@@ -638,7 +737,7 @@ class TestFillDefaultsFromSchema:
def test_existing_values_preserved(self):
"""Test that existing values are not overwritten"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
@@ -646,7 +745,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["name", "count"],
}
- output: SchemaData = {"name": "Bob", "count": 42}
+ output = {"name": "Bob", "count": 42}
result = fill_defaults_from_schema(output, schema)
@@ -654,7 +753,7 @@ class TestFillDefaultsFromSchema:
def test_complex_nested_structure(self):
"""Test complex nested structure with multiple levels"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"user": {
@@ -690,7 +789,7 @@ class TestFillDefaultsFromSchema:
},
"required": ["user", "tags", "metadata", "is_active"],
}
- output: SchemaData = {
+ output = {
"user": {
"name": "Alice",
"age": 25,
@@ -730,8 +829,8 @@ class TestFillDefaultsFromSchema:
def test_empty_schema(self):
"""Test with empty schema"""
- schema: SchemaData = {}
- output: SchemaData = {"any": "value"}
+ schema = {}
+ output = {"any": "value"}
result = fill_defaults_from_schema(output, schema)
@@ -739,14 +838,14 @@ class TestFillDefaultsFromSchema:
def test_schema_without_required(self):
"""Test schema without required field"""
- schema: SchemaData = {
+ schema = {
"type": "object",
"properties": {
"optional1": {"type": "string"},
"optional2": {"type": "integer"},
},
}
- output: SchemaData = {}
+ output = {}
result = fill_defaults_from_schema(output, schema)