From e0082dbf182a00d345baa3f59bb3e6ffdfbe702b Mon Sep 17 00:00:00 2001 From: Stream Date: Wed, 4 Feb 2026 21:13:07 +0800 Subject: [PATCH] revert: add tools for output in agent mode feat: hide output tools and improve JSON formatting for structured output feat: hide output tools and improve JSON formatting for structured output fix: handle prompt template correctly to extract selectors for step run fix: emit StreamChunkEvent correctly for sandbox agent chore: better debug message fix: incorrect output tool runtime selection fix: type issues fix: align parameter list fix: align parameter list fix: hide internal builtin providers from tool list vibe: implement file structured output vibe: implement file structured output fix: refix parameter for tool fix: crash fix: crash refactor: remove union types fix: type check Merge branch 'feat/structured-output-with-sandbox' into feat/support-agent-sandbox fix: provide json as text fix: provide json as text fix: get AgentResult correctly fix: provides correct prompts, tools and terminal predicates fix: provides correct prompts, tools and terminal predicates fix: circular import feat: support structured output in sandbox and tool mode --- api/core/agent/agent_app_runner.py | 61 ++- api/core/agent/base_agent_runner.py | 10 - api/core/agent/entities.py | 28 +- .../agent/output_parser/cot_output_parser.py | 46 +- api/core/agent/output_tools.py | 108 ----- api/core/agent/patterns/base.py | 6 +- api/core/agent/patterns/function_call.py | 118 ++---- api/core/agent/patterns/react.py | 167 +++----- api/core/agent/patterns/strategy_factory.py | 23 +- api/core/agent/prompt/template.py | 20 +- api/core/llm_generator/llm_generator.py | 13 +- .../llm_generator/output_parser/file_ref.py | 290 +++++++------ .../output_parser/structured_output.py | 134 +++--- api/core/plugin/utils/chunk_merger.py | 10 +- api/core/sandbox/bash/session.py | 136 +++--- api/core/tools/__base/tool.py | 4 +- .../providers/agent_output/_assets/icon.svg | 5 - .../providers/agent_output/agent_output.py | 8 - .../providers/agent_output/agent_output.yaml | 10 - .../agent_output/tools/final_output_answer.py | 17 - .../tools/final_output_answer.yaml | 18 - .../tools/final_structured_output.py | 17 - .../tools/final_structured_output.yaml | 18 - .../agent_output/tools/illegal_output.py | 21 - .../agent_output/tools/illegal_output.yaml | 18 - .../agent_output/tools/output_text.py | 17 - .../agent_output/tools/output_text.yaml | 18 - api/core/tools/tool_manager.py | 25 +- .../knowledge_retrieval_node.py | 4 +- api/core/workflow/nodes/llm/entities.py | 8 +- api/core/workflow/nodes/llm/llm_utils.py | 4 +- api/core/workflow/nodes/llm/node.py | 280 ++++--------- .../question_classifier_node.py | 3 +- .../tools/builtin_tools_manage_service.py | 4 - api/tests/fixtures/file output schema.yml | 5 +- .../output_parser/test_cot_output_parser.py | 5 +- .../core/agent/patterns/test_base.py | 2 +- .../core/agent/test_agent_app_runner.py | 6 +- .../unit_tests/core/agent/test_entities.py | 2 +- .../output_parser/test_file_ref.py | 290 ++++++++++--- .../test_structured_output_parser.py | 393 +++++++++++------- 41 files changed, 1014 insertions(+), 1358 deletions(-) delete mode 100644 api/core/agent/output_tools.py delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/agent_output.py delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py delete mode 100644 api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py index c67d2d831d..2ee0a23aab 100644 --- a/api/core/agent/agent_app_runner.py +++ b/api/core/agent/agent_app_runner.py @@ -1,4 +1,3 @@ -import json import logging from collections.abc import Generator from copy import deepcopy @@ -24,7 +23,7 @@ from core.model_runtime.entities import ( from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMeta +from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine from models.model import Message @@ -103,13 +102,11 @@ class AgentAppRunner(BaseAgentRunner): agent_strategy=self.config.strategy, tool_invoke_hook=tool_invoke_hook, instruction=instruction, - tenant_id=self.tenant_id, - invoke_from=self.application_generate_entity.invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) # Initialize state variables current_agent_thought_id = None + has_published_thought = False current_tool_name: str | None = None self._current_message_file_ids: list[str] = [] @@ -121,6 +118,7 @@ class AgentAppRunner(BaseAgentRunner): prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, stop=app_generate_entity.model_conf.stop, + stream=True, ) # Consume generator and collect result @@ -135,10 +133,17 @@ class AgentAppRunner(BaseAgentRunner): break if isinstance(output, LLMResultChunk): - # No more expect streaming data - continue + # Handle LLM chunk + if current_agent_thought_id and not has_published_thought: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + has_published_thought = True - else: + yield output + + elif isinstance(output, AgentLog): # Handle Agent Log using log_type for type-safe dispatch if output.status == AgentLog.LogStatus.START: if output.log_type == AgentLog.LogType.ROUND: @@ -151,6 +156,7 @@ class AgentAppRunner(BaseAgentRunner): tool_input="", messages_ids=message_file_ids, ) + has_published_thought = False elif output.log_type == AgentLog.LogType.TOOL_CALL: if current_agent_thought_id is None: @@ -258,30 +264,23 @@ class AgentAppRunner(BaseAgentRunner): raise # Process final result - if not isinstance(result, AgentResult): - raise ValueError("Agent did not return AgentResult") - output_payload = result.output - if isinstance(output_payload, dict): - final_answer = json.dumps(output_payload, ensure_ascii=False) - elif isinstance(output_payload, str): - final_answer = output_payload - else: - raise ValueError("Final output is not a string or structured data.") - usage = result.usage or LLMUsage.empty_usage() + if isinstance(result, AgentResult): + final_answer = result.text + usage = result.usage or LLMUsage.empty_usage() - # Publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=self.model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=usage, - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, - ) + # Publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e5da2b3e12..b5459611b1 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -7,7 +7,6 @@ from typing import Union, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext -from core.agent.output_tools import build_agent_output_tools from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -37,7 +36,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( - ToolInvokeFrom, ToolParameter, ) from core.tools.tool_manager import ToolManager @@ -253,14 +251,6 @@ class BaseAgentRunner(AppRunner): # save tool entity tool_instances[dataset_tool.entity.identity.name] = dataset_tool - output_tools = build_agent_output_tools( - tenant_id=self.tenant_id, - invoke_from=self.application_generate_entity.invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT, - ) - for tool in output_tools: - tool_instances[tool.entity.identity.name] = tool - return tool_instances, prompt_messages_tools def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index ad8bdb723a..46af4d2d72 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,11 +1,10 @@ import uuid from collections.abc import Mapping from enum import StrEnum -from typing import Any +from typing import Any, Union from pydantic import BaseModel, Field -from core.agent.output_tools import FINAL_OUTPUT_TOOL from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType @@ -42,9 +41,9 @@ class AgentScratchpadUnit(BaseModel): """ action_name: str - action_input: dict[str, Any] | str + action_input: Union[dict, str] - def to_dict(self) -> dict[str, Any]: + def to_dict(self): """ Convert to dictionary. """ @@ -63,9 +62,9 @@ class AgentScratchpadUnit(BaseModel): """ Check if the scratchpad unit is final. """ - if self.action is None: - return False - return self.action.action_name.lower() == FINAL_OUTPUT_TOOL + return self.action is None or ( + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() + ) class AgentEntity(BaseModel): @@ -126,7 +125,7 @@ class ExecutionContext(BaseModel): "tenant_id": self.tenant_id, } - def with_updates(self, **kwargs: Any) -> "ExecutionContext": + def with_updates(self, **kwargs) -> "ExecutionContext": """Create a new context with updated fields.""" data = self.to_dict() data.update(kwargs) @@ -179,23 +178,12 @@ class AgentLog(BaseModel): metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log") -class AgentOutputKind(StrEnum): - """ - Agent output kind. - """ - - OUTPUT_TEXT = "output_text" - FINAL_OUTPUT_ANSWER = "final_output_answer" - FINAL_STRUCTURED_OUTPUT = "final_structured_output" - ILLEGAL_OUTPUT = "illegal_output" - - class AgentResult(BaseModel): """ Agent execution result. """ - output: str | dict = Field(default="", description="The generated output") + text: str = Field(default="", description="The generated text") files: list[Any] = Field(default_factory=list, description="Files produced during execution") usage: Any | None = Field(default=None, description="LLM usage statistics") finish_reason: str | None = Field(default=None, description="Reason for completion") diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 2d4ce58b6a..7c8f09e6b9 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -1,7 +1,7 @@ import json import re from collections.abc import Generator -from typing import Any, Union, cast +from typing import Union from core.agent.entities import AgentScratchpadUnit from core.model_runtime.entities.llm_entities import LLMResultChunk @@ -10,52 +10,46 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod def handle_react_stream_output( - cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any] + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: - def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]: - action_name: str | None = None - action_input: Any | None = None - parsed_action: Any = action - if isinstance(parsed_action, str): + def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: + action_name = None + action_input = None + if isinstance(action, str): try: - parsed_action = json.loads(parsed_action, strict=False) + action = json.loads(action, strict=False) except json.JSONDecodeError: - return parsed_action or "" + return action or "" # cohere always returns a list - if isinstance(parsed_action, list): - action_list: list[Any] = cast(list[Any], parsed_action) - if len(action_list) == 1: - parsed_action = action_list[0] + if isinstance(action, list) and len(action) == 1: + action = action[0] - if isinstance(parsed_action, dict): - action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action) - for key, value in action_dict.items(): - if "input" in key.lower(): - action_input = value - elif isinstance(value, str): - action_name = value - else: - return json.dumps(parsed_action) + for key, value in action.items(): + if "input" in key.lower(): + action_input = value + else: + action_name = value if action_name is not None and action_input is not None: return AgentScratchpadUnit.Action( action_name=action_name, action_input=action_input, ) - return json.dumps(parsed_action) + else: + return json.dumps(action) - def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]: + def extra_json_from_code_block(code_block) -> list[Union[list, dict]]: blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE) if not blocks: return [] try: - json_blocks: list[dict[str, Any] | list[Any]] = [] + json_blocks = [] for block in blocks: json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) json_blocks.append(json.loads(json_text, strict=False)) return json_blocks - except Exception: + except: return [] code_block_cache = "" diff --git a/api/core/agent/output_tools.py b/api/core/agent/output_tools.py deleted file mode 100644 index 253d138c3a..0000000000 --- a/api/core/agent/output_tools.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from typing import Any - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.__base.tool import Tool -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType -from core.tools.tool_manager import ToolManager - -OUTPUT_TOOL_PROVIDER = "agent_output" - -OUTPUT_TEXT_TOOL = "output_text" -FINAL_OUTPUT_TOOL = "final_output_answer" -FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output" -ILLEGAL_OUTPUT_TOOL = "illegal_output" - -OUTPUT_TOOL_NAMES: Sequence[str] = ( - OUTPUT_TEXT_TOOL, - FINAL_OUTPUT_TOOL, - FINAL_STRUCTURED_OUTPUT_TOOL, - ILLEGAL_OUTPUT_TOOL, -) - -TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL) - - -TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session." - - -def is_terminal_output_tool(tool_name: str) -> bool: - return tool_name in TERMINAL_OUTPUT_TOOL_NAMES - - -def get_terminal_tool_name(structured_output_enabled: bool) -> str: - return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL - - -OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES) - - -def build_agent_output_tools( - tenant_id: str, - invoke_from: InvokeFrom, - tool_invoke_from: ToolInvokeFrom, - structured_output_schema: Mapping[str, Any] | None = None, -) -> list[Tool]: - def get_tool_runtime(_tool_name: str) -> Tool: - return ToolManager.get_tool_runtime( - provider_type=ToolProviderType.BUILT_IN, - provider_id=OUTPUT_TOOL_PROVIDER, - tool_name=_tool_name, - tenant_id=tenant_id, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - - tools: list[Tool] = [ - get_tool_runtime(OUTPUT_TEXT_TOOL), - get_tool_runtime(ILLEGAL_OUTPUT_TOOL), - ] - - if structured_output_schema: - raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL) - raw_tool.entity = raw_tool.entity.model_copy(deep=True) - data_parameter = ToolParameter( - name="data", - type=ToolParameter.ToolParameterType.OBJECT, - form=ToolParameter.ToolParameterForm.LLM, - required=True, - input_schema=dict(structured_output_schema), - label=I18nObject(en_US="__Data", zh_Hans="__Data"), - ) - raw_tool.entity.parameters = [data_parameter] - - def invoke_tool( - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: str | None = None, - app_id: str | None = None, - message_id: str | None = None, - ) -> ToolInvokeMessage: - data = tool_parameters["data"] - if not data: - return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required")) - if not isinstance(data, dict): - return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict")) - return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE)) - - raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage] - tools.append(raw_tool) - else: - raw_tool = get_tool_runtime(FINAL_OUTPUT_TOOL) - - def invoke_tool( - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: str | None = None, - app_id: str | None = None, - message_id: str | None = None, - ) -> ToolInvokeMessage: - return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE)) - - raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage] - tools.append(raw_tool) - - return tools diff --git a/api/core/agent/patterns/base.py b/api/core/agent/patterns/base.py index a97b796837..03c349a475 100644 --- a/api/core/agent/patterns/base.py +++ b/api/core/agent/patterns/base.py @@ -10,7 +10,6 @@ from collections.abc import Callable, Generator from typing import TYPE_CHECKING, Any from core.agent.entities import AgentLog, AgentResult, ExecutionContext -from core.agent.output_tools import ILLEGAL_OUTPUT_TOOL from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities import ( @@ -57,7 +56,8 @@ class AgentPattern(ABC): @abstractmethod def run( - self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] + self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [], + stream: bool = True, ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: """Execute the agent strategy.""" pass @@ -462,8 +462,6 @@ class AgentPattern(ABC): """Convert tools to prompt message format.""" prompt_tools: list[PromptMessageTool] = [] for tool in self.tools: - if tool.entity.identity.name == ILLEGAL_OUTPUT_TOOL: - continue prompt_tools.append(tool.to_prompt_message_tool()) return prompt_tools diff --git a/api/core/agent/patterns/function_call.py b/api/core/agent/patterns/function_call.py index 9d58f4e00d..9c87c186d8 100644 --- a/api/core/agent/patterns/function_call.py +++ b/api/core/agent/patterns/function_call.py @@ -1,15 +1,10 @@ """Function Call strategy implementation.""" import json -import uuid from collections.abc import Generator -from typing import Any, Literal, Protocol, Union, cast +from typing import Any, Union from core.agent.entities import AgentLog, AgentResult -from core.agent.output_tools import ( - ILLEGAL_OUTPUT_TOOL, - TERMINAL_OUTPUT_MESSAGE, -) from core.file import File from core.model_runtime.entities import ( AssistantPromptMessage, @@ -30,7 +25,8 @@ class FunctionCallStrategy(AgentPattern): """Function Call strategy using model's native tool calling capability.""" def run( - self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] + self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [], + stream: bool = True, ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: """Execute the function call agent strategy.""" # Convert tools to prompt format @@ -43,23 +39,9 @@ class FunctionCallStrategy(AgentPattern): total_usage: dict[str, LLMUsage | None] = {"usage": None} messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy final_text: str = "" - final_tool_args: dict[str, Any] = {"!!!": "!!!"} finish_reason: str | None = None output_files: list[File] = [] # Track files produced by tools - class _LLMInvoker(Protocol): - def invoke_llm( - self, - *, - prompt_messages: list[PromptMessage], - model_parameters: dict[str, Any], - tools: list[PromptMessageTool], - stop: list[str], - stream: Literal[False], - user: str | None, - callbacks: list[Any], - ) -> LLMResult: ... - while function_call_state and iteration_step <= max_iterations: function_call_state = False round_log = self._create_log( @@ -69,7 +51,8 @@ class FunctionCallStrategy(AgentPattern): data={}, ) yield round_log - + # On last iteration, remove tools to force final answer + current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools model_log = self._create_log( label=f"{self.model_instance.model} Thought", log_type=AgentLog.LogType.THOUGHT, @@ -86,63 +69,47 @@ class FunctionCallStrategy(AgentPattern): round_usage: dict[str, LLMUsage | None] = {"usage": None} # Invoke model - invoker = cast(_LLMInvoker, self.model_instance) - chunks = invoker.invoke_llm( + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( prompt_messages=messages, model_parameters=model_parameters, - tools=prompt_tools, + tools=current_tools, stop=stop, - stream=False, + stream=stream, user=self.context.user_id, callbacks=[], ) # Process response tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks( - chunks, round_usage, model_log, emit_chunks=False + chunks, round_usage, model_log ) - - if response_content: - replaced_tool_call = ( - str(uuid.uuid4()), - ILLEGAL_OUTPUT_TOOL, - { - "raw": response_content, - }, - ) - tool_calls.append(replaced_tool_call) - - messages.append(self._create_assistant_message("", tool_calls)) + messages.append(self._create_assistant_message(response_content, tool_calls)) # Accumulate to total usage round_usage_value = round_usage.get("usage") if round_usage_value: self._accumulate_usage(total_usage, round_usage_value) + # Update final text if no tool calls (this is likely the final answer) + if not tool_calls: + final_text = response_content + # Update finish reason if chunk_finish_reason: finish_reason = chunk_finish_reason - assert len(tool_calls) > 0 - # Process tool calls tool_outputs: dict[str, str] = {} - function_call_state = True - # Execute tools - for tool_call_id, tool_name, tool_args in tool_calls: - tool_response, tool_files, _ = yield from self._handle_tool_call( - tool_name, tool_args, tool_call_id, messages, round_log - ) - tool_entity = self._find_tool_by_name(tool_name) - if not tool_entity: - raise ValueError(f"Tool {tool_name} not found") - tool_outputs[tool_name] = tool_response - # Track files produced by tools - output_files.extend(tool_files) - if tool_response == TERMINAL_OUTPUT_MESSAGE: - function_call_state = False - final_tool_args = tool_entity.transform_tool_parameters_type(tool_args) - + if tool_calls: + function_call_state = True + # Execute tools + for tool_call_id, tool_name, tool_args in tool_calls: + tool_response, tool_files, _ = yield from self._handle_tool_call( + tool_name, tool_args, tool_call_id, messages, round_log + ) + tool_outputs[tool_name] = tool_response + # Track files produced by tools + output_files.extend(tool_files) yield self._finish_log( round_log, data={ @@ -161,19 +128,8 @@ class FunctionCallStrategy(AgentPattern): # Return final result from core.agent.entities import AgentResult - output_payload: str | dict - output_text = final_tool_args.get("text") - output_structured_payload = final_tool_args.get("data") - - if isinstance(output_structured_payload, dict): - output_payload = output_structured_payload - elif isinstance(output_text, str): - output_payload = output_text - else: - raise ValueError(f"Final output ({final_tool_args}) is not a string or structured data.") - return AgentResult( - output=output_payload, + text=final_text, files=output_files, usage=total_usage.get("usage") or LLMUsage.empty_usage(), finish_reason=finish_reason, @@ -184,8 +140,6 @@ class FunctionCallStrategy(AgentPattern): chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], llm_usage: dict[str, LLMUsage | None], start_log: AgentLog, - *, - emit_chunks: bool, ) -> Generator[ LLMResultChunk | AgentLog, None, @@ -217,8 +171,7 @@ class FunctionCallStrategy(AgentPattern): if chunk.delta.finish_reason: finish_reason = chunk.delta.finish_reason - if emit_chunks: - yield chunk + yield chunk else: # Non-streaming response result: LLMResult = chunks @@ -233,12 +186,11 @@ class FunctionCallStrategy(AgentPattern): self._accumulate_usage(llm_usage, result.usage) # Convert to streaming format - if emit_chunks: - yield LLMResultChunk( - model=result.model, - prompt_messages=result.prompt_messages, - delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), - ) + yield LLMResultChunk( + model=result.model, + prompt_messages=result.prompt_messages, + delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage), + ) yield self._finish_log( start_log, data={ @@ -248,14 +200,6 @@ class FunctionCallStrategy(AgentPattern): ) return tool_calls, response_content, finish_reason - @staticmethod - def _format_output_text(value: Any) -> str: - if value is None: - return "" - if isinstance(value, str): - return value - return json.dumps(value, ensure_ascii=False) - def _create_assistant_message( self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None ) -> AssistantPromptMessage: diff --git a/api/core/agent/patterns/react.py b/api/core/agent/patterns/react.py index 88b4e8e930..b81d572aea 100644 --- a/api/core/agent/patterns/react.py +++ b/api/core/agent/patterns/react.py @@ -4,17 +4,10 @@ from __future__ import annotations import json from collections.abc import Generator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from core.agent.output_tools import ( - FINAL_OUTPUT_TOOL, - FINAL_STRUCTURED_OUTPUT_TOOL, - ILLEGAL_OUTPUT_TOOL, - OUTPUT_TEXT_TOOL, - OUTPUT_TOOL_NAME_SET, -) from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities import ( @@ -59,7 +52,8 @@ class ReActStrategy(AgentPattern): self.instruction = instruction def run( - self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] + self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [], + stream: bool = True, ) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]: """Execute the ReAct agent strategy.""" # Initialize tracking @@ -71,19 +65,6 @@ class ReActStrategy(AgentPattern): output_files: list[File] = [] # Track files produced by tools final_text: str = "" finish_reason: str | None = None - tool_instance_names = {tool.entity.identity.name for tool in self.tools} - available_output_tool_names = { - tool_name - for tool_name in tool_instance_names - if tool_name in OUTPUT_TOOL_NAME_SET and tool_name != ILLEGAL_OUTPUT_TOOL - } - if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names: - terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL - elif FINAL_OUTPUT_TOOL in available_output_tool_names: - terminal_tool_name = FINAL_OUTPUT_TOOL - else: - raise ValueError("No terminal output tool configured") - allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names # Add "Observation" to stop sequences if "Observation" not in stop: @@ -100,15 +81,10 @@ class ReActStrategy(AgentPattern): ) yield round_log - # Build prompt with tool restrictions on last iteration - if iteration_step == max_iterations: - tools_for_prompt = [ - tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names - ] - else: - tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL] + # Build prompt with/without tools based on iteration + include_tools = iteration_step < max_iterations current_messages = self._build_prompt_with_react_format( - prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction + prompt_messages, agent_scratchpad, include_tools, self.instruction ) model_log = self._create_log( @@ -130,18 +106,18 @@ class ReActStrategy(AgentPattern): messages_to_use = current_messages # Invoke model - chunks = self.model_instance.invoke_llm( + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm( prompt_messages=messages_to_use, model_parameters=model_parameters, stop=stop, - stream=False, + stream=stream, user=self.context.user_id or "", callbacks=[], ) # Process response scratchpad, chunk_finish_reason = yield from self._handle_chunks( - chunks, round_usage, model_log, current_messages, emit_chunks=False + chunks, round_usage, model_log, current_messages ) agent_scratchpad.append(scratchpad) @@ -155,44 +131,28 @@ class ReActStrategy(AgentPattern): finish_reason = chunk_finish_reason # Check if we have an action to execute - if scratchpad.action is None: - if not allow_illegal_output: - raise ValueError("Model did not call any tools") - illegal_action = AgentScratchpadUnit.Action( - action_name=ILLEGAL_OUTPUT_TOOL, - action_input={"raw": scratchpad.thought or ""}, - ) - scratchpad.action = illegal_action - scratchpad.action_str = illegal_action.model_dump_json() + if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": react_state = True - observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log) - scratchpad.observation = observation - output_files.extend(tool_files) - else: - action_name = scratchpad.action.action_name - if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict): - pass # output_text_payload = scratchpad.action.action_input.get("text") - elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict): - data = scratchpad.action.action_input.get("data") - if isinstance(data, dict): - pass # structured_output_payload = data - elif action_name == FINAL_OUTPUT_TOOL: - if isinstance(scratchpad.action.action_input, dict): - final_text = self._format_output_text(scratchpad.action.action_input.get("text")) - else: - final_text = self._format_output_text(scratchpad.action.action_input) - + # Execute tool observation, tool_files = yield from self._handle_tool_call( scratchpad.action, current_messages, round_log ) scratchpad.observation = observation + # Track files produced by tools output_files.extend(tool_files) - if action_name == terminal_tool_name: - pass # terminal_output_seen = True - react_state = False - else: - react_state = True + # Add observation to scratchpad for display + yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages) + else: + # Extract final answer + if scratchpad.action and scratchpad.action.action_input: + final_answer = scratchpad.action.action_input + if isinstance(final_answer, dict): + final_answer = json.dumps(final_answer, ensure_ascii=False) + final_text = str(final_answer) + elif scratchpad.thought: + # If no action but we have thought, use thought as final answer + final_text = scratchpad.thought yield self._finish_log( round_log, @@ -208,22 +168,17 @@ class ReActStrategy(AgentPattern): # Return final result - output_payload: str | dict - - # TODO + from core.agent.entities import AgentResult return AgentResult( - output=output_payload, - files=output_files, - usage=total_usage.get("usage"), - finish_reason=finish_reason, + text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason ) def _build_prompt_with_react_format( self, original_messages: list[PromptMessage], agent_scratchpad: list[AgentScratchpadUnit], - tools: list[Tool] | None, + include_tools: bool = True, instruction: str = "", ) -> list[PromptMessage]: """Build prompt messages with ReAct format.""" @@ -240,13 +195,9 @@ class ReActStrategy(AgentPattern): # Format tools tools_str = "" tool_names = [] - if tools: + if include_tools and self.tools: # Convert tools to prompt message tools format - prompt_tools = [ - tool.to_prompt_message_tool() - for tool in tools - if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL - ] + prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools] tool_names = [tool.name for tool in prompt_tools] # Format tools as JSON for comprehensive information @@ -258,19 +209,12 @@ class ReActStrategy(AgentPattern): tools_str = "No tools available" tool_names_str = "" - final_tool_name = FINAL_OUTPUT_TOOL - if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names: - final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL - if final_tool_name not in tool_names: - raise ValueError("No terminal output tool available for prompt") - # Replace placeholders in the existing system prompt updated_content = msg.content assert isinstance(updated_content, str) updated_content = updated_content.replace("{{instruction}}", instruction) updated_content = updated_content.replace("{{tools}}", tools_str) updated_content = updated_content.replace("{{tool_names}}", tool_names_str) - updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name) # Create new SystemPromptMessage with updated content messages[i] = SystemPromptMessage(content=updated_content) @@ -302,12 +246,10 @@ class ReActStrategy(AgentPattern): def _handle_chunks( self, - chunks: LLMResult, + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult], llm_usage: dict[str, Any], model_log: AgentLog, current_messages: list[PromptMessage], - *, - emit_chunks: bool, ) -> Generator[ LLMResultChunk | AgentLog, None, @@ -319,20 +261,25 @@ class ReActStrategy(AgentPattern): """ usage_dict: dict[str, Any] = {} - def result_to_chunks() -> Generator[LLMResultChunk, None, None]: - yield LLMResultChunk( - model=chunks.model, - prompt_messages=chunks.prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=chunks.message, - usage=chunks.usage, - finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do - ), - system_fingerprint=chunks.system_fingerprint or "", - ) + # Convert non-streaming to streaming format if needed + if isinstance(chunks, LLMResult): + # Create a generator from the LLMResult + def result_to_chunks() -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=chunks.model, + prompt_messages=chunks.prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=chunks.message, + usage=chunks.usage, + finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do + ), + system_fingerprint=chunks.system_fingerprint or "", + ) - streaming_chunks = result_to_chunks() + streaming_chunks = result_to_chunks() + else: + streaming_chunks = chunks react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict) @@ -356,18 +303,14 @@ class ReActStrategy(AgentPattern): scratchpad.action_str = action_str scratchpad.action = chunk - if emit_chunks: - yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) - elif isinstance(chunk, str): + yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages) + else: # Text chunk chunk_text = str(chunk) scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text scratchpad.thought = (scratchpad.thought or "") + chunk_text - if emit_chunks: - yield self._create_text_chunk(chunk_text, current_messages) - else: - raise ValueError(f"Unexpected chunk type: {type(chunk)}") + yield self._create_text_chunk(chunk_text, current_messages) # Update usage if usage_dict.get("usage"): @@ -391,14 +334,6 @@ class ReActStrategy(AgentPattern): return scratchpad, finish_reason - @staticmethod - def _format_output_text(value: Any) -> str: - if value is None: - return "" - if isinstance(value, str): - return value - return json.dumps(value, ensure_ascii=False) - def _handle_tool_call( self, action: AgentScratchpadUnit.Action, diff --git a/api/core/agent/patterns/strategy_factory.py b/api/core/agent/patterns/strategy_factory.py index fbb550534d..2ec845c9b0 100644 --- a/api/core/agent/patterns/strategy_factory.py +++ b/api/core/agent/patterns/strategy_factory.py @@ -2,17 +2,13 @@ from __future__ import annotations -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from core.agent.entities import AgentEntity, ExecutionContext from core.file.models import File from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelFeature -from ...app.entities.app_invoke_entities import InvokeFrom -from ...tools.entities.tool_entities import ToolInvokeFrom -from ..output_tools import build_agent_output_tools from .base import AgentPattern, ToolInvokeHook from .function_call import FunctionCallStrategy from .react import ReActStrategy @@ -29,10 +25,6 @@ class StrategyFactory: @staticmethod def create_strategy( - *, - tenant_id: str, - invoke_from: InvokeFrom, - tool_invoke_from: ToolInvokeFrom, model_features: list[ModelFeature], model_instance: ModelInstance, context: ExecutionContext, @@ -43,16 +35,11 @@ class StrategyFactory: agent_strategy: AgentEntity.Strategy | None = None, tool_invoke_hook: ToolInvokeHook | None = None, instruction: str = "", - structured_output_schema: Mapping[str, Any] | None = None, ) -> AgentPattern: """ Create an appropriate strategy based on model features. Args: - tenant_id: - invoke_from: - tool_invoke_from: - structured_output_schema: model_features: List of model features/capabilities model_instance: Model instance to use context: Execution context containing trace/audit information @@ -67,14 +54,6 @@ class StrategyFactory: Returns: AgentStrategy instance """ - output_tools = build_agent_output_tools( - tenant_id=tenant_id, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - structured_output_schema=structured_output_schema, - ) - - tools.extend(output_tools) # If explicit strategy is provided and it's Function Calling, try to use it if supported if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING: diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index 0bbc9007b7..f5ba2119f4 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -7,7 +7,7 @@ You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish. +Valid "action" values: "Final Answer" or {{tool_names}} Provide only ONE action per $JSON_BLOB, as shown: @@ -32,14 +32,12 @@ Thought: I know what to respond Action: ``` { - "action": "{{final_tool_name}}", - "action_input": { - "text": "Final response to human" - } + "action": "Final Answer", + "action_input": "Final response to human" } ``` -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:. +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. {{historic_messages}} Question: {{query}} {{agent_scratchpad}} @@ -58,7 +56,7 @@ You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish. +Valid "action" values: "Final Answer" or {{tool_names}} Provide only ONE action per $JSON_BLOB, as shown: @@ -83,14 +81,12 @@ Thought: I know what to respond Action: ``` { - "action": "{{final_tool_name}}", - "action_input": { - "text": "Final response to human" - } + "action": "Final Answer", + "action_input": "Final response to human" } ``` -Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:. +Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. """ # noqa: E501 diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 5510aacf02..6cb4f450eb 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -477,6 +477,7 @@ class LLMGenerator: prompt_messages=complete_messages, output_model=CodeNodeStructuredOutput, model_parameters=model_parameters, + tenant_id=tenant_id, ) return { @@ -557,14 +558,10 @@ class LLMGenerator: completion_params = model_config.get("completion_params", {}) if model_config else {} try: - response = invoke_llm_with_pydantic_model( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, - prompt_messages=prompt_messages, - output_model=SuggestedQuestionsOutput, - model_parameters=completion_params, - ) + response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema, + model_instance=model_instance, prompt_messages=prompt_messages, + output_model=SuggestedQuestionsOutput, + model_parameters=completion_params, tenant_id=tenant_id) return {"questions": response.questions, "error": ""} diff --git a/api/core/llm_generator/output_parser/file_ref.py b/api/core/llm_generator/output_parser/file_ref.py index 74e47570c6..83489e6a79 100644 --- a/api/core/llm_generator/output_parser/file_ref.py +++ b/api/core/llm_generator/output_parser/file_ref.py @@ -1,190 +1,188 @@ -from collections.abc import Callable, Mapping, Sequence -from typing import Any, cast +""" +File reference detection and conversion for structured output. + +This module provides utilities to: +1. Detect file reference fields in JSON Schema (format: "dify-file-ref") +2. Convert file ID strings to File objects after LLM returns +""" + +import uuid +from collections.abc import Mapping +from typing import Any from core.file import File from core.variables.segments import ArrayFileSegment, FileSegment +from factories.file_factory import build_from_mapping -FILE_PATH_SCHEMA_TYPE = "file" -FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"} -FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)." +FILE_REF_FORMAT = "dify-file-ref" -def is_file_path_property(schema: Mapping[str, Any]) -> bool: - if schema.get("type") == FILE_PATH_SCHEMA_TYPE: - return True - format_value = schema.get("format") - if not isinstance(format_value, str): - return False - normalized_format = format_value.lower().replace("_", "-") - return normalized_format in FILE_PATH_SCHEMA_FORMATS +def is_file_ref_property(schema: dict) -> bool: + """Check if a schema property is a file reference.""" + return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT -def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]: - file_path_fields: list[str] = [] +def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]: + """ + Recursively detect file reference fields in schema. + + Args: + schema: JSON Schema to analyze + path: Current path in the schema (used for recursion) + + Returns: + List of JSON paths containing file refs, e.g., ["image_id", "files[*]"] + """ + file_ref_paths: list[str] = [] schema_type = schema.get("type") if schema_type == "object": - properties = schema.get("properties") - if isinstance(properties, Mapping): - properties_mapping = cast(Mapping[str, Any], properties) - for prop_name, prop_schema in properties_mapping.items(): - if not isinstance(prop_schema, Mapping): - continue - prop_schema_mapping = cast(Mapping[str, Any], prop_schema) - current_path = f"{path}.{prop_name}" if path else prop_name + for prop_name, prop_schema in schema.get("properties", {}).items(): + current_path = f"{path}.{prop_name}" if path else prop_name - if is_file_path_property(prop_schema_mapping): - file_path_fields.append(current_path) - else: - file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path)) + if is_file_ref_property(prop_schema): + file_ref_paths.append(current_path) + elif isinstance(prop_schema, dict): + file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path)) elif schema_type == "array": - items_schema = schema.get("items") - if not isinstance(items_schema, Mapping): - return file_path_fields - items_schema_mapping = cast(Mapping[str, Any], items_schema) + items_schema = schema.get("items", {}) array_path = f"{path}[*]" if path else "[*]" - if is_file_path_property(items_schema_mapping): - file_path_fields.append(array_path) - else: - file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path)) + if is_file_ref_property(items_schema): + file_ref_paths.append(array_path) + elif isinstance(items_schema, dict): + file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path)) - return file_path_fields + return file_ref_paths -def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: - result = _deep_copy_value(schema) - if not isinstance(result, dict): - raise ValueError("structured_output_schema must be a JSON object") - result_dict = cast(dict[str, Any], result) - - file_path_fields: list[str] = [] - _adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields) - return result_dict, file_path_fields - - -def convert_sandbox_file_paths_in_output( +def convert_file_refs_in_output( output: Mapping[str, Any], - file_path_fields: Sequence[str], - file_resolver: Callable[[str], File], -) -> tuple[dict[str, Any], list[File]]: - if not file_path_fields: - return dict(output), [] + json_schema: Mapping[str, Any], + tenant_id: str, +) -> dict[str, Any]: + """ + Convert file ID strings to File objects based on schema. - result = _deep_copy_value(output) - if not isinstance(result, dict): - raise ValueError("Structured output must be a JSON object") - result_dict = cast(dict[str, Any], result) + Args: + output: The structured_output from LLM result + json_schema: The original JSON schema (to detect file ref fields) + tenant_id: Tenant ID for file lookup - files: list[File] = [] - for path in file_path_fields: - _convert_path_in_place(result_dict, path.split("."), file_resolver, files) + Returns: + Output with file references converted to File objects + """ + file_ref_paths = detect_file_ref_fields(json_schema) + if not file_ref_paths: + return dict(output) - return result_dict, files + result = _deep_copy_dict(output) + + for path in file_ref_paths: + _convert_path_in_place(result, path.split("."), tenant_id) + + return result -def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None: - schema_type = schema.get("type") - - if schema_type == "object": - properties = schema.get("properties") - if isinstance(properties, Mapping): - properties_mapping = cast(Mapping[str, Any], properties) - for prop_name, prop_schema in properties_mapping.items(): - if not isinstance(prop_schema, dict): - continue - prop_schema_dict = cast(dict[str, Any], prop_schema) - current_path = f"{path}.{prop_name}" if path else prop_name - - if is_file_path_property(prop_schema_dict): - _normalize_file_path_schema(prop_schema_dict) - file_path_fields.append(current_path) - else: - _adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields) - - elif schema_type == "array": - items_schema = schema.get("items") - if not isinstance(items_schema, dict): - return - items_schema_dict = cast(dict[str, Any], items_schema) - array_path = f"{path}[*]" if path else "[*]" - - if is_file_path_property(items_schema_dict): - _normalize_file_path_schema(items_schema_dict) - file_path_fields.append(array_path) +def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]: + """Deep copy a mapping to a mutable dict.""" + result: dict[str, Any] = {} + for key, value in obj.items(): + if isinstance(value, Mapping): + result[key] = _deep_copy_dict(value) + elif isinstance(value, list): + result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value] else: - _adapt_schema_in_place(items_schema_dict, array_path, file_path_fields) + result[key] = value + return result -def _normalize_file_path_schema(schema: dict[str, Any]) -> None: - schema["type"] = "string" - schema.pop("format", None) - description = schema.get("description", "") - if description: - schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}" - else: - schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX - - -def _deep_copy_value(value: Any) -> Any: - if isinstance(value, Mapping): - mapping = cast(Mapping[str, Any], value) - return {key: _deep_copy_value(item) for key, item in mapping.items()} - if isinstance(value, list): - list_value = cast(list[Any], value) - return [_deep_copy_value(item) for item in list_value] - return value - - -def _convert_path_in_place( - obj: dict[str, Any], - path_parts: list[str], - file_resolver: Callable[[str], File], - files: list[File], -) -> None: +def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None: + """Convert file refs at the given path in place, wrapping in Segment types.""" if not path_parts: return current = path_parts[0] remaining = path_parts[1:] + # Handle array notation like "files[*]" if current.endswith("[*]"): - key = current[:-3] if current != "[*]" else "" - target_value = obj.get(key) if key else obj + key = current[:-3] if current != "[*]" else None + target = obj.get(key) if key else obj - if isinstance(target_value, list): - target_list = cast(list[Any], target_value) + if isinstance(target, list): if remaining: - for item in target_list: + # Nested array with remaining path - recurse into each item + for item in target: if isinstance(item, dict): - item_dict = cast(dict[str, Any], item) - _convert_path_in_place(item_dict, remaining, file_resolver, files) + _convert_path_in_place(item, remaining, tenant_id) else: - resolved_files: list[File] = [] - for item in target_list: - if not isinstance(item, str): - raise ValueError("File path must be a string") - file = file_resolver(item) - files.append(file) - resolved_files.append(file) + # Array of file IDs - convert all and wrap in ArrayFileSegment + files: list[File] = [] + for item in target: + file = _convert_file_id(item, tenant_id) + if file is not None: + files.append(file) + # Replace the array with ArrayFileSegment if key: - obj[key] = ArrayFileSegment(value=resolved_files) + obj[key] = ArrayFileSegment(value=files) return if not remaining: - if current not in obj: - return - value = obj[current] - if value is None: - obj[current] = None - return - if not isinstance(value, str): - raise ValueError("File path must be a string") - file = file_resolver(value) - files.append(file) - obj[current] = FileSegment(value=file) - return + # Leaf node - convert the value and wrap in FileSegment + if current in obj: + file = _convert_file_id(obj[current], tenant_id) + if file is not None: + obj[current] = FileSegment(value=file) + else: + obj[current] = None + else: + # Recurse into nested object + if current in obj and isinstance(obj[current], dict): + _convert_path_in_place(obj[current], remaining, tenant_id) - if current in obj and isinstance(obj[current], dict): - _convert_path_in_place(obj[current], remaining, file_resolver, files) + +def _convert_file_id(file_id: Any, tenant_id: str) -> File | None: + """ + Convert a file ID string to a File object. + + Tries multiple file sources in order: + 1. ToolFile (files generated by tools/workflows) + 2. UploadFile (files uploaded by users) + """ + if not isinstance(file_id, str): + return None + + # Validate UUID format + try: + uuid.UUID(file_id) + except ValueError: + return None + + # Try ToolFile first (files generated by tools/workflows) + try: + return build_from_mapping( + mapping={ + "transfer_method": "tool_file", + "tool_file_id": file_id, + }, + tenant_id=tenant_id, + ) + except ValueError: + pass + + # Try UploadFile (files uploaded by users) + try: + return build_from_mapping( + mapping={ + "transfer_method": "local_file", + "upload_file_id": file_id, + }, + tenant_id=tenant_id, + ) + except ValueError: + pass + + # File not found in any source + return None diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index c790feda53..a069f0409c 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -8,7 +8,7 @@ import json_repair from pydantic import BaseModel, TypeAdapter, ValidationError from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.file_ref import detect_file_path_fields +from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT from core.model_manager import ModelInstance from core.model_runtime.callbacks.base_callback import Callback @@ -55,11 +55,12 @@ def invoke_llm_with_structured_output( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], json_schema: Mapping[str, Any], - model_parameters: Mapping[str, Any] | None = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, user: str | None = None, callbacks: list[Callback] | None = None, + tenant_id: str | None = None, ) -> LLMResultWithStructuredOutput: """ Invoke large language model with structured output. @@ -77,12 +78,14 @@ def invoke_llm_with_structured_output( :param stop: stop words :param user: unique user id :param callbacks: callbacks - :return: response with structured output + :param tenant_id: tenant ID for file reference conversion. When provided and + json_schema contains file reference fields (format: "dify-file-ref"), + file IDs in the output will be automatically converted to File objects. + :return: full response or stream response chunk generator result """ - model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {}) - - if detect_file_path_fields(json_schema): - raise OutputParserError("Structured output file paths are only supported in sandbox mode.") + model_parameters_with_json_schema: dict[str, Any] = { + **(model_parameters or {}), + } # Determine structured output strategy @@ -119,6 +122,14 @@ def invoke_llm_with_structured_output( # Fill missing fields with default values structured_output = fill_defaults_from_schema(structured_output, json_schema) + # Convert file references if tenant_id is provided + if tenant_id is not None: + structured_output = convert_file_refs_in_output( + output=structured_output, + json_schema=json_schema, + tenant_id=tenant_id, + ) + return LLMResultWithStructuredOutput( structured_output=structured_output, model=llm_result.model, @@ -136,11 +147,12 @@ def invoke_llm_with_pydantic_model( model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], output_model: type[T], - model_parameters: Mapping[str, Any] | None = None, + model_parameters: Mapping | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, user: str | None = None, callbacks: list[Callback] | None = None, + tenant_id: str | None = None, ) -> T: """ Invoke large language model with a Pydantic output model. @@ -148,8 +160,11 @@ def invoke_llm_with_pydantic_model( This helper generates a JSON schema from the Pydantic model, invokes the structured-output LLM path, and validates the result. - The helper performs a non-streaming invocation and returns the validated - Pydantic model directly. + The stream parameter controls the underlying LLM invocation mode: + - stream=True (default): Uses streaming LLM call, consumes the generator internally + - stream=False: Uses non-streaming LLM call + + In both cases, the function returns the validated Pydantic model directly. """ json_schema = _schema_from_pydantic(output_model) @@ -164,6 +179,7 @@ def invoke_llm_with_pydantic_model( stop=stop, user=user, callbacks=callbacks, + tenant_id=tenant_id, ) structured_output = result.structured_output @@ -220,27 +236,25 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]: return _parse_structured_output(content) -def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]: +def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]: """Parse JSON from tool call arguments.""" if not arguments: raise OutputParserError("Tool call arguments is empty") try: - parsed_any = json.loads(arguments) - if not isinstance(parsed_any, dict): + parsed = json.loads(arguments) + if not isinstance(parsed, dict): raise OutputParserError(f"Tool call arguments is not a dict: {arguments}") - parsed = cast(dict[str, Any], parsed_any) return parsed except json.JSONDecodeError: # Try to repair malformed JSON - repaired_any = json_repair.loads(arguments) - if not isinstance(repaired_any, dict): + repaired = json_repair.loads(arguments) + if not isinstance(repaired, dict): raise OutputParserError(f"Failed to parse tool call arguments: {arguments}") - repaired: dict[str, Any] = repaired_any return repaired -def get_default_value_for_type(type_name: str | list[str] | None) -> Any: +def _get_default_value_for_type(type_name: str | list[str] | None) -> Any: """Get default empty value for a JSON schema type.""" # Handle array of types (e.g., ["string", "null"]) if isinstance(type_name, list): @@ -297,7 +311,7 @@ def fill_defaults_from_schema( # Create empty object and recursively fill its required fields result[prop_name] = fill_defaults_from_schema({}, prop_schema) else: - result[prop_name] = get_default_value_for_type(prop_type) + result[prop_name] = _get_default_value_for_type(prop_type) elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema: # Field exists and is an object, recursively fill nested required fields result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema) @@ -308,10 +322,10 @@ def fill_defaults_from_schema( def _handle_native_json_schema( provider: str, model_schema: AIModelEntity, - structured_output_schema: Mapping[str, Any], - model_parameters: dict[str, Any], + structured_output_schema: Mapping, + model_parameters: dict, rules: list[ParameterRule], -) -> dict[str, Any]: +): """ Handle structured output for models with native JSON schema support. @@ -333,7 +347,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None: +def _set_response_format(model_parameters: dict, rules: list): """ Set the appropriate response format parameter based on model rules. @@ -349,7 +363,7 @@ def _set_response_format(model_parameters: dict[str, Any], rules: list[Parameter def _handle_prompt_based_schema( - prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any] + prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping ) -> list[PromptMessage]: """ Handle structured output for models without native JSON schema support. @@ -386,27 +400,28 @@ def _handle_prompt_based_schema( return updated_prompt -def _parse_structured_output(result_text: str) -> dict[str, Any]: +def _parse_structured_output(result_text: str) -> Mapping[str, Any]: + structured_output: Mapping[str, Any] = {} + parsed: Mapping[str, Any] = {} try: - parsed = TypeAdapter(dict[str, Any]).validate_json(result_text) - return parsed + parsed = TypeAdapter(Mapping).validate_json(result_text) + if not isinstance(parsed, dict): + raise OutputParserError(f"Failed to parse structured output: {result_text}") + structured_output = parsed except ValidationError: # if the result_text is not a valid json, try to repair it - temp_parsed: Any = json_repair.loads(result_text) - if isinstance(temp_parsed, list): - temp_parsed_list = cast(list[Any], temp_parsed) - dict_items: list[dict[str, Any]] = [] - for item in temp_parsed_list: - if isinstance(item, dict): - dict_items.append(cast(dict[str, Any], item)) - temp_parsed = dict_items[0] if dict_items else {} + temp_parsed = json_repair.loads(result_text) if not isinstance(temp_parsed, dict): - raise OutputParserError(f"Failed to parse structured output: {result_text}") - temp_parsed_dict = cast(dict[str, Any], temp_parsed) - return temp_parsed_dict + # handle reasoning model like deepseek-r1 got '\n\n\n' prefix + if isinstance(temp_parsed, list): + temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {}) + else: + raise OutputParserError(f"Failed to parse structured output: {result_text}") + structured_output = cast(dict, temp_parsed) + return structured_output -def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]: +def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping): """ Prepare JSON schema based on model requirements. @@ -418,49 +433,54 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema """ # Deep copy to avoid modifying the original schema - processed_schema = deepcopy(schema) - processed_schema_dict = dict(processed_schema) + processed_schema = dict(deepcopy(schema)) # Convert boolean types to string types (common requirement) - convert_boolean_to_string(processed_schema_dict) + convert_boolean_to_string(processed_schema) # Apply model-specific transformations if SpecialModelType.GEMINI in model_schema.model: - remove_additional_properties(processed_schema_dict) - return processed_schema_dict - if SpecialModelType.OLLAMA in provider: - return processed_schema_dict - - # Default format with name field - return {"schema": processed_schema_dict, "name": "llm_response"} + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} -def remove_additional_properties(schema: dict[str, Any]) -> None: +def remove_additional_properties(schema: dict): """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. :param schema: JSON schema to modify in-place """ + if not isinstance(schema, dict): + return + # Remove additionalProperties at current level schema.pop("additionalProperties", None) # Process nested structures recursively for value in schema.values(): if isinstance(value, dict): - remove_additional_properties(cast(dict[str, Any], value)) + remove_additional_properties(value) elif isinstance(value, list): - for item in cast(list[Any], value): + for item in value: if isinstance(item, dict): - remove_additional_properties(cast(dict[str, Any], item)) + remove_additional_properties(item) -def convert_boolean_to_string(schema: dict[str, Any]) -> None: +def convert_boolean_to_string(schema: dict): """ Convert boolean type specifications to string in JSON schema. :param schema: JSON schema to modify in-place """ + if not isinstance(schema, dict): + return + # Check for boolean type at current level if schema.get("type") == "boolean": schema["type"] = "string" @@ -468,8 +488,8 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None: # Process nested dictionaries and lists recursively for value in schema.values(): if isinstance(value, dict): - convert_boolean_to_string(cast(dict[str, Any], value)) + convert_boolean_to_string(value) elif isinstance(value, list): - for item in cast(list[Any], value): + for item in value: if isinstance(item, dict): - convert_boolean_to_string(cast(dict[str, Any], item)) + convert_boolean_to_string(item) diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index 6a266d1d9d..28cb70f96a 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -1,13 +1,11 @@ from collections.abc import Generator from dataclasses import dataclass, field -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar, Union +from core.agent.entities import AgentInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage -if TYPE_CHECKING: - pass - -MessageType = TypeVar("MessageType", bound=ToolInvokeMessage) +MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage]) @dataclass @@ -89,7 +87,7 @@ def merge_blob_chunks( ), meta=resp.meta, ) - assert isinstance(merged_message, ToolInvokeMessage) + assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage)) yield merged_message # type: ignore # Clean up the buffer del files[chunk_id] diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py index 60691790a2..57e790b2c2 100644 --- a/api/core/sandbox/bash/session.py +++ b/api/core/sandbox/bash/session.py @@ -14,8 +14,7 @@ from core.skill.entities import ToolAccessPolicy from core.skill.entities.tool_dependencies import ToolDependencies from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager -from core.virtual_environment.__base.exec import CommandExecutionError -from core.virtual_environment.__base.helpers import execute, pipeline +from core.virtual_environment.__base.helpers import pipeline from ..bash.dify_cli import DifyCliConfig from ..entities import DifyCli @@ -120,6 +119,21 @@ class SandboxBashSession: return self._bash_tool def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]: + """ + Collect files from sandbox output directory and save them as ToolFiles. + + Scans the specified output directory in sandbox, downloads each file, + saves it as a ToolFile, and returns a list of File objects. The File + objects will have valid tool_file_id that can be referenced by subsequent + nodes via structured output. + + Args: + output_dir: Directory path in sandbox to scan for output files. + Defaults to "output" (relative to workspace). + + Returns: + List of File objects representing the collected files. + """ vm = self._sandbox.vm collected_files: list[File] = [] @@ -130,6 +144,8 @@ class SandboxBashSession: logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc) return collected_files + tool_file_manager = ToolFileManager() + for file_state in file_states: # Skip files that are too large if file_state.size > MAX_OUTPUT_FILE_SIZE: @@ -146,14 +162,47 @@ class SandboxBashSession: file_content = vm.download_file(file_state.path) file_binary = file_content.getvalue() + # Determine mime type from extension filename = os.path.basename(file_state.path) - file_obj = self._create_tool_file(filename=filename, file_binary=file_binary) + mime_type, _ = mimetypes.guess_type(filename) + if not mime_type: + mime_type = "application/octet-stream" + + # Save as ToolFile + tool_file = tool_file_manager.create_file_by_raw( + user_id=self._user_id, + tenant_id=self._tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mime_type, + filename=filename, + ) + + # Determine file type from mime type + file_type = _get_file_type_from_mime(mime_type) + extension = os.path.splitext(filename)[1] if "." in filename else ".bin" + url = sign_tool_file(tool_file.id, extension) + + # Create File object with tool_file_id as related_id + file_obj = File( + id=tool_file.id, # Use tool_file_id as the File id for easy reference + tenant_id=self._tenant_id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + filename=filename, + extension=extension, + mime_type=mime_type, + size=len(file_binary), + related_id=tool_file.id, + url=url, + storage_key=tool_file.file_key, + ) collected_files.append(file_obj) logger.info( "Collected sandbox output file: %s -> tool_file_id=%s", file_state.path, - file_obj.id, + tool_file.id, ) except Exception as exc: @@ -167,85 +216,6 @@ class SandboxBashSession: ) return collected_files - def download_file(self, path: str) -> File: - path_kind = self._detect_path_kind(path) - if path_kind == "dir": - raise ValueError("Directory outputs are not supported") - if path_kind != "file": - raise ValueError(f"Sandbox file not found: {path}") - - file_content = self._sandbox.vm.download_file(path) - file_binary = file_content.getvalue() - if len(file_binary) > MAX_OUTPUT_FILE_SIZE: - raise ValueError(f"Sandbox file exceeds size limit: {path}") - - filename = os.path.basename(path) or "file" - return self._create_tool_file(filename=filename, file_binary=file_binary) - - def _detect_path_kind(self, path: str) -> str: - script = r""" -import os -import sys - -p = sys.argv[1] -if os.path.isdir(p): - print("dir") - raise SystemExit(0) -if os.path.isfile(p): - print("file") - raise SystemExit(0) -print("none") -raise SystemExit(2) -""" - try: - result = execute( - self._sandbox.vm, - [ - "sh", - "-c", - 'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"', - script, - path, - ], - timeout=10, - error_message="Failed to inspect sandbox path", - ) - except CommandExecutionError as exc: - raise ValueError(str(exc)) from exc - return result.stdout.decode("utf-8", errors="replace").strip() - - def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File: - mime_type, _ = mimetypes.guess_type(filename) - if not mime_type: - mime_type = "application/octet-stream" - - tool_file = ToolFileManager().create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - conversation_id=None, - file_binary=file_binary, - mimetype=mime_type, - filename=filename, - ) - - file_type = _get_file_type_from_mime(mime_type) - extension = os.path.splitext(filename)[1] if "." in filename else ".bin" - url = sign_tool_file(tool_file.id, extension) - - return File( - id=tool_file.id, - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=filename, - extension=extension, - mime_type=mime_type, - size=len(file_binary), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, - ) - def _get_file_type_from_mime(mime_type: str) -> FileType: """Determine FileType from mime type.""" diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 9c0e8d0963..24fc11aefc 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -57,7 +57,7 @@ class Tool(ABC): tool_parameters.update(self.runtime.runtime_parameters) # try parse tool parameters into the correct type - tool_parameters = self.transform_tool_parameters_type(tool_parameters) + tool_parameters = self._transform_tool_parameters_type(tool_parameters) result = self._invoke( user_id=user_id, @@ -82,7 +82,7 @@ class Tool(ABC): else: return result - def transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: + def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: """ Transform tool parameters type """ diff --git a/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg b/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg deleted file mode 100644 index 82e3b54edd..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/_assets/icon.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/api/core/tools/builtin_tool/providers/agent_output/agent_output.py b/api/core/tools/builtin_tool/providers/agent_output/agent_output.py deleted file mode 100644 index 65c3eb5a1d..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/agent_output.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Any - -from core.tools.builtin_tool.provider import BuiltinToolProviderController - - -class AgentOutputProvider(BuiltinToolProviderController): - def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): - pass diff --git a/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml deleted file mode 100644 index 1e4f624104..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/agent_output.yaml +++ /dev/null @@ -1,10 +0,0 @@ -identity: - author: Dify - name: agent_output - label: - en_US: Agent Output - description: - en_US: Internal tools for agent output control. - icon: icon.svg - tags: - - utilities diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py deleted file mode 100644 index ed5edc9fd3..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Generator -from typing import Any - -from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolInvokeMessage - - -class FinalOutputAnswerTool(BuiltinTool): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: str | None = None, - app_id: str | None = None, - message_id: str | None = None, - ) -> Generator[ToolInvokeMessage, None, None]: - yield self.create_text_message("Final answer recorded.") diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml deleted file mode 100644 index b8e4f14d85..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_output_answer.yaml +++ /dev/null @@ -1,18 +0,0 @@ -identity: - name: final_output_answer - author: Dify - label: - en_US: Final Output Answer -description: - human: - en_US: Internal tool to deliver the final answer. - llm: Use this tool when you are ready to provide the final answer. -parameters: - - name: text - type: string - required: true - label: - en_US: Text - human_description: - en_US: Final answer text. - form: llm diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py deleted file mode 100644 index 916dc9f0bc..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Generator -from typing import Any - -from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolInvokeMessage - - -class FinalStructuredOutputTool(BuiltinTool): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: str | None = None, - app_id: str | None = None, - message_id: str | None = None, - ) -> Generator[ToolInvokeMessage, None, None]: - yield self.create_text_message("Structured output recorded.") diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml deleted file mode 100644 index 79eb2d6412..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/final_structured_output.yaml +++ /dev/null @@ -1,18 +0,0 @@ -identity: - name: final_structured_output - author: Dify - label: - en_US: Final Structured Output -description: - human: - en_US: Internal tool to deliver structured output. - llm: Use this tool to provide structured output data. -parameters: - - name: data - type: object - required: true - label: - en_US: Data - human_description: - en_US: Structured output data. - form: llm diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py deleted file mode 100644 index 2276c527e9..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.py +++ /dev/null @@ -1,21 +0,0 @@ -from collections.abc import Generator -from typing import Any - -from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolInvokeMessage - - -class IllegalOutputTool(BuiltinTool): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: str | None = None, - app_id: str | None = None, - message_id: str | None = None, - ) -> Generator[ToolInvokeMessage, None, None]: - message = ( - "Protocol violation: do not output plain text. " - "Call an output tool and finish with the configured terminal tool." - ) - yield self.create_text_message(message) diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml deleted file mode 100644 index b293a41659..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/illegal_output.yaml +++ /dev/null @@ -1,18 +0,0 @@ -identity: - name: illegal_output - author: Dify - label: - en_US: Illegal Output -description: - human: - en_US: Internal tool for output protocol violations. - llm: Use this tool to correct output protocol violations. -parameters: - - name: raw - type: string - required: true - label: - en_US: Raw Output - human_description: - en_US: Raw model output that violated the protocol. - form: llm diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py deleted file mode 100644 index 942846f507..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Generator -from typing import Any - -from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolInvokeMessage - - -class OutputTextTool(BuiltinTool): - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: str | None = None, - app_id: str | None = None, - message_id: str | None = None, - ) -> Generator[ToolInvokeMessage, None, None]: - yield self.create_text_message("Output recorded.") diff --git a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml b/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml deleted file mode 100644 index 48d237e521..0000000000 --- a/api/core/tools/builtin_tool/providers/agent_output/tools/output_text.yaml +++ /dev/null @@ -1,18 +0,0 @@ -identity: - name: output_text - author: Dify - label: - en_US: Output Text -description: - human: - en_US: Internal tool to store intermediate text output. - llm: Use this tool to emit non-final text output. -parameters: - - name: text - type: string - required: true - label: - en_US: Text - human_description: - en_US: Output text. - form: llm diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b524de34e4..5dae773841 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import json import logging import mimetypes @@ -33,8 +31,9 @@ from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from core.agent.entities import AgentToolEntity from core.workflow.nodes.tool.entities import ToolEntity + +from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered @@ -67,8 +66,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -INTERNAL_BUILTIN_TOOL_PROVIDERS = {"agent_output"} - class ApiProviderControllerItem(TypedDict): provider: ApiToolProvider @@ -364,7 +361,7 @@ class ToolManager: app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional[VariablePool] = None, + variable_pool: Optional["VariablePool"] = None, ) -> Tool: """ get the agent tool runtime @@ -404,9 +401,9 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: ToolEntity, + workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional[VariablePool] = None, + variable_pool: Optional["VariablePool"] = None, ) -> Tool: """ get the workflow tool runtime @@ -594,10 +591,6 @@ class ToolManager: cls._hardcoded_providers = {} cls._builtin_providers_loaded = False - @classmethod - def is_internal_builtin_provider(cls, provider_name: str) -> bool: - return provider_name in INTERNAL_BUILTIN_TOOL_PROVIDERS - @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ @@ -635,9 +628,9 @@ class ToolManager: # MySQL: Use window function to achieve same result sql = """ SELECT id FROM ( - SELECT id, + SELECT id, ROW_NUMBER() OVER ( - PARTITION BY tenant_id, provider + PARTITION BY tenant_id, provider ORDER BY is_default DESC, created_at DESC ) as rn FROM tool_builtin_providers @@ -674,8 +667,6 @@ class ToolManager: # append builtin providers for provider in builtin_providers: - if cls.is_internal_builtin_provider(provider.entity.identity.name): - continue # handle include, exclude if is_filtered( include_set=dify_config.POSITION_TOOL_INCLUDES_SET, @@ -1019,7 +1010,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional[VariablePool], + variable_pool: Optional["VariablePool"], tool_configurations: dict[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index dded96eacf..0827494a48 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -577,12 +577,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, - structured_output_schema=None, + structured_output_enabled=self.node_data.structured_output_enabled, + structured_output=None, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, node_type=self.node_type, - tenant_id=self.tenant_id, ) for event in generator: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index b0eacd00ea..1a33d137d6 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid from core.agent.entities import AgentLog, AgentResult from core.file import File from core.model_runtime.entities import ImagePromptMessageContent, LLMMode -from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage +from core.model_runtime.entities.llm_entities import LLMUsage from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.tools.entities.tool_entities import ToolProviderType from core.workflow.entities import ToolCall, ToolCallResult @@ -156,9 +156,6 @@ class LLMGenerationData(BaseModel): finish_reason: str | None = Field(None, description="Finish reason from LLM") files: list[File] = Field(default_factory=list, description="Generated files") trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order") - structured_output: LLMStructuredOutput | None = Field( - default=None, description="Structured output from tool-only agent runs" - ) class ThinkTagStreamParser: @@ -287,7 +284,6 @@ class AggregatedResult(BaseModel): files: list[File] = Field(default_factory=list) usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) finish_reason: str | None = None - structured_output: LLMStructuredOutput | None = None class AgentContext(BaseModel): @@ -387,7 +383,7 @@ class LLMNodeData(BaseNodeData): Strategy for handling model reasoning output. separated: Return clean text (without tags) + reasoning_content field. - Recommended for new workflows. Enables safe downstream parsing and + Recommended for new workflows. Enables safe downstream parsing and workflow variable access: {{#node_id.reasoning_content#}} tagged : Return original text (with tags) + reasoning_content field. diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index b1460ae402..86bf8f473e 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str: """ Build a text description of generated files for inclusion in context. - The description includes file_id for context; structured output file paths - are only supported in sandbox mode. + The description includes file_id which can be used by subsequent nodes + to reference the files via structured output. """ if not files: return "" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 82c2e1f1c1..feb76e6510 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -13,24 +13,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext -from core.agent.output_tools import ( - FINAL_OUTPUT_TOOL, - FINAL_STRUCTURED_OUTPUT_TOOL, - OUTPUT_TEXT_TOOL, - build_agent_output_tools, -) from core.agent.patterns import StrategyFactory from core.app.entities.app_asset_entities import AppAssetFileTree from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app_assets.constants import AppAssetsAttrs -from core.file import FileTransferMethod, FileType, file_manager +from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.file_ref import ( - adapt_schema_for_sandbox_file_paths, - convert_sandbox_file_paths_in_output, - detect_file_path_fields, -) from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.memory.base import BaseMemory from core.model_manager import ModelInstance, ModelManager @@ -73,7 +62,6 @@ from core.skill.entities.skill_document import SkillDocument from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency from core.skill.skill_compiler import SkillCompiler from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeFrom from core.tools.signature import sign_upload_file from core.tools.tool_manager import ToolManager from core.variables import ( @@ -197,11 +185,12 @@ class LLMNode(Node[LLMNodeData]): def _run(self) -> Generator: node_inputs: dict[str, Any] = {} process_data: dict[str, Any] = {} - usage: LLMUsage = LLMUsage.empty_usage() - finish_reason: str | None = None - reasoning_content: str = "" # Initialize as empty string for consistency + clean_text = "" + usage = LLMUsage.empty_usage() + finish_reason = None + reasoning_content = "" # Initialize as empty string for consistency clean_text = "" # Initialize clean_text to avoid UnboundLocalError - variable_pool: VariablePool = self.graph_runtime_state.variable_pool + variable_pool = self.graph_runtime_state.variable_pool try: # Parse prompt template to separate static messages and context references @@ -261,9 +250,8 @@ class LLMNode(Node[LLMNodeData]): ) query: str | None = None - memory_config = self.node_data.memory - if memory_config: - query = memory_config.query_prompt_template + if self.node_data.memory: + query = self.node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): @@ -305,23 +293,9 @@ class LLMNode(Node[LLMNodeData]): sandbox=self.graph_runtime_state.sandbox, ) - structured_output_schema: Mapping[str, Any] | None - structured_output_file_paths: list[str] = [] - - if self.node_data.structured_output_enabled: - if not self.node_data.structured_output: - raise ValueError("structured_output_enabled is True but structured_output is not set") - raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output) - if self.node_data.computer_use: - structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths( - raw_schema - ) - else: - if detect_file_path_fields(raw_schema): - raise LLMNodeError("Structured output file paths are only supported in sandbox mode.") - structured_output_schema = raw_schema - else: - structured_output_schema = None + # Variables for outputs + generation_data: LLMGenerationData | None = None + structured_output: LLMStructuredOutput | None = None if self.node_data.computer_use: sandbox = self.graph_runtime_state.sandbox @@ -335,10 +309,7 @@ class LLMNode(Node[LLMNodeData]): stop=stop, variable_pool=variable_pool, tool_dependencies=tool_dependencies, - structured_output_schema=structured_output_schema, - structured_output_file_paths=structured_output_file_paths, ) - elif self.tool_call_enabled: generator = self._invoke_llm_with_tools( model_instance=model_instance, @@ -348,7 +319,6 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, node_inputs=node_inputs, process_data=process_data, - structured_output_schema=structured_output_schema, ) else: # Use traditional LLM invocation @@ -358,7 +328,8 @@ class LLMNode(Node[LLMNodeData]): prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, - structured_output_schema=structured_output_schema, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, @@ -384,8 +355,6 @@ class LLMNode(Node[LLMNodeData]): reasoning_content = "" usage = generation_data.usage finish_reason = generation_data.finish_reason - if generation_data.structured_output: - structured_output = generation_data.structured_output # Unified process_data building process_data = { @@ -409,7 +378,16 @@ class LLMNode(Node[LLMNodeData]): if tool.enabled ] - # Build generation field and determine files_to_output first + # Unified outputs building + outputs = { + "text": clean_text, + "reasoning_content": reasoning_content, + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, + "context": llm_utils.build_context(prompt_messages, clean_text, generation_data), + } + + # Build generation field if generation_data: # Use generation_data from tool invocation (supports multi-turn) generation = { @@ -437,15 +415,6 @@ class LLMNode(Node[LLMNodeData]): } files_to_output = self._file_outputs - # Unified outputs building (files passed to context for subsequent node reference) - outputs = { - "text": clean_text, - "reasoning_content": reasoning_content, - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "context": llm_utils.build_context(prompt_messages, clean_text, generation_data, files=files_to_output), - } - outputs["generation"] = generation if files_to_output: outputs["files"] = ArrayFileSegment(value=files_to_output) @@ -524,7 +493,8 @@ class LLMNode(Node[LLMNodeData]): prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, user_id: str, - structured_output_schema: Mapping[str, Any] | None, + structured_output_enabled: bool, + structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, file_outputs: list[File], node_id: str, @@ -538,7 +508,10 @@ class LLMNode(Node[LLMNodeData]): if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") - if structured_output_schema: + if structured_output_enabled: + output_schema = LLMNode.fetch_structured_output_schema( + structured_output=structured_output or {}, + ) request_start_time = time.perf_counter() invoke_result = invoke_llm_with_structured_output( @@ -546,12 +519,12 @@ class LLMNode(Node[LLMNodeData]): model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=structured_output_schema, + json_schema=output_schema, model_parameters=node_data_model.completion_params, stop=list(stop or []), user=user_id, + tenant_id=tenant_id, ) - else: request_start_time = time.perf_counter() @@ -1288,16 +1261,18 @@ class LLMNode(Node[LLMNodeData]): # Insert histories into the prompt prompt_content = prompt_messages[0].content # For issue #11247 - Check if prompt content is a string or a list - if isinstance(prompt_content, str): + prompt_content_type = type(prompt_content) + if prompt_content_type == str: prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): + elif prompt_content_type == list: + prompt_content = prompt_content if isinstance(prompt_content, list) else [] for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): + if content_item.type == PromptMessageContentType.TEXT: if "#histories#" in content_item.data: content_item.data = content_item.data.replace("#histories#", memory_text) else: @@ -1307,12 +1282,13 @@ class LLMNode(Node[LLMNodeData]): # Add current query to the prompt message if sys_query: - if isinstance(prompt_content, str): + if prompt_content_type == str: prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): + elif prompt_content_type == list: + prompt_content = prompt_content if isinstance(prompt_content, list) else [] for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): + if content_item.type == PromptMessageContentType.TEXT: content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") @@ -1438,11 +1414,9 @@ class LLMNode(Node[LLMNodeData]): if isinstance(item, PromptMessageContext): if len(item.value_selector) >= 2: prompt_context_selectors.append(item.value_selector) - elif isinstance(item, LLMNodeChatModelMessage): + elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2": variable_template_parser = VariableTemplateParser(template=item.text) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - else: - raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): if prompt_template.edition_type != "jinja2": variable_template_parser = VariableTemplateParser(template=prompt_template.text) @@ -1478,14 +1452,13 @@ class LLMNode(Node[LLMNodeData]): if typed_node_data.prompt_config: enable_jinja = False - if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type == "jinja2": - enable_jinja = True - else: - for prompt in prompt_template: - if prompt.edition_type == "jinja2": + if isinstance(prompt_template, list): + for item in prompt_template: + if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2": enable_jinja = True break + else: + enable_jinja = True if enable_jinja: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: @@ -1898,7 +1871,6 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, node_inputs: dict[str, Any], process_data: dict[str, Any], - structured_output_schema: Mapping[str, Any] | None, ) -> Generator[NodeEventBase, None, LLMGenerationData]: """Invoke LLM with tools support (from Agent V2). @@ -1915,16 +1887,12 @@ class LLMNode(Node[LLMNodeData]): # Use factory to create appropriate strategy strategy = StrategyFactory.create_strategy( - tenant_id=self.tenant_id, - invoke_from=self.invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, model_features=model_features, model_instance=model_instance, tools=tool_instances, files=prompt_files, max_iterations=self._node_data.max_iterations or 10, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), - structured_output_schema=structured_output_schema, ) # Run strategy @@ -1932,6 +1900,7 @@ class LLMNode(Node[LLMNodeData]): prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, stop=list(stop or []), + stream=True, ) result = yield from self._process_tool_outputs(outputs) @@ -1945,12 +1914,8 @@ class LLMNode(Node[LLMNodeData]): stop: Sequence[str] | None, variable_pool: VariablePool, tool_dependencies: ToolDependencies | None, - structured_output_schema: Mapping[str, Any] | None, - structured_output_file_paths: Sequence[str] | None, - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]: + ) -> Generator[NodeEventBase, None, LLMGenerationData]: result: LLMGenerationData | None = None - sandbox_output_files: list[File] = [] - structured_output_files: list[File] = [] # FIXME(Mairuis): Async processing for bash session. with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session: @@ -1958,9 +1923,6 @@ class LLMNode(Node[LLMNodeData]): model_features = self._get_model_features(model_instance) strategy = StrategyFactory.create_strategy( - tenant_id=self.tenant_id, - invoke_from=self.invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, model_features=model_features, model_instance=model_instance, tools=[session.bash_tool], @@ -1968,55 +1930,20 @@ class LLMNode(Node[LLMNodeData]): max_iterations=self._node_data.max_iterations or 100, agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING, context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id), - structured_output_schema=structured_output_schema, ) outputs = strategy.run( prompt_messages=list(prompt_messages), model_parameters=self._node_data.model.completion_params, stop=list(stop or []), + stream=True, ) result = yield from self._process_tool_outputs(outputs) - if result and result.structured_output and structured_output_file_paths: - structured_output_payload = result.structured_output.structured_output or {} - try: - converted_output, structured_output_files = convert_sandbox_file_paths_in_output( - output=structured_output_payload, - file_path_fields=structured_output_file_paths, - file_resolver=session.download_file, - ) - except ValueError as exc: - raise LLMNodeError(str(exc)) from exc - result = result.model_copy( - update={"structured_output": LLMStructuredOutput(structured_output=converted_output)} - ) - - # Collect output files from sandbox before session ends - # Files are saved as ToolFiles with valid tool_file_id for later reference - sandbox_output_files = session.collect_output_files() - if result is None: raise LLMNodeError("SandboxSession exited unexpectedly") - structured_output = result.structured_output - if structured_output is not None: - yield structured_output - - # Merge sandbox output files into result - if sandbox_output_files or structured_output_files: - result = LLMGenerationData( - text=result.text, - reasoning_contents=result.reasoning_contents, - tool_calls=result.tool_calls, - sequence=result.sequence, - usage=result.usage, - finish_reason=result.finish_reason, - files=result.files + sandbox_output_files + structured_output_files, - trace=result.trace, - ) - return result def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]: @@ -2084,20 +2011,6 @@ class LLMNode(Node[LLMNodeData]): logger.warning("Failed to load tool %s: %s", tool, str(e)) continue - structured_output_schema = None - if self._node_data.structured_output_enabled: - structured_output_schema = LLMNode.fetch_structured_output_schema( - structured_output=self._node_data.structured_output or {}, - ) - tool_instances.extend( - build_agent_output_tools( - tenant_id=self.tenant_id, - invoke_from=self.invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, - structured_output_schema=structured_output_schema, - ) - ) - return tool_instances def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]: @@ -2261,45 +2174,18 @@ class LLMNode(Node[LLMNodeData]): # Add tool call to pending list for model segment buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments)) - output_tool_names = {OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL} - - if tool_name not in output_tool_names: - yield ToolCallChunkEvent( - selector=[self._node_id, "generation", "tool_calls"], - chunk=tool_arguments, - tool_call=ToolCall( - id=tool_call_id, - name=tool_name, - arguments=tool_arguments, - icon=tool_icon, - icon_dark=tool_icon_dark, - ), - is_final=False, - ) - - if tool_name in output_tool_names: - content = "" - if tool_name in (OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL): - content = payload.tool_args["text"] - elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL: - raw_content = json.dumps( - payload.tool_args["data"], - ensure_ascii=False, - indent=2 - ) - content = f"```json\n{raw_content}\n```" - - if content: - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=content, - is_final=False, - ) - yield StreamChunkEvent( - selector=[self._node_id, "generation", "content"], - chunk=content, - is_final=False, - ) + yield ToolCallChunkEvent( + selector=[self._node_id, "generation", "tool_calls"], + chunk=tool_arguments, + tool_call=ToolCall( + id=tool_call_id, + name=tool_name, + arguments=tool_arguments, + icon=tool_icon, + icon_dark=tool_icon_dark, + ), + is_final=False, + ) if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START: tool_name = payload.tool_name @@ -2543,7 +2429,6 @@ class LLMNode(Node[LLMNodeData]): content_position = 0 tool_call_seen_index: dict[str, int] = {} for trace_segment in trace_state.trace_segments: - # FIXME: These if will never happen if trace_segment.type == "thought": sequence.append({"type": "reasoning", "index": reasoning_index}) reasoning_index += 1 @@ -2586,15 +2471,8 @@ class LLMNode(Node[LLMNodeData]): key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map)) ) - text_content: str - if aggregate.text: - text_content = aggregate.text - elif aggregate.structured_output: - text_content = json.dumps(aggregate.structured_output.structured_output) - else: - raise ValueError("Aggregate must have either text or structured output.") return LLMGenerationData( - text=text_content, + text=aggregate.text, reasoning_contents=buffers.reasoning_per_turn, tool_calls=tool_calls_for_generation, sequence=sequence, @@ -2602,7 +2480,6 @@ class LLMNode(Node[LLMNodeData]): finish_reason=aggregate.finish_reason, files=aggregate.files, trace=trace_state.trace_segments, - structured_output=aggregate.structured_output, ) def _process_tool_outputs( @@ -2613,33 +2490,22 @@ class LLMNode(Node[LLMNodeData]): state = ToolOutputState() try: - while True: - output = next(outputs) + for output in outputs: if isinstance(output, AgentLog): yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent) else: - continue + yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate) except StopIteration as exception: - if not isinstance(exception.value, AgentResult): - raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception - state.agent.agent_result = exception.value - agent_result = state.agent.agent_result - if not agent_result: - raise ValueError("No agent result found in tool outputs") - output_payload = agent_result.output - if isinstance(output_payload, dict): - state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload) - state.aggregate.text = json.dumps(output_payload) - elif isinstance(output_payload, str): - state.aggregate.text = output_payload - else: - raise ValueError(f"Unexpected output type: {type(output_payload)}") + if isinstance(getattr(exception, "value", None), AgentResult): + state.agent.agent_result = exception.value - state.aggregate.files = state.agent.agent_result.files - if state.agent.agent_result.usage: - state.aggregate.usage = state.agent.agent_result.usage - if state.agent.agent_result.finish_reason: - state.aggregate.finish_reason = state.agent.agent_result.finish_reason + if state.agent.agent_result: + state.aggregate.text = state.agent.agent_result.text or state.aggregate.text + state.aggregate.files = state.agent.agent_result.files + if state.agent.agent_result.usage: + state.aggregate.usage = state.agent.agent_result.usage + if state.agent.agent_result.finish_reason: + state.aggregate.finish_reason = state.agent.agent_result.finish_reason yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate) yield from self._close_streams() diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index c8dfe7ccf9..564e548e9f 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -156,7 +156,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): prompt_messages=prompt_messages, stop=stop, user_id=self.user_id, - structured_output_schema=None, + structured_output_enabled=False, + structured_output=None, file_saver=self._llm_file_saver, file_outputs=self._file_outputs, node_id=self._node_id, diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 57c52c1c86..4c37a867b1 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -91,8 +91,6 @@ class BuiltinToolManageService: :return: the list of tools """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name): - return [] tools = provider_controller.get_tools() result: list[ToolApiEntity] = [] @@ -543,8 +541,6 @@ class BuiltinToolManageService: for provider_controller in provider_controllers: try: - if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name): - continue # handle include, exclude if is_filtered( include_set=dify_config.POSITION_TOOL_INCLUDES_SET, diff --git a/api/tests/fixtures/file output schema.yml b/api/tests/fixtures/file output schema.yml index 8ae3072572..37fc9c72c7 100644 --- a/api/tests/fixtures/file output schema.yml +++ b/api/tests/fixtures/file output schema.yml @@ -126,8 +126,9 @@ workflow: additionalProperties: false properties: image: - description: Sandbox file path of the selected image - type: file + description: File ID (UUID) of the selected image + format: dify-file-ref + type: string required: - image type: object diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py index aec077124d..4a613e35b0 100644 --- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py +++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py @@ -1,3 +1,4 @@ +import json from collections.abc import Generator from core.agent.entities import AgentScratchpadUnit @@ -63,7 +64,7 @@ def test_cot_output_parser(): output += result elif isinstance(result, AgentScratchpadUnit.Action): if test_case["action"]: - assert result.model_dump() == test_case["action"] - output += result.model_dump_json() + assert result.to_dict() == test_case["action"] + output += json.dumps(result.to_dict()) if test_case["output"]: assert output == test_case["output"] diff --git a/api/tests/unit_tests/core/agent/patterns/test_base.py b/api/tests/unit_tests/core/agent/patterns/test_base.py index dd42b31bad..b0e0d44940 100644 --- a/api/tests/unit_tests/core/agent/patterns/test_base.py +++ b/api/tests/unit_tests/core/agent/patterns/test_base.py @@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage class ConcreteAgentPattern(AgentPattern): """Concrete implementation of AgentPattern for testing.""" - def run(self, prompt_messages, model_parameters, stop=[]): + def run(self, prompt_messages, model_parameters, stop=[], stream=True): """Minimal implementation for testing.""" yield from [] diff --git a/api/tests/unit_tests/core/agent/test_agent_app_runner.py b/api/tests/unit_tests/core/agent/test_agent_app_runner.py index 43d28d39b7..d9301ccfe0 100644 --- a/api/tests/unit_tests/core/agent/test_agent_app_runner.py +++ b/api/tests/unit_tests/core/agent/test_agent_app_runner.py @@ -329,15 +329,13 @@ class TestAgentLogProcessing: ) result = AgentResult( - output="Final answer", + text="Final answer", files=[], usage=usage, finish_reason="stop", ) - output_payload = result.output - assert isinstance(output_payload, str) - assert output_payload == "Final answer" + assert result.text == "Final answer" assert result.files == [] assert result.usage == usage assert result.finish_reason == "stop" diff --git a/api/tests/unit_tests/core/agent/test_entities.py b/api/tests/unit_tests/core/agent/test_entities.py index 61b0ac0052..5136f48aab 100644 --- a/api/tests/unit_tests/core/agent/test_entities.py +++ b/api/tests/unit_tests/core/agent/test_entities.py @@ -153,7 +153,7 @@ class TestAgentScratchpadUnit: action_input={"query": "test"}, ) - result = action.model_dump() + result = action.to_dict() assert result == { "action": "search", diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py index 2c16de8f6f..6d18ac7fc9 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_file_ref.py @@ -1,93 +1,269 @@ """ -Unit tests for sandbox file path detection and conversion. +Unit tests for file reference detection and conversion. """ +import uuid +from unittest.mock import MagicMock, patch + import pytest from core.file import File, FileTransferMethod, FileType from core.llm_generator.output_parser.file_ref import ( - FILE_PATH_DESCRIPTION_SUFFIX, - adapt_schema_for_sandbox_file_paths, - convert_sandbox_file_paths_in_output, - detect_file_path_fields, - is_file_path_property, + FILE_REF_FORMAT, + convert_file_refs_in_output, + detect_file_ref_fields, + is_file_ref_property, ) from core.variables.segments import ArrayFileSegment, FileSegment -def _build_file(file_id: str) -> File: - return File( - id=file_id, - tenant_id="tenant_123", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - filename="test.png", - extension=".png", - mime_type="image/png", - size=128, - related_id=file_id, - storage_key="sandbox/path", - ) +class TestIsFileRefProperty: + """Tests for is_file_ref_property function.""" + + def test_valid_file_ref(self): + schema = {"type": "string", "format": FILE_REF_FORMAT} + assert is_file_ref_property(schema) is True + + def test_invalid_type(self): + schema = {"type": "number", "format": FILE_REF_FORMAT} + assert is_file_ref_property(schema) is False + + def test_missing_format(self): + schema = {"type": "string"} + assert is_file_ref_property(schema) is False + + def test_wrong_format(self): + schema = {"type": "string", "format": "uuid"} + assert is_file_ref_property(schema) is False -class TestFilePathSchema: - def test_is_file_path_property(self): - assert is_file_path_property({"type": "file"}) is True - assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True - assert is_file_path_property({"type": "string"}) is False +class TestDetectFileRefFields: + """Tests for detect_file_ref_fields function.""" - def test_detect_file_path_fields(self): + def test_simple_file_ref(self): schema = { "type": "object", "properties": { - "image": {"type": "string", "format": "dify-file-ref"}, - "files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}}, - "meta": {"type": "object", "properties": {"doc": {"type": "file"}}}, + "image": {"type": "string", "format": FILE_REF_FORMAT}, }, } - assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"} + paths = detect_file_ref_fields(schema) + assert paths == ["image"] - def test_adapt_schema_for_sandbox_file_paths(self): + def test_multiple_file_refs(self): schema = { "type": "object", "properties": { - "image": {"type": "string", "format": "dify-file-ref"}, + "image": {"type": "string", "format": FILE_REF_FORMAT}, + "document": {"type": "string", "format": FILE_REF_FORMAT}, "name": {"type": "string"}, }, } - adapted, fields = adapt_schema_for_sandbox_file_paths(schema) + paths = detect_file_ref_fields(schema) + assert set(paths) == {"image", "document"} - assert set(fields) == {"image"} - adapted_image = adapted["properties"]["image"] - assert adapted_image["type"] == "string" - assert "format" not in adapted_image - assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"] + def test_array_of_file_refs(self): + schema = { + "type": "object", + "properties": { + "files": { + "type": "array", + "items": {"type": "string", "format": FILE_REF_FORMAT}, + }, + }, + } + paths = detect_file_ref_fields(schema) + assert paths == ["files[*]"] + + def test_nested_file_ref(self): + schema = { + "type": "object", + "properties": { + "data": { + "type": "object", + "properties": { + "image": {"type": "string", "format": FILE_REF_FORMAT}, + }, + }, + }, + } + paths = detect_file_ref_fields(schema) + assert paths == ["data.image"] + + def test_no_file_refs(self): + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "count": {"type": "number"}, + }, + } + paths = detect_file_ref_fields(schema) + assert paths == [] + + def test_empty_schema(self): + schema = {} + paths = detect_file_ref_fields(schema) + assert paths == [] + + def test_mixed_schema(self): + schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "image": {"type": "string", "format": FILE_REF_FORMAT}, + "documents": { + "type": "array", + "items": {"type": "string", "format": FILE_REF_FORMAT}, + }, + }, + } + paths = detect_file_ref_fields(schema) + assert set(paths) == {"image", "documents[*]"} -class TestConvertSandboxFilePaths: - def test_convert_sandbox_file_paths(self): - output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"} +class TestConvertFileRefsInOutput: + """Tests for convert_file_refs_in_output function.""" - def resolver(path: str) -> File: - return _build_file(path) + @pytest.fixture + def mock_file(self): + """Create a mock File object with all required attributes.""" + file = MagicMock(spec=File) + file.type = FileType.IMAGE + file.transfer_method = FileTransferMethod.TOOL_FILE + file.related_id = "test-related-id" + file.remote_url = None + file.tenant_id = "tenant_123" + file.id = None + file.filename = "test.png" + file.extension = ".png" + file.mime_type = "image/png" + file.size = 1024 + file.dify_model_identity = "__dify__file__" + return file - converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver) + @pytest.fixture + def mock_build_from_mapping(self, mock_file): + """Mock the build_from_mapping function.""" + with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock: + mock.return_value = mock_file + yield mock - assert isinstance(converted["image"], FileSegment) - assert isinstance(converted["files"], ArrayFileSegment) - assert converted["name"] == "demo" - assert [file.id for file in files] == ["a.png", "b.png", "c.png"] + def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file): + file_id = str(uuid.uuid4()) + output = {"image": file_id} + schema = { + "type": "object", + "properties": { + "image": {"type": "string", "format": FILE_REF_FORMAT}, + }, + } - def test_invalid_path_value_raises(self): - def resolver(path: str) -> File: - return _build_file(path) + result = convert_file_refs_in_output(output, schema, "tenant_123") - with pytest.raises(ValueError): - convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver) + # Result should be wrapped in FileSegment + assert isinstance(result["image"], FileSegment) + assert result["image"].value == mock_file + mock_build_from_mapping.assert_called_once_with( + mapping={"transfer_method": "tool_file", "tool_file_id": file_id}, + tenant_id="tenant_123", + ) - def test_no_file_paths_returns_output(self): - output = {"name": "demo"} - converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file) + def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file): + file_id1 = str(uuid.uuid4()) + file_id2 = str(uuid.uuid4()) + output = {"files": [file_id1, file_id2]} + schema = { + "type": "object", + "properties": { + "files": { + "type": "array", + "items": {"type": "string", "format": FILE_REF_FORMAT}, + }, + }, + } - assert converted == output - assert files == [] + result = convert_file_refs_in_output(output, schema, "tenant_123") + + # Result should be wrapped in ArrayFileSegment + assert isinstance(result["files"], ArrayFileSegment) + assert list(result["files"].value) == [mock_file, mock_file] + assert mock_build_from_mapping.call_count == 2 + + def test_no_conversion_without_file_refs(self): + output = {"name": "test", "count": 5} + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "count": {"type": "number"}, + }, + } + + result = convert_file_refs_in_output(output, schema, "tenant_123") + + assert result == {"name": "test", "count": 5} + + def test_invalid_uuid_returns_none(self): + output = {"image": "not-a-valid-uuid"} + schema = { + "type": "object", + "properties": { + "image": {"type": "string", "format": FILE_REF_FORMAT}, + }, + } + + result = convert_file_refs_in_output(output, schema, "tenant_123") + + assert result["image"] is None + + def test_file_not_found_returns_none(self): + file_id = str(uuid.uuid4()) + output = {"image": file_id} + schema = { + "type": "object", + "properties": { + "image": {"type": "string", "format": FILE_REF_FORMAT}, + }, + } + + with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock: + mock.side_effect = ValueError("File not found") + result = convert_file_refs_in_output(output, schema, "tenant_123") + + assert result["image"] is None + + def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file): + file_id = str(uuid.uuid4()) + output = {"query": "search term", "image": file_id, "count": 10} + schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "image": {"type": "string", "format": FILE_REF_FORMAT}, + "count": {"type": "number"}, + }, + } + + result = convert_file_refs_in_output(output, schema, "tenant_123") + + assert result["query"] == "search term" + assert isinstance(result["image"], FileSegment) + assert result["image"].value == mock_file + assert result["count"] == 10 + + def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file): + file_id = str(uuid.uuid4()) + original = {"image": file_id} + output = dict(original) + schema = { + "type": "object", + "properties": { + "image": {"type": "string", "format": FILE_REF_FORMAT}, + }, + } + + convert_file_refs_in_output(output, schema, "tenant_123") + + # Original should still contain the string ID + assert original["image"] == file_id diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 30f2f8cc73..df73c29004 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -1,6 +1,4 @@ -from collections.abc import Mapping from decimal import Decimal -from typing import Any, NotRequired, TypedDict from unittest.mock import MagicMock, patch import pytest @@ -8,14 +6,16 @@ from pydantic import BaseModel, ConfigDict from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( + _get_default_value_for_type, fill_defaults_from_schema, - get_default_value_for_type, invoke_llm_with_pydantic_model, invoke_llm_with_structured_output, ) -from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import ( LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, LLMUsage, ) @@ -25,29 +25,7 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - ModelType, - ParameterRule, - ParameterType, -) - - -class StructuredOutputTestCase(TypedDict): - name: str - provider: str - model_name: str - support_structure_output: bool - stream: bool - json_schema: Mapping[str, Any] - expected_llm_response: LLMResult - expected_result_type: type[LLMResultWithStructuredOutput] | None - should_raise: bool - expected_error: NotRequired[type[OutputParserError]] - parameter_rules: NotRequired[list[ParameterRule]] - - -SchemaData = dict[str, Any] +from core.model_runtime.entities.model_entities import AIModelEntity, ModelType def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: @@ -93,7 +71,7 @@ def get_model_instance() -> MagicMock: def test_structured_output_parser(): """Test cases for invoke_llm_with_structured_output function""" - testcases: list[StructuredOutputTestCase] = [ + testcases = [ # Test case 1: Model with native structured output support, non-streaming { "name": "native_structured_output_non_streaming", @@ -110,6 +88,39 @@ def test_structured_output_parser(): "expected_result_type": LLMResultWithStructuredOutput, "should_raise": False, }, + # Test case 2: Model with native structured output support, streaming + { + "name": "native_structured_output_streaming", + "provider": "openai", + "model_name": "gpt-4o", + "support_structure_output": True, + "stream": True, + "json_schema": {"type": "object", "properties": {"name": {"type": "string"}}}, + "expected_llm_response": [ + LLMResultChunk( + model="gpt-4o", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content='{"name":'), + usage=create_mock_usage(prompt_tokens=10, completion_tokens=2), + ), + ), + LLMResultChunk( + model="gpt-4o", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=' "test"}'), + usage=create_mock_usage(prompt_tokens=10, completion_tokens=3), + ), + ), + ], + "expected_result_type": "generator", + "should_raise": False, + }, # Test case 3: Model without native structured output support, non-streaming { "name": "prompt_based_structured_output_non_streaming", @@ -126,24 +137,78 @@ def test_structured_output_parser(): "expected_result_type": LLMResultWithStructuredOutput, "should_raise": False, }, + # Test case 4: Model without native structured output support, streaming { - "name": "non_streaming_with_list_content", + "name": "prompt_based_structured_output_streaming", + "provider": "anthropic", + "model_name": "claude-3-sonnet", + "support_structure_output": False, + "stream": True, + "json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}}, + "expected_llm_response": [ + LLMResultChunk( + model="claude-3-sonnet", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content='{"answer": "test'), + usage=create_mock_usage(prompt_tokens=15, completion_tokens=3), + ), + ), + LLMResultChunk( + model="claude-3-sonnet", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=' response"}'), + usage=create_mock_usage(prompt_tokens=15, completion_tokens=5), + ), + ), + ], + "expected_result_type": "generator", + "should_raise": False, + }, + # Test case 5: Streaming with list content + { + "name": "streaming_with_list_content", "provider": "openai", "model_name": "gpt-4o", "support_structure_output": True, - "stream": False, + "stream": True, "json_schema": {"type": "object", "properties": {"data": {"type": "string"}}}, - "expected_llm_response": LLMResult( - model="gpt-4o", - message=AssistantPromptMessage( - content=[ - TextPromptMessageContent(data='{"data":'), - TextPromptMessageContent(data=' "value"}'), - ] + "expected_llm_response": [ + LLMResultChunk( + model="gpt-4o", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data='{"data":'), + ] + ), + usage=create_mock_usage(prompt_tokens=10, completion_tokens=2), + ), ), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=5), - ), - "expected_result_type": LLMResultWithStructuredOutput, + LLMResultChunk( + model="gpt-4o", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data=' "value"}'), + ] + ), + usage=create_mock_usage(prompt_tokens=10, completion_tokens=3), + ), + ), + ], + "expected_result_type": "generator", "should_raise": False, }, # Test case 6: Error case - non-string LLM response content (non-streaming) @@ -188,13 +253,7 @@ def test_structured_output_parser(): "stream": False, "json_schema": {"type": "object", "properties": {"result": {"type": "string"}}}, "parameter_rules": [ - ParameterRule( - name="response_format", - label=I18nObject(en_US="response_format"), - type=ParameterType.STRING, - required=False, - options=["json_schema"], - ), + MagicMock(name="response_format", options=["json_schema"], required=False), ], "expected_llm_response": LLMResult( model="gpt-4o", @@ -213,13 +272,7 @@ def test_structured_output_parser(): "stream": False, "json_schema": {"type": "object", "properties": {"output": {"type": "string"}}}, "parameter_rules": [ - ParameterRule( - name="response_format", - label=I18nObject(en_US="response_format"), - type=ParameterType.STRING, - required=False, - options=["JSON"], - ), + MagicMock(name="response_format", options=["JSON"], required=False), ], "expected_llm_response": LLMResult( model="claude-3-sonnet", @@ -232,72 +285,89 @@ def test_structured_output_parser(): ] for case in testcases: - provider = case["provider"] - model_name = case["model_name"] - support_structure_output = case["support_structure_output"] - json_schema = case["json_schema"] - stream = case["stream"] + # Setup model entity + model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"]) - model_schema = get_model_entity(provider, model_name, support_structure_output) - - parameter_rules = case.get("parameter_rules") - if parameter_rules is not None: - model_schema.parameter_rules = parameter_rules + # Add parameter rules if specified + if "parameter_rules" in case: + model_schema.parameter_rules = case["parameter_rules"] + # Setup model instance model_instance = get_model_instance() model_instance.invoke_llm.return_value = case["expected_llm_response"] + # Setup prompt messages prompt_messages = [ SystemPromptMessage(content="You are a helpful assistant."), UserPromptMessage(content="Generate a response according to the schema."), ] if case["should_raise"]: - expected_error = case.get("expected_error", OutputParserError) - with pytest.raises(expected_error): # noqa: PT012 - if stream: + # Test error cases + with pytest.raises(case["expected_error"]): # noqa: PT012 + if case["stream"]: result_generator = invoke_llm_with_structured_output( - provider=provider, + provider=case["provider"], model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=json_schema, + json_schema=case["json_schema"], ) + # Consume the generator to trigger the error list(result_generator) else: invoke_llm_with_structured_output( - provider=provider, + provider=case["provider"], model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=json_schema, + json_schema=case["json_schema"], ) else: + # Test successful cases with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair: + # Configure json_repair mock for cases that need it if case["name"] == "json_repair_scenario": mock_json_repair.return_value = {"name": "test"} result = invoke_llm_with_structured_output( - provider=provider, + provider=case["provider"], model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=json_schema, + json_schema=case["json_schema"], model_parameters={"temperature": 0.7, "max_tokens": 100}, user="test_user", ) - expected_result_type = case["expected_result_type"] - assert expected_result_type is not None - assert isinstance(result, expected_result_type) - assert result.model == model_name - assert result.structured_output is not None - assert isinstance(result.structured_output, dict) + if case["expected_result_type"] == "generator": + # Test streaming results + assert hasattr(result, "__iter__") + chunks = list(result) + assert len(chunks) > 0 + # Verify all chunks are LLMResultChunkWithStructuredOutput + for chunk in chunks[:-1]: # All except last + assert isinstance(chunk, LLMResultChunkWithStructuredOutput) + assert chunk.model == case["model_name"] + + # Last chunk should have structured output + last_chunk = chunks[-1] + assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput) + assert last_chunk.structured_output is not None + assert isinstance(last_chunk.structured_output, dict) + else: + # Test non-streaming results + assert isinstance(result, case["expected_result_type"]) + assert result.model == case["model_name"] + assert result.structured_output is not None + assert isinstance(result.structured_output, dict) + + # Verify model_instance.invoke_llm was called with correct parameters model_instance.invoke_llm.assert_called_once() call_args = model_instance.invoke_llm.call_args - assert call_args.kwargs["stream"] == stream + assert call_args.kwargs["stream"] == case["stream"] assert call_args.kwargs["user"] == "test_user" assert "temperature" in call_args.kwargs["model_parameters"] assert "max_tokens" in call_args.kwargs["model_parameters"] @@ -306,32 +376,45 @@ def test_structured_output_parser(): def test_parse_structured_output_edge_cases(): """Test edge cases for structured output parsing""" - provider = "deepseek" - model_name = "deepseek-r1" - support_structure_output = False - json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}} - expected_llm_response = LLMResult( - model="deepseek-r1", - message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'), - usage=create_mock_usage(prompt_tokens=10, completion_tokens=5), + # Test case with list that contains dict (reasoning model scenario) + testcase_list_with_dict = { + "name": "list_with_dict_parsing", + "provider": "deepseek", + "model_name": "deepseek-r1", + "support_structure_output": False, + "stream": False, + "json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}}, + "expected_llm_response": LLMResult( + model="deepseek-r1", + message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'), + usage=create_mock_usage(prompt_tokens=10, completion_tokens=5), + ), + "expected_result_type": LLMResultWithStructuredOutput, + "should_raise": False, + } + + # Setup for list parsing test + model_schema = get_model_entity( + testcase_list_with_dict["provider"], + testcase_list_with_dict["model_name"], + testcase_list_with_dict["support_structure_output"], ) - model_schema = get_model_entity(provider, model_name, support_structure_output) - model_instance = get_model_instance() - model_instance.invoke_llm.return_value = expected_llm_response + model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"] prompt_messages = [UserPromptMessage(content="Test reasoning")] with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair: + # Mock json_repair to return a list with dict mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"] result = invoke_llm_with_structured_output( - provider=provider, + provider=testcase_list_with_dict["provider"], model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=json_schema, + json_schema=testcase_list_with_dict["json_schema"], ) assert isinstance(result, LLMResultWithStructuredOutput) @@ -341,16 +424,18 @@ def test_parse_structured_output_edge_cases(): def test_model_specific_schema_preparation(): """Test schema preparation for different model types""" - provider = "google" - model_name = "gemini-pro" - support_structure_output = True - json_schema: SchemaData = { - "type": "object", - "properties": {"result": {"type": "boolean"}}, - "additionalProperties": False, + # Test Gemini model + gemini_case = { + "provider": "google", + "model_name": "gemini-pro", + "support_structure_output": True, + "stream": False, + "json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False}, } - model_schema = get_model_entity(provider, model_name, support_structure_output) + model_schema = get_model_entity( + gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"] + ) model_instance = get_model_instance() model_instance.invoke_llm.return_value = LLMResult( @@ -362,11 +447,11 @@ def test_model_specific_schema_preparation(): prompt_messages = [UserPromptMessage(content="Test")] result = invoke_llm_with_structured_output( - provider=provider, + provider=gemini_case["provider"], model_schema=model_schema, model_instance=model_instance, prompt_messages=prompt_messages, - json_schema=json_schema, + json_schema=gemini_case["json_schema"], ) assert isinstance(result, LLMResultWithStructuredOutput) @@ -408,26 +493,40 @@ def test_structured_output_with_pydantic_model_non_streaming(): assert result.name == "test" -def test_structured_output_with_pydantic_model_list_content(): +def test_structured_output_with_pydantic_model_streaming(): model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True) model_instance = get_model_instance() - model_instance.invoke_llm.return_value = LLMResult( - model="gpt-4o", - message=AssistantPromptMessage( - content=[ - TextPromptMessageContent(data='{"name":'), - TextPromptMessageContent(data=' "test"}'), - ] - ), - usage=create_mock_usage(prompt_tokens=8, completion_tokens=4), - ) + + def mock_streaming_response(): + yield LLMResultChunk( + model="gpt-4o", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content='{"name":'), + usage=create_mock_usage(prompt_tokens=8, completion_tokens=2), + ), + ) + yield LLMResultChunk( + model="gpt-4o", + prompt_messages=[UserPromptMessage(content="test")], + system_fingerprint="test", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=' "test"}'), + usage=create_mock_usage(prompt_tokens=8, completion_tokens=4), + ), + ) + + model_instance.invoke_llm.return_value = mock_streaming_response() result = invoke_llm_with_pydantic_model( provider="openai", model_schema=model_schema, model_instance=model_instance, prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")], - output_model=ExampleOutput, + output_model=ExampleOutput ) assert isinstance(result, ExampleOutput) @@ -449,51 +548,51 @@ def test_structured_output_with_pydantic_model_validation_error(): model_schema=model_schema, model_instance=model_instance, prompt_messages=[UserPromptMessage(content="test")], - output_model=ExampleOutput, + output_model=ExampleOutput ) class TestGetDefaultValueForType: - """Test cases for get_default_value_for_type function""" + """Test cases for _get_default_value_for_type function""" def test_string_type(self): - assert get_default_value_for_type("string") == "" + assert _get_default_value_for_type("string") == "" def test_object_type(self): - assert get_default_value_for_type("object") == {} + assert _get_default_value_for_type("object") == {} def test_array_type(self): - assert get_default_value_for_type("array") == [] + assert _get_default_value_for_type("array") == [] def test_number_type(self): - assert get_default_value_for_type("number") == 0 + assert _get_default_value_for_type("number") == 0 def test_integer_type(self): - assert get_default_value_for_type("integer") == 0 + assert _get_default_value_for_type("integer") == 0 def test_boolean_type(self): - assert get_default_value_for_type("boolean") is False + assert _get_default_value_for_type("boolean") is False def test_null_type(self): - assert get_default_value_for_type("null") is None + assert _get_default_value_for_type("null") is None def test_none_type(self): - assert get_default_value_for_type(None) is None + assert _get_default_value_for_type(None) is None def test_unknown_type(self): - assert get_default_value_for_type("unknown") is None + assert _get_default_value_for_type("unknown") is None def test_union_type_string_null(self): # ["string", "null"] should return "" (first non-null type) - assert get_default_value_for_type(["string", "null"]) == "" + assert _get_default_value_for_type(["string", "null"]) == "" def test_union_type_null_first(self): # ["null", "integer"] should return 0 (first non-null type) - assert get_default_value_for_type(["null", "integer"]) == 0 + assert _get_default_value_for_type(["null", "integer"]) == 0 def test_union_type_only_null(self): # ["null"] should return None - assert get_default_value_for_type(["null"]) is None + assert _get_default_value_for_type(["null"]) is None class TestFillDefaultsFromSchema: @@ -501,7 +600,7 @@ class TestFillDefaultsFromSchema: def test_simple_required_fields(self): """Test filling simple required fields""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "name": {"type": "string"}, @@ -510,7 +609,7 @@ class TestFillDefaultsFromSchema: }, "required": ["name", "age"], } - output: SchemaData = {"name": "Alice"} + output = {"name": "Alice"} result = fill_defaults_from_schema(output, schema) @@ -520,7 +619,7 @@ class TestFillDefaultsFromSchema: def test_non_required_fields_not_filled(self): """Test that non-required fields are not filled""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "required_field": {"type": "string"}, @@ -528,7 +627,7 @@ class TestFillDefaultsFromSchema: }, "required": ["required_field"], } - output: SchemaData = {} + output = {} result = fill_defaults_from_schema(output, schema) @@ -537,7 +636,7 @@ class TestFillDefaultsFromSchema: def test_nested_object_required_fields(self): """Test filling nested object required fields""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "user": { @@ -560,7 +659,7 @@ class TestFillDefaultsFromSchema: }, "required": ["user"], } - output: SchemaData = { + output = { "user": { "name": "Alice", "address": { @@ -585,7 +684,7 @@ class TestFillDefaultsFromSchema: def test_missing_nested_object_created(self): """Test that missing required nested objects are created""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "metadata": { @@ -599,7 +698,7 @@ class TestFillDefaultsFromSchema: }, "required": ["metadata"], } - output: SchemaData = {} + output = {} result = fill_defaults_from_schema(output, schema) @@ -611,7 +710,7 @@ class TestFillDefaultsFromSchema: def test_all_types_default_values(self): """Test default values for all types""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "str_field": {"type": "string"}, @@ -623,7 +722,7 @@ class TestFillDefaultsFromSchema: }, "required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"], } - output: SchemaData = {} + output = {} result = fill_defaults_from_schema(output, schema) @@ -638,7 +737,7 @@ class TestFillDefaultsFromSchema: def test_existing_values_preserved(self): """Test that existing values are not overwritten""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "name": {"type": "string"}, @@ -646,7 +745,7 @@ class TestFillDefaultsFromSchema: }, "required": ["name", "count"], } - output: SchemaData = {"name": "Bob", "count": 42} + output = {"name": "Bob", "count": 42} result = fill_defaults_from_schema(output, schema) @@ -654,7 +753,7 @@ class TestFillDefaultsFromSchema: def test_complex_nested_structure(self): """Test complex nested structure with multiple levels""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "user": { @@ -690,7 +789,7 @@ class TestFillDefaultsFromSchema: }, "required": ["user", "tags", "metadata", "is_active"], } - output: SchemaData = { + output = { "user": { "name": "Alice", "age": 25, @@ -730,8 +829,8 @@ class TestFillDefaultsFromSchema: def test_empty_schema(self): """Test with empty schema""" - schema: SchemaData = {} - output: SchemaData = {"any": "value"} + schema = {} + output = {"any": "value"} result = fill_defaults_from_schema(output, schema) @@ -739,14 +838,14 @@ class TestFillDefaultsFromSchema: def test_schema_without_required(self): """Test schema without required field""" - schema: SchemaData = { + schema = { "type": "object", "properties": { "optional1": {"type": "string"}, "optional2": {"type": "integer"}, }, } - output: SchemaData = {} + output = {} result = fill_defaults_from_schema(output, schema)