mirror of
https://github.com/langgenius/dify.git
synced 2026-02-13 16:00:55 -05:00
revert: add tools for output in agent mode
feat: hide output tools and improve JSON formatting for structured output feat: hide output tools and improve JSON formatting for structured output fix: handle prompt template correctly to extract selectors for step run fix: emit StreamChunkEvent correctly for sandbox agent chore: better debug message fix: incorrect output tool runtime selection fix: type issues fix: align parameter list fix: align parameter list fix: hide internal builtin providers from tool list vibe: implement file structured output vibe: implement file structured output fix: refix parameter for tool fix: crash fix: crash refactor: remove union types fix: type check Merge branch 'feat/structured-output-with-sandbox' into feat/support-agent-sandbox fix: provide json as text fix: provide json as text fix: get AgentResult correctly fix: provides correct prompts, tools and terminal predicates fix: provides correct prompts, tools and terminal predicates fix: circular import feat: support structured output in sandbox and tool mode
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
@@ -24,7 +23,7 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMeta
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
@@ -103,13 +102,11 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
@@ -121,6 +118,7 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
@@ -135,10 +133,17 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# No more expect streaming data
|
||||
continue
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
else:
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
@@ -151,6 +156,7 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
@@ -258,30 +264,23 @@ class AgentAppRunner(BaseAgentRunner):
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if not isinstance(result, AgentResult):
|
||||
raise ValueError("Agent did not return AgentResult")
|
||||
output_payload = result.output
|
||||
if isinstance(output_payload, dict):
|
||||
final_answer = json.dumps(output_payload, ensure_ascii=False)
|
||||
elif isinstance(output_payload, str):
|
||||
final_answer = output_payload
|
||||
else:
|
||||
raise ValueError("Final output is not a string or structured data.")
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import build_agent_output_tools
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -37,7 +36,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@@ -253,14 +251,6 @@ class BaseAgentRunner(AppRunner):
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
)
|
||||
for tool in output_tools:
|
||||
tool_instances[tool.entity.identity.name] = tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.agent.output_tools import FINAL_OUTPUT_TOOL
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
@@ -42,9 +41,9 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: dict[str, Any] | str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert to dictionary.
|
||||
"""
|
||||
@@ -63,9 +62,9 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
if self.action is None:
|
||||
return False
|
||||
return self.action.action_name.lower() == FINAL_OUTPUT_TOOL
|
||||
return self.action is None or (
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
@@ -126,7 +125,7 @@ class ExecutionContext(BaseModel):
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs: Any) -> "ExecutionContext":
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
"""Create a new context with updated fields."""
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
@@ -179,23 +178,12 @@ class AgentLog(BaseModel):
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
|
||||
|
||||
|
||||
class AgentOutputKind(StrEnum):
|
||||
"""
|
||||
Agent output kind.
|
||||
"""
|
||||
|
||||
OUTPUT_TEXT = "output_text"
|
||||
FINAL_OUTPUT_ANSWER = "final_output_answer"
|
||||
FINAL_STRUCTURED_OUTPUT = "final_structured_output"
|
||||
ILLEGAL_OUTPUT = "illegal_output"
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""
|
||||
Agent execution result.
|
||||
"""
|
||||
|
||||
output: str | dict = Field(default="", description="The generated output")
|
||||
text: str = Field(default="", description="The generated text")
|
||||
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
|
||||
usage: Any | None = Field(default=None, description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(default=None, description="Reason for completion")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
@@ -10,52 +10,46 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any]
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(action: Any) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name: str | None = None
|
||||
action_input: Any | None = None
|
||||
parsed_action: Any = action
|
||||
if isinstance(parsed_action, str):
|
||||
def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]:
|
||||
action_name = None
|
||||
action_input = None
|
||||
if isinstance(action, str):
|
||||
try:
|
||||
parsed_action = json.loads(parsed_action, strict=False)
|
||||
action = json.loads(action, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
return parsed_action or ""
|
||||
return action or ""
|
||||
|
||||
# cohere always returns a list
|
||||
if isinstance(parsed_action, list):
|
||||
action_list: list[Any] = cast(list[Any], parsed_action)
|
||||
if len(action_list) == 1:
|
||||
parsed_action = action_list[0]
|
||||
if isinstance(action, list) and len(action) == 1:
|
||||
action = action[0]
|
||||
|
||||
if isinstance(parsed_action, dict):
|
||||
action_dict: dict[str, Any] = cast(dict[str, Any], parsed_action)
|
||||
for key, value in action_dict.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
elif isinstance(value, str):
|
||||
action_name = value
|
||||
else:
|
||||
return json.dumps(parsed_action)
|
||||
for key, value in action.items():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
|
||||
if action_name is not None and action_input is not None:
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action_name,
|
||||
action_input=action_input,
|
||||
)
|
||||
return json.dumps(parsed_action)
|
||||
else:
|
||||
return json.dumps(action)
|
||||
|
||||
def extra_json_from_code_block(code_block: str) -> list[dict[str, Any] | list[Any]]:
|
||||
def extra_json_from_code_block(code_block) -> list[Union[list, dict]]:
|
||||
blocks = re.findall(r"```[json]*\s*([\[{].*[]}])\s*```", code_block, re.DOTALL | re.IGNORECASE)
|
||||
if not blocks:
|
||||
return []
|
||||
try:
|
||||
json_blocks: list[dict[str, Any] | list[Any]] = []
|
||||
json_blocks = []
|
||||
for block in blocks:
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
json_blocks.append(json.loads(json_text, strict=False))
|
||||
return json_blocks
|
||||
except Exception:
|
||||
except:
|
||||
return []
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
OUTPUT_TOOL_PROVIDER = "agent_output"
|
||||
|
||||
OUTPUT_TEXT_TOOL = "output_text"
|
||||
FINAL_OUTPUT_TOOL = "final_output_answer"
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL = "final_structured_output"
|
||||
ILLEGAL_OUTPUT_TOOL = "illegal_output"
|
||||
|
||||
OUTPUT_TOOL_NAMES: Sequence[str] = (
|
||||
OUTPUT_TEXT_TOOL,
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
)
|
||||
|
||||
TERMINAL_OUTPUT_TOOL_NAMES: Sequence[str] = (FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
|
||||
|
||||
TERMINAL_OUTPUT_MESSAGE = "Final output received. This ends the current session."
|
||||
|
||||
|
||||
def is_terminal_output_tool(tool_name: str) -> bool:
|
||||
return tool_name in TERMINAL_OUTPUT_TOOL_NAMES
|
||||
|
||||
|
||||
def get_terminal_tool_name(structured_output_enabled: bool) -> str:
|
||||
return FINAL_STRUCTURED_OUTPUT_TOOL if structured_output_enabled else FINAL_OUTPUT_TOOL
|
||||
|
||||
|
||||
OUTPUT_TOOL_NAME_SET = set(OUTPUT_TOOL_NAMES)
|
||||
|
||||
|
||||
def build_agent_output_tools(
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
structured_output_schema: Mapping[str, Any] | None = None,
|
||||
) -> list[Tool]:
|
||||
def get_tool_runtime(_tool_name: str) -> Tool:
|
||||
return ToolManager.get_tool_runtime(
|
||||
provider_type=ToolProviderType.BUILT_IN,
|
||||
provider_id=OUTPUT_TOOL_PROVIDER,
|
||||
tool_name=_tool_name,
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
|
||||
tools: list[Tool] = [
|
||||
get_tool_runtime(OUTPUT_TEXT_TOOL),
|
||||
get_tool_runtime(ILLEGAL_OUTPUT_TOOL),
|
||||
]
|
||||
|
||||
if structured_output_schema:
|
||||
raw_tool = get_tool_runtime(FINAL_STRUCTURED_OUTPUT_TOOL)
|
||||
raw_tool.entity = raw_tool.entity.model_copy(deep=True)
|
||||
data_parameter = ToolParameter(
|
||||
name="data",
|
||||
type=ToolParameter.ToolParameterType.OBJECT,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
required=True,
|
||||
input_schema=dict(structured_output_schema),
|
||||
label=I18nObject(en_US="__Data", zh_Hans="__Data"),
|
||||
)
|
||||
raw_tool.entity.parameters = [data_parameter]
|
||||
|
||||
def invoke_tool(
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> ToolInvokeMessage:
|
||||
data = tool_parameters["data"]
|
||||
if not data:
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` field is required"))
|
||||
if not isinstance(data, dict):
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text="`data` must be a dict"))
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
|
||||
|
||||
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
|
||||
tools.append(raw_tool)
|
||||
else:
|
||||
raw_tool = get_tool_runtime(FINAL_OUTPUT_TOOL)
|
||||
|
||||
def invoke_tool(
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(message=ToolInvokeMessage.TextMessage(text=TERMINAL_OUTPUT_MESSAGE))
|
||||
|
||||
raw_tool._invoke = invoke_tool # pyright: ignore[reportPrivateUsage]
|
||||
tools.append(raw_tool)
|
||||
|
||||
return tools
|
||||
@@ -10,7 +10,6 @@ from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.agent.output_tools import ILLEGAL_OUTPUT_TOOL
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
@@ -57,7 +56,8 @@ class AgentPattern(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
|
||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
@@ -462,8 +462,6 @@ class AgentPattern(ABC):
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == ILLEGAL_OUTPUT_TOOL:
|
||||
continue
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
"""Function Call strategy implementation."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Protocol, Union, cast
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.agent.output_tools import (
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
TERMINAL_OUTPUT_MESSAGE,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
@@ -30,7 +25,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
|
||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
@@ -43,23 +39,9 @@ class FunctionCallStrategy(AgentPattern):
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
final_tool_args: dict[str, Any] = {"!!!": "!!!"}
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
|
||||
class _LLMInvoker(Protocol):
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
stream: Literal[False],
|
||||
user: str | None,
|
||||
callbacks: list[Any],
|
||||
) -> LLMResult: ...
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
@@ -69,7 +51,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
@@ -86,63 +69,47 @@ class FunctionCallStrategy(AgentPattern):
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
invoker = cast(_LLMInvoker, self.model_instance)
|
||||
chunks = invoker.invoke_llm(
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=prompt_tools,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
stream=stream,
|
||||
user=self.context.user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, emit_chunks=False
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
|
||||
if response_content:
|
||||
replaced_tool_call = (
|
||||
str(uuid.uuid4()),
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
{
|
||||
"raw": response_content,
|
||||
},
|
||||
)
|
||||
tool_calls.append(replaced_tool_call)
|
||||
|
||||
messages.append(self._create_assistant_message("", tool_calls))
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_entity = self._find_tool_by_name(tool_name)
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
if tool_response == TERMINAL_OUTPUT_MESSAGE:
|
||||
function_call_state = False
|
||||
final_tool_args = tool_entity.transform_tool_parameters_type(tool_args)
|
||||
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _ = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
@@ -161,19 +128,8 @@ class FunctionCallStrategy(AgentPattern):
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
output_payload: str | dict
|
||||
output_text = final_tool_args.get("text")
|
||||
output_structured_payload = final_tool_args.get("data")
|
||||
|
||||
if isinstance(output_structured_payload, dict):
|
||||
output_payload = output_structured_payload
|
||||
elif isinstance(output_text, str):
|
||||
output_payload = output_text
|
||||
else:
|
||||
raise ValueError(f"Final output ({final_tool_args}) is not a string or structured data.")
|
||||
|
||||
return AgentResult(
|
||||
output=output_payload,
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
@@ -184,8 +140,6 @@ class FunctionCallStrategy(AgentPattern):
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
@@ -217,8 +171,7 @@ class FunctionCallStrategy(AgentPattern):
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
if emit_chunks:
|
||||
yield chunk
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
@@ -233,12 +186,11 @@ class FunctionCallStrategy(AgentPattern):
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
if emit_chunks:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
@@ -248,14 +200,6 @@ class FunctionCallStrategy(AgentPattern):
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
|
||||
@@ -4,17 +4,10 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
ILLEGAL_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
OUTPUT_TOOL_NAME_SET,
|
||||
)
|
||||
from core.file import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
@@ -59,7 +52,8 @@ class ReActStrategy(AgentPattern):
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str]
|
||||
self, prompt_messages: list[PromptMessage], model_parameters: dict[str, Any], stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
@@ -71,19 +65,6 @@ class ReActStrategy(AgentPattern):
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
tool_instance_names = {tool.entity.identity.name for tool in self.tools}
|
||||
available_output_tool_names = {
|
||||
tool_name
|
||||
for tool_name in tool_instance_names
|
||||
if tool_name in OUTPUT_TOOL_NAME_SET and tool_name != ILLEGAL_OUTPUT_TOOL
|
||||
}
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
elif FINAL_OUTPUT_TOOL in available_output_tool_names:
|
||||
terminal_tool_name = FINAL_OUTPUT_TOOL
|
||||
else:
|
||||
raise ValueError("No terminal output tool configured")
|
||||
allow_illegal_output = ILLEGAL_OUTPUT_TOOL in tool_instance_names
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
@@ -100,15 +81,10 @@ class ReActStrategy(AgentPattern):
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with tool restrictions on last iteration
|
||||
if iteration_step == max_iterations:
|
||||
tools_for_prompt = [
|
||||
tool for tool in self.tools if tool.entity.identity.name in available_output_tool_names
|
||||
]
|
||||
else:
|
||||
tools_for_prompt = [tool for tool in self.tools if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL]
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, tools_for_prompt, self.instruction
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
@@ -130,18 +106,18 @@ class ReActStrategy(AgentPattern):
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks = self.model_instance.invoke_llm(
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
stream=stream,
|
||||
user=self.context.user_id or "",
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages, emit_chunks=False
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
@@ -155,44 +131,28 @@ class ReActStrategy(AgentPattern):
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action is None:
|
||||
if not allow_illegal_output:
|
||||
raise ValueError("Model did not call any tools")
|
||||
illegal_action = AgentScratchpadUnit.Action(
|
||||
action_name=ILLEGAL_OUTPUT_TOOL,
|
||||
action_input={"raw": scratchpad.thought or ""},
|
||||
)
|
||||
scratchpad.action = illegal_action
|
||||
scratchpad.action_str = illegal_action.model_dump_json()
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
observation, tool_files = yield from self._handle_tool_call(illegal_action, current_messages, round_log)
|
||||
scratchpad.observation = observation
|
||||
output_files.extend(tool_files)
|
||||
else:
|
||||
action_name = scratchpad.action.action_name
|
||||
if action_name == OUTPUT_TEXT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
pass # output_text_payload = scratchpad.action.action_input.get("text")
|
||||
elif action_name == FINAL_STRUCTURED_OUTPUT_TOOL and isinstance(scratchpad.action.action_input, dict):
|
||||
data = scratchpad.action.action_input.get("data")
|
||||
if isinstance(data, dict):
|
||||
pass # structured_output_payload = data
|
||||
elif action_name == FINAL_OUTPUT_TOOL:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_text = self._format_output_text(scratchpad.action.action_input.get("text"))
|
||||
else:
|
||||
final_text = self._format_output_text(scratchpad.action.action_input)
|
||||
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
if action_name == terminal_tool_name:
|
||||
pass # terminal_output_seen = True
|
||||
react_state = False
|
||||
else:
|
||||
react_state = True
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
@@ -208,22 +168,17 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
# Return final result
|
||||
|
||||
output_payload: str | dict
|
||||
|
||||
# TODO
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
output=output_payload,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage"),
|
||||
finish_reason=finish_reason,
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
tools: list[Tool] | None,
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
@@ -240,13 +195,9 @@ class ReActStrategy(AgentPattern):
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if tools:
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [
|
||||
tool.to_prompt_message_tool()
|
||||
for tool in tools
|
||||
if tool.entity.identity.name != ILLEGAL_OUTPUT_TOOL
|
||||
]
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
@@ -258,19 +209,12 @@ class ReActStrategy(AgentPattern):
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
final_tool_name = FINAL_OUTPUT_TOOL
|
||||
if FINAL_STRUCTURED_OUTPUT_TOOL in tool_names:
|
||||
final_tool_name = FINAL_STRUCTURED_OUTPUT_TOOL
|
||||
if final_tool_name not in tool_names:
|
||||
raise ValueError("No terminal output tool available for prompt")
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
updated_content = updated_content.replace("{{final_tool_name}}", final_tool_name)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
@@ -302,12 +246,10 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: LLMResult,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
*,
|
||||
emit_chunks: bool,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
@@ -319,20 +261,25 @@ class ReActStrategy(AgentPattern):
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
# Create a generator from the LLMResult
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=chunks.model,
|
||||
prompt_messages=chunks.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=chunks.message,
|
||||
usage=chunks.usage,
|
||||
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
|
||||
),
|
||||
system_fingerprint=chunks.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
@@ -356,18 +303,14 @@ class ReActStrategy(AgentPattern):
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
elif isinstance(chunk, str):
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
if emit_chunks:
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
else:
|
||||
raise ValueError(f"Unexpected chunk type: {type(chunk)}")
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
@@ -391,14 +334,6 @@ class ReActStrategy(AgentPattern):
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
@staticmethod
|
||||
def _format_output_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
|
||||
@@ -2,17 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.file.models import File
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from ...app.entities.app_invoke_entities import InvokeFrom
|
||||
from ...tools.entities.tool_entities import ToolInvokeFrom
|
||||
from ..output_tools import build_agent_output_tools
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
@@ -29,10 +25,6 @@ class StrategyFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
*,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
tool_invoke_from: ToolInvokeFrom,
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
@@ -43,16 +35,11 @@ class StrategyFactory:
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
structured_output_schema: Mapping[str, Any] | None = None,
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
tenant_id:
|
||||
invoke_from:
|
||||
tool_invoke_from:
|
||||
structured_output_schema:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
@@ -67,14 +54,6 @@ class StrategyFactory:
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
output_tools = build_agent_output_tools(
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
tools.extend(output_tools)
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
|
||||
@@ -7,7 +7,7 @@ You have access to the following tools:
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
@@ -32,14 +32,12 @@ Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "{{final_tool_name}}",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
@@ -58,7 +56,7 @@ You have access to the following tools:
|
||||
{{tools}}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
Valid "action" values: {{tool_names}}. You must call "{{final_tool_name}}" to finish.
|
||||
Valid "action" values: "Final Answer" or {{tool_names}}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
@@ -83,14 +81,12 @@ Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{
|
||||
"action": "{{final_tool_name}}",
|
||||
"action_input": {
|
||||
"text": "Final response to human"
|
||||
}
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Do not respond with plain text. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
|
||||
@@ -477,6 +477,7 @@ class LLMGenerator:
|
||||
prompt_messages=complete_messages,
|
||||
output_model=CodeNodeStructuredOutput,
|
||||
model_parameters=model_parameters,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -557,14 +558,10 @@ class LLMGenerator:
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=completion_params,
|
||||
)
|
||||
response = invoke_llm_with_pydantic_model(provider=model_instance.provider, model_schema=model_schema,
|
||||
model_instance=model_instance, prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=completion_params, tenant_id=tenant_id)
|
||||
|
||||
return {"questions": response.questions, "error": ""}
|
||||
|
||||
|
||||
@@ -1,190 +1,188 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_PATH_SCHEMA_TYPE = "file"
|
||||
FILE_PATH_SCHEMA_FORMATS = {"file", "file-ref", "dify-file-ref"}
|
||||
FILE_PATH_DESCRIPTION_SUFFIX = "Sandbox file path (relative paths supported)."
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
|
||||
|
||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
||||
if schema.get("type") == FILE_PATH_SCHEMA_TYPE:
|
||||
return True
|
||||
format_value = schema.get("format")
|
||||
if not isinstance(format_value, str):
|
||||
return False
|
||||
normalized_format = format_value.lower().replace("_", "-")
|
||||
return normalized_format in FILE_PATH_SCHEMA_FORMATS
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
|
||||
|
||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
file_path_fields: list[str] = []
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, Mapping):
|
||||
continue
|
||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_mapping):
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, Mapping):
|
||||
return file_path_fields
|
||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
||||
items_schema = schema.get("items", {})
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_mapping):
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_path_fields
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
result = _deep_copy_value(schema)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("structured_output_schema must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
file_path_fields: list[str] = []
|
||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
||||
return result_dict, file_path_fields
|
||||
|
||||
|
||||
def convert_sandbox_file_paths_in_output(
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
file_path_fields: Sequence[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
) -> tuple[dict[str, Any], list[File]]:
|
||||
if not file_path_fields:
|
||||
return dict(output), []
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
result = _deep_copy_value(output)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("Structured output must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
files: list[File] = []
|
||||
for path in file_path_fields:
|
||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
return result_dict, files
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, dict):
|
||||
continue
|
||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_dict):
|
||||
_normalize_file_path_schema(prop_schema_dict)
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, dict):
|
||||
return
|
||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_dict):
|
||||
_normalize_file_path_schema(items_schema_dict)
|
||||
file_path_fields.append(array_path)
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
else:
|
||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
||||
schema["type"] = "string"
|
||||
schema.pop("format", None)
|
||||
description = schema.get("description", "")
|
||||
if description:
|
||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
||||
else:
|
||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
||||
|
||||
|
||||
def _deep_copy_value(value: Any) -> Any:
|
||||
if isinstance(value, Mapping):
|
||||
mapping = cast(Mapping[str, Any], value)
|
||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[Any], value)
|
||||
return [_deep_copy_value(item) for item in list_value]
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_in_place(
|
||||
obj: dict[str, Any],
|
||||
path_parts: list[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
files: list[File],
|
||||
) -> None:
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else ""
|
||||
target_value = obj.get(key) if key else obj
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target_value, list):
|
||||
target_list = cast(list[Any], target_value)
|
||||
if isinstance(target, list):
|
||||
if remaining:
|
||||
for item in target_list:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
if isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
else:
|
||||
resolved_files: list[File] = []
|
||||
for item in target_list:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(item)
|
||||
files.append(file)
|
||||
resolved_files.append(file)
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
if current not in obj:
|
||||
return
|
||||
value = obj[current]
|
||||
if value is None:
|
||||
obj[current] = None
|
||||
return
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(value)
|
||||
files.append(file)
|
||||
obj[current] = FileSegment(value=file)
|
||||
return
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
|
||||
@@ -8,7 +8,7 @@ import json_repair
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import detect_file_path_fields
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@@ -55,11 +55,12 @@ def invoke_llm_with_structured_output(
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any] | None = None,
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with structured output.
|
||||
@@ -77,12 +78,14 @@ def invoke_llm_with_structured_output(
|
||||
:param stop: stop words
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: response with structured output
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
model_parameters_with_json_schema: dict[str, Any] = dict(model_parameters or {})
|
||||
|
||||
if detect_file_path_fields(json_schema):
|
||||
raise OutputParserError("Structured output file paths are only supported in sandbox mode.")
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
}
|
||||
|
||||
# Determine structured output strategy
|
||||
|
||||
@@ -119,6 +122,14 @@ def invoke_llm_with_structured_output(
|
||||
# Fill missing fields with default values
|
||||
structured_output = fill_defaults_from_schema(structured_output, json_schema)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
@@ -136,11 +147,12 @@ def invoke_llm_with_pydantic_model(
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping[str, Any] | None = None,
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Invoke large language model with a Pydantic output model.
|
||||
@@ -148,8 +160,11 @@ def invoke_llm_with_pydantic_model(
|
||||
This helper generates a JSON schema from the Pydantic model, invokes the
|
||||
structured-output LLM path, and validates the result.
|
||||
|
||||
The helper performs a non-streaming invocation and returns the validated
|
||||
Pydantic model directly.
|
||||
The stream parameter controls the underlying LLM invocation mode:
|
||||
- stream=True (default): Uses streaming LLM call, consumes the generator internally
|
||||
- stream=False: Uses non-streaming LLM call
|
||||
|
||||
In both cases, the function returns the validated Pydantic model directly.
|
||||
"""
|
||||
json_schema = _schema_from_pydantic(output_model)
|
||||
|
||||
@@ -164,6 +179,7 @@ def invoke_llm_with_pydantic_model(
|
||||
stop=stop,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
structured_output = result.structured_output
|
||||
@@ -220,27 +236,25 @@ def _extract_structured_output(llm_result: LLMResult) -> Mapping[str, Any]:
|
||||
return _parse_structured_output(content)
|
||||
|
||||
|
||||
def _parse_tool_call_arguments(arguments: str) -> dict[str, Any]:
|
||||
def _parse_tool_call_arguments(arguments: str) -> Mapping[str, Any]:
|
||||
"""Parse JSON from tool call arguments."""
|
||||
if not arguments:
|
||||
raise OutputParserError("Tool call arguments is empty")
|
||||
|
||||
try:
|
||||
parsed_any = json.loads(arguments)
|
||||
if not isinstance(parsed_any, dict):
|
||||
parsed = json.loads(arguments)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Tool call arguments is not a dict: {arguments}")
|
||||
parsed = cast(dict[str, Any], parsed_any)
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
# Try to repair malformed JSON
|
||||
repaired_any = json_repair.loads(arguments)
|
||||
if not isinstance(repaired_any, dict):
|
||||
repaired = json_repair.loads(arguments)
|
||||
if not isinstance(repaired, dict):
|
||||
raise OutputParserError(f"Failed to parse tool call arguments: {arguments}")
|
||||
repaired: dict[str, Any] = repaired_any
|
||||
return repaired
|
||||
|
||||
|
||||
def get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
def _get_default_value_for_type(type_name: str | list[str] | None) -> Any:
|
||||
"""Get default empty value for a JSON schema type."""
|
||||
# Handle array of types (e.g., ["string", "null"])
|
||||
if isinstance(type_name, list):
|
||||
@@ -297,7 +311,7 @@ def fill_defaults_from_schema(
|
||||
# Create empty object and recursively fill its required fields
|
||||
result[prop_name] = fill_defaults_from_schema({}, prop_schema)
|
||||
else:
|
||||
result[prop_name] = get_default_value_for_type(prop_type)
|
||||
result[prop_name] = _get_default_value_for_type(prop_type)
|
||||
elif isinstance(result[prop_name], dict) and prop_type == "object" and "properties" in prop_schema:
|
||||
# Field exists and is an object, recursively fill nested required fields
|
||||
result[prop_name] = fill_defaults_from_schema(result[prop_name], prop_schema)
|
||||
@@ -308,10 +322,10 @@ def fill_defaults_from_schema(
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict[str, Any]:
|
||||
):
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
@@ -333,7 +347,7 @@ def _handle_native_json_schema(
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
|
||||
def _set_response_format(model_parameters: dict, rules: list):
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
@@ -349,7 +363,7 @@ def _set_response_format(model_parameters: dict[str, Any], rules: list[Parameter
|
||||
|
||||
|
||||
def _handle_prompt_based_schema(
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping[str, Any]
|
||||
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Handle structured output for models without native JSON schema support.
|
||||
@@ -386,27 +400,28 @@ def _handle_prompt_based_schema(
|
||||
return updated_prompt
|
||||
|
||||
|
||||
def _parse_structured_output(result_text: str) -> dict[str, Any]:
|
||||
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
|
||||
structured_output: Mapping[str, Any] = {}
|
||||
parsed: Mapping[str, Any] = {}
|
||||
try:
|
||||
parsed = TypeAdapter(dict[str, Any]).validate_json(result_text)
|
||||
return parsed
|
||||
parsed = TypeAdapter(Mapping).validate_json(result_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = parsed
|
||||
except ValidationError:
|
||||
# if the result_text is not a valid json, try to repair it
|
||||
temp_parsed: Any = json_repair.loads(result_text)
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed_list = cast(list[Any], temp_parsed)
|
||||
dict_items: list[dict[str, Any]] = []
|
||||
for item in temp_parsed_list:
|
||||
if isinstance(item, dict):
|
||||
dict_items.append(cast(dict[str, Any], item))
|
||||
temp_parsed = dict_items[0] if dict_items else {}
|
||||
temp_parsed = json_repair.loads(result_text)
|
||||
if not isinstance(temp_parsed, dict):
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
temp_parsed_dict = cast(dict[str, Any], temp_parsed)
|
||||
return temp_parsed_dict
|
||||
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
|
||||
if isinstance(temp_parsed, list):
|
||||
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
|
||||
else:
|
||||
raise OutputParserError(f"Failed to parse structured output: {result_text}")
|
||||
structured_output = cast(dict, temp_parsed)
|
||||
return structured_output
|
||||
|
||||
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping[str, Any]) -> dict[str, Any]:
|
||||
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
|
||||
"""
|
||||
Prepare JSON schema based on model requirements.
|
||||
|
||||
@@ -418,49 +433,54 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
||||
"""
|
||||
|
||||
# Deep copy to avoid modifying the original schema
|
||||
processed_schema = deepcopy(schema)
|
||||
processed_schema_dict = dict(processed_schema)
|
||||
processed_schema = dict(deepcopy(schema))
|
||||
|
||||
# Convert boolean types to string types (common requirement)
|
||||
convert_boolean_to_string(processed_schema_dict)
|
||||
convert_boolean_to_string(processed_schema)
|
||||
|
||||
# Apply model-specific transformations
|
||||
if SpecialModelType.GEMINI in model_schema.model:
|
||||
remove_additional_properties(processed_schema_dict)
|
||||
return processed_schema_dict
|
||||
if SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema_dict
|
||||
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema_dict, "name": "llm_response"}
|
||||
remove_additional_properties(processed_schema)
|
||||
return processed_schema
|
||||
elif SpecialModelType.OLLAMA in provider:
|
||||
return processed_schema
|
||||
else:
|
||||
# Default format with name field
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict[str, Any]) -> None:
|
||||
def remove_additional_properties(schema: dict):
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Remove additionalProperties at current level
|
||||
schema.pop("additionalProperties", None)
|
||||
|
||||
# Process nested structures recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
remove_additional_properties(cast(dict[str, Any], value))
|
||||
remove_additional_properties(value)
|
||||
elif isinstance(value, list):
|
||||
for item in cast(list[Any], value):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
remove_additional_properties(cast(dict[str, Any], item))
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
def convert_boolean_to_string(schema: dict):
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
:param schema: JSON schema to modify in-place
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
return
|
||||
|
||||
# Check for boolean type at current level
|
||||
if schema.get("type") == "boolean":
|
||||
schema["type"] = "string"
|
||||
@@ -468,8 +488,8 @@ def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
# Process nested dictionaries and lists recursively
|
||||
for value in schema.values():
|
||||
if isinstance(value, dict):
|
||||
convert_boolean_to_string(cast(dict[str, Any], value))
|
||||
convert_boolean_to_string(value)
|
||||
elif isinstance(value, list):
|
||||
for item in cast(list[Any], value):
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
convert_boolean_to_string(cast(dict[str, Any], item))
|
||||
convert_boolean_to_string(item)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
MessageType = TypeVar("MessageType", bound=ToolInvokeMessage)
|
||||
MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage])
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,7 +87,7 @@ def merge_blob_chunks(
|
||||
),
|
||||
meta=resp.meta,
|
||||
)
|
||||
assert isinstance(merged_message, ToolInvokeMessage)
|
||||
assert isinstance(merged_message, (ToolInvokeMessage, AgentInvokeMessage))
|
||||
yield merged_message # type: ignore
|
||||
# Clean up the buffer
|
||||
del files[chunk_id]
|
||||
|
||||
@@ -14,8 +14,7 @@ from core.skill.entities import ToolAccessPolicy
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute, pipeline
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig
|
||||
from ..entities import DifyCli
|
||||
@@ -120,6 +119,21 @@ class SandboxBashSession:
|
||||
return self._bash_tool
|
||||
|
||||
def collect_output_files(self, output_dir: str = SANDBOX_OUTPUT_DIR) -> list[File]:
|
||||
"""
|
||||
Collect files from sandbox output directory and save them as ToolFiles.
|
||||
|
||||
Scans the specified output directory in sandbox, downloads each file,
|
||||
saves it as a ToolFile, and returns a list of File objects. The File
|
||||
objects will have valid tool_file_id that can be referenced by subsequent
|
||||
nodes via structured output.
|
||||
|
||||
Args:
|
||||
output_dir: Directory path in sandbox to scan for output files.
|
||||
Defaults to "output" (relative to workspace).
|
||||
|
||||
Returns:
|
||||
List of File objects representing the collected files.
|
||||
"""
|
||||
vm = self._sandbox.vm
|
||||
collected_files: list[File] = []
|
||||
|
||||
@@ -130,6 +144,8 @@ class SandboxBashSession:
|
||||
logger.debug("Failed to list sandbox output files in %s: %s", output_dir, exc)
|
||||
return collected_files
|
||||
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
for file_state in file_states:
|
||||
# Skip files that are too large
|
||||
if file_state.size > MAX_OUTPUT_FILE_SIZE:
|
||||
@@ -146,14 +162,47 @@ class SandboxBashSession:
|
||||
file_content = vm.download_file(file_state.path)
|
||||
file_binary = file_content.getvalue()
|
||||
|
||||
# Determine mime type from extension
|
||||
filename = os.path.basename(file_state.path)
|
||||
file_obj = self._create_tool_file(filename=filename, file_binary=file_binary)
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Save as ToolFile
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mime_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
# Determine file type from mime type
|
||||
file_type = _get_file_type_from_mime(mime_type)
|
||||
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
# Create File object with tool_file_id as related_id
|
||||
file_obj = File(
|
||||
id=tool_file.id, # Use tool_file_id as the File id for easy reference
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(file_binary),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
collected_files.append(file_obj)
|
||||
|
||||
logger.info(
|
||||
"Collected sandbox output file: %s -> tool_file_id=%s",
|
||||
file_state.path,
|
||||
file_obj.id,
|
||||
tool_file.id,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
@@ -167,85 +216,6 @@ class SandboxBashSession:
|
||||
)
|
||||
return collected_files
|
||||
|
||||
def download_file(self, path: str) -> File:
|
||||
path_kind = self._detect_path_kind(path)
|
||||
if path_kind == "dir":
|
||||
raise ValueError("Directory outputs are not supported")
|
||||
if path_kind != "file":
|
||||
raise ValueError(f"Sandbox file not found: {path}")
|
||||
|
||||
file_content = self._sandbox.vm.download_file(path)
|
||||
file_binary = file_content.getvalue()
|
||||
if len(file_binary) > MAX_OUTPUT_FILE_SIZE:
|
||||
raise ValueError(f"Sandbox file exceeds size limit: {path}")
|
||||
|
||||
filename = os.path.basename(path) or "file"
|
||||
return self._create_tool_file(filename=filename, file_binary=file_binary)
|
||||
|
||||
def _detect_path_kind(self, path: str) -> str:
|
||||
script = r"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
p = sys.argv[1]
|
||||
if os.path.isdir(p):
|
||||
print("dir")
|
||||
raise SystemExit(0)
|
||||
if os.path.isfile(p):
|
||||
print("file")
|
||||
raise SystemExit(0)
|
||||
print("none")
|
||||
raise SystemExit(2)
|
||||
"""
|
||||
try:
|
||||
result = execute(
|
||||
self._sandbox.vm,
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
'if command -v python3 >/dev/null 2>&1; then py=python3; else py=python; fi; "$py" -c "$0" "$@"',
|
||||
script,
|
||||
path,
|
||||
],
|
||||
timeout=10,
|
||||
error_message="Failed to inspect sandbox path",
|
||||
)
|
||||
except CommandExecutionError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
return result.stdout.decode("utf-8", errors="replace").strip()
|
||||
|
||||
def _create_tool_file(self, *, filename: str, file_binary: bytes) -> File:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
tool_file = ToolFileManager().create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mime_type,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
file_type = _get_file_type_from_mime(mime_type)
|
||||
extension = os.path.splitext(filename)[1] if "." in filename else ".bin"
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
id=tool_file.id,
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(file_binary),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_file_type_from_mime(mime_type: str) -> FileType:
|
||||
"""Determine FileType from mime type."""
|
||||
|
||||
@@ -57,7 +57,7 @@ class Tool(ABC):
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self.transform_tool_parameters_type(tool_parameters)
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
@@ -82,7 +82,7 @@ class Tool(ABC):
|
||||
else:
|
||||
return result
|
||||
|
||||
def transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none">
|
||||
<rect x="3" y="3" width="18" height="18" rx="4" stroke="#2F2F2F" stroke-width="1.5"/>
|
||||
<path d="M7 12h10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
|
||||
<path d="M12 7v10" stroke="#2F2F2F" stroke-width="1.5" stroke-linecap="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 332 B |
@@ -1,8 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AgentOutputProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
@@ -1,10 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: agent_output
|
||||
label:
|
||||
en_US: Agent Output
|
||||
description:
|
||||
en_US: Internal tools for agent output control.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
@@ -1,17 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class FinalOutputAnswerTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("Final answer recorded.")
|
||||
@@ -1,18 +0,0 @@
|
||||
identity:
|
||||
name: final_output_answer
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Final Output Answer
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool to deliver the final answer.
|
||||
llm: Use this tool when you are ready to provide the final answer.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
human_description:
|
||||
en_US: Final answer text.
|
||||
form: llm
|
||||
@@ -1,17 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class FinalStructuredOutputTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("Structured output recorded.")
|
||||
@@ -1,18 +0,0 @@
|
||||
identity:
|
||||
name: final_structured_output
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Final Structured Output
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool to deliver structured output.
|
||||
llm: Use this tool to provide structured output data.
|
||||
parameters:
|
||||
- name: data
|
||||
type: object
|
||||
required: true
|
||||
label:
|
||||
en_US: Data
|
||||
human_description:
|
||||
en_US: Structured output data.
|
||||
form: llm
|
||||
@@ -1,21 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class IllegalOutputTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
message = (
|
||||
"Protocol violation: do not output plain text. "
|
||||
"Call an output tool and finish with the configured terminal tool."
|
||||
)
|
||||
yield self.create_text_message(message)
|
||||
@@ -1,18 +0,0 @@
|
||||
identity:
|
||||
name: illegal_output
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Illegal Output
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool for output protocol violations.
|
||||
llm: Use this tool to correct output protocol violations.
|
||||
parameters:
|
||||
- name: raw
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Raw Output
|
||||
human_description:
|
||||
en_US: Raw model output that violated the protocol.
|
||||
form: llm
|
||||
@@ -1,17 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class OutputTextTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield self.create_text_message("Output recorded.")
|
||||
@@ -1,18 +0,0 @@
|
||||
identity:
|
||||
name: output_text
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Output Text
|
||||
description:
|
||||
human:
|
||||
en_US: Internal tool to store intermediate text output.
|
||||
llm: Use this tool to emit non-final text output.
|
||||
parameters:
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Text
|
||||
human_description:
|
||||
en_US: Output text.
|
||||
form: llm
|
||||
@@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -33,8 +31,9 @@ from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
@@ -67,8 +66,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
INTERNAL_BUILTIN_TOOL_PROVIDERS = {"agent_output"}
|
||||
|
||||
|
||||
class ApiProviderControllerItem(TypedDict):
|
||||
provider: ApiToolProvider
|
||||
@@ -364,7 +361,7 @@ class ToolManager:
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@@ -404,9 +401,9 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
workflow_tool: ToolEntity,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: Optional[VariablePool] = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@@ -594,10 +591,6 @@ class ToolManager:
|
||||
cls._hardcoded_providers = {}
|
||||
cls._builtin_providers_loaded = False
|
||||
|
||||
@classmethod
|
||||
def is_internal_builtin_provider(cls, provider_name: str) -> bool:
|
||||
return provider_name in INTERNAL_BUILTIN_TOOL_PROVIDERS
|
||||
|
||||
@classmethod
|
||||
def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
|
||||
"""
|
||||
@@ -635,9 +628,9 @@ class ToolManager:
|
||||
# MySQL: Use window function to achieve same result
|
||||
sql = """
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY tenant_id, provider
|
||||
PARTITION BY tenant_id, provider
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
) as rn
|
||||
FROM tool_builtin_providers
|
||||
@@ -674,8 +667,6 @@ class ToolManager:
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
if cls.is_internal_builtin_provider(provider.entity.identity.name):
|
||||
continue
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
@@ -1019,7 +1010,7 @@ class ToolManager:
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: Optional[VariablePool],
|
||||
variable_pool: Optional["VariablePool"],
|
||||
tool_configurations: dict[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -577,12 +577,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_schema=None,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_valid
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMStructuredOutput, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCall, ToolCallResult
|
||||
@@ -156,9 +156,6 @@ class LLMGenerationData(BaseModel):
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
structured_output: LLMStructuredOutput | None = Field(
|
||||
default=None, description="Structured output from tool-only agent runs"
|
||||
)
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
@@ -287,7 +284,6 @@ class AggregatedResult(BaseModel):
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
@@ -387,7 +383,7 @@ class LLMNodeData(BaseNodeData):
|
||||
Strategy for handling model reasoning output.
|
||||
|
||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
workflow variable access: {{#node_id.reasoning_content#}}
|
||||
|
||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||
|
||||
@@ -257,8 +257,8 @@ def _build_file_descriptions(files: Sequence[Any]) -> str:
|
||||
"""
|
||||
Build a text description of generated files for inclusion in context.
|
||||
|
||||
The description includes file_id for context; structured output file paths
|
||||
are only supported in sandbox mode.
|
||||
The description includes file_id which can be used by subsequent nodes
|
||||
to reference the files via structured output.
|
||||
"""
|
||||
if not files:
|
||||
return ""
|
||||
|
||||
@@ -13,24 +13,13 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.output_tools import (
|
||||
FINAL_OUTPUT_TOOL,
|
||||
FINAL_STRUCTURED_OUTPUT_TOOL,
|
||||
OUTPUT_TEXT_TOOL,
|
||||
build_agent_output_tools,
|
||||
)
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.file import FileTransferMethod, FileType, file_manager
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
adapt_schema_for_sandbox_file_paths,
|
||||
convert_sandbox_file_paths_in_output,
|
||||
detect_file_path_fields,
|
||||
)
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@@ -73,7 +62,6 @@ from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.skill.skill_compiler import SkillCompiler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables import (
|
||||
@@ -197,11 +185,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
usage: LLMUsage = LLMUsage.empty_usage()
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str = "" # Initialize as empty string for consistency
|
||||
clean_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = "" # Initialize as empty string for consistency
|
||||
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
||||
variable_pool: VariablePool = self.graph_runtime_state.variable_pool
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# Parse prompt template to separate static messages and context references
|
||||
@@ -261,9 +250,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
memory_config = self.node_data.memory
|
||||
if memory_config:
|
||||
query = memory_config.query_prompt_template
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
@@ -305,23 +293,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sandbox=self.graph_runtime_state.sandbox,
|
||||
)
|
||||
|
||||
structured_output_schema: Mapping[str, Any] | None
|
||||
structured_output_file_paths: list[str] = []
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
if not self.node_data.structured_output:
|
||||
raise ValueError("structured_output_enabled is True but structured_output is not set")
|
||||
raw_schema = LLMNode.fetch_structured_output_schema(structured_output=self.node_data.structured_output)
|
||||
if self.node_data.computer_use:
|
||||
structured_output_schema, structured_output_file_paths = adapt_schema_for_sandbox_file_paths(
|
||||
raw_schema
|
||||
)
|
||||
else:
|
||||
if detect_file_path_fields(raw_schema):
|
||||
raise LLMNodeError("Structured output file paths are only supported in sandbox mode.")
|
||||
structured_output_schema = raw_schema
|
||||
else:
|
||||
structured_output_schema = None
|
||||
# Variables for outputs
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
if self.node_data.computer_use:
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
@@ -335,10 +309,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=stop,
|
||||
variable_pool=variable_pool,
|
||||
tool_dependencies=tool_dependencies,
|
||||
structured_output_schema=structured_output_schema,
|
||||
structured_output_file_paths=structured_output_file_paths,
|
||||
)
|
||||
|
||||
elif self.tool_call_enabled:
|
||||
generator = self._invoke_llm_with_tools(
|
||||
model_instance=model_instance,
|
||||
@@ -348,7 +319,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
else:
|
||||
# Use traditional LLM invocation
|
||||
@@ -358,7 +328,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_schema=structured_output_schema,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
@@ -384,8 +355,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
reasoning_content = ""
|
||||
usage = generation_data.usage
|
||||
finish_reason = generation_data.finish_reason
|
||||
if generation_data.structured_output:
|
||||
structured_output = generation_data.structured_output
|
||||
|
||||
# Unified process_data building
|
||||
process_data = {
|
||||
@@ -409,7 +378,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if tool.enabled
|
||||
]
|
||||
|
||||
# Build generation field and determine files_to_output first
|
||||
# Unified outputs building
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
|
||||
}
|
||||
|
||||
# Build generation field
|
||||
if generation_data:
|
||||
# Use generation_data from tool invocation (supports multi-turn)
|
||||
generation = {
|
||||
@@ -437,15 +415,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
}
|
||||
files_to_output = self._file_outputs
|
||||
|
||||
# Unified outputs building (files passed to context for subsequent node reference)
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data, files=files_to_output),
|
||||
}
|
||||
|
||||
outputs["generation"] = generation
|
||||
if files_to_output:
|
||||
outputs["files"] = ArrayFileSegment(value=files_to_output)
|
||||
@@ -524,7 +493,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_schema: Mapping[str, Any] | None,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
@@ -538,7 +508,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if structured_output_schema:
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
@@ -546,12 +519,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=structured_output_schema,
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
@@ -1288,16 +1261,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content_type = type(prompt_content)
|
||||
if prompt_content_type == str:
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
@@ -1307,12 +1282,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if isinstance(prompt_content, str):
|
||||
if prompt_content_type == str:
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
elif prompt_content_type == list:
|
||||
prompt_content = prompt_content if isinstance(prompt_content, list) else []
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
@@ -1438,11 +1414,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if isinstance(item, PromptMessageContext):
|
||||
if len(item.value_selector) >= 2:
|
||||
prompt_context_selectors.append(item.value_selector)
|
||||
elif isinstance(item, LLMNodeChatModelMessage):
|
||||
elif isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=item.text)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
else:
|
||||
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
@@ -1478,14 +1452,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
if prompt_template.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
else:
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type == "jinja2":
|
||||
if isinstance(prompt_template, list):
|
||||
for item in prompt_template:
|
||||
if isinstance(item, LLMNodeChatModelMessage) and item.edition_type == "jinja2":
|
||||
enable_jinja = True
|
||||
break
|
||||
else:
|
||||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
@@ -1898,7 +1871,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
structured_output_schema: Mapping[str, Any] | None,
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Invoke LLM with tools support (from Agent V2).
|
||||
|
||||
@@ -1915,16 +1887,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=tool_instances,
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 10,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
# Run strategy
|
||||
@@ -1932,6 +1900,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
@@ -1945,12 +1914,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None,
|
||||
variable_pool: VariablePool,
|
||||
tool_dependencies: ToolDependencies | None,
|
||||
structured_output_schema: Mapping[str, Any] | None,
|
||||
structured_output_file_paths: Sequence[str] | None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData]:
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
result: LLMGenerationData | None = None
|
||||
sandbox_output_files: list[File] = []
|
||||
structured_output_files: list[File] = []
|
||||
|
||||
# FIXME(Mairuis): Async processing for bash session.
|
||||
with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_dependencies) as session:
|
||||
@@ -1958,9 +1923,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=[session.bash_tool],
|
||||
@@ -1968,55 +1930,20 @@ class LLMNode(Node[LLMNodeData]):
|
||||
max_iterations=self._node_data.max_iterations or 100,
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
|
||||
if result and result.structured_output and structured_output_file_paths:
|
||||
structured_output_payload = result.structured_output.structured_output or {}
|
||||
try:
|
||||
converted_output, structured_output_files = convert_sandbox_file_paths_in_output(
|
||||
output=structured_output_payload,
|
||||
file_path_fields=structured_output_file_paths,
|
||||
file_resolver=session.download_file,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise LLMNodeError(str(exc)) from exc
|
||||
result = result.model_copy(
|
||||
update={"structured_output": LLMStructuredOutput(structured_output=converted_output)}
|
||||
)
|
||||
|
||||
# Collect output files from sandbox before session ends
|
||||
# Files are saved as ToolFiles with valid tool_file_id for later reference
|
||||
sandbox_output_files = session.collect_output_files()
|
||||
|
||||
if result is None:
|
||||
raise LLMNodeError("SandboxSession exited unexpectedly")
|
||||
|
||||
structured_output = result.structured_output
|
||||
if structured_output is not None:
|
||||
yield structured_output
|
||||
|
||||
# Merge sandbox output files into result
|
||||
if sandbox_output_files or structured_output_files:
|
||||
result = LLMGenerationData(
|
||||
text=result.text,
|
||||
reasoning_contents=result.reasoning_contents,
|
||||
tool_calls=result.tool_calls,
|
||||
sequence=result.sequence,
|
||||
usage=result.usage,
|
||||
finish_reason=result.finish_reason,
|
||||
files=result.files + sandbox_output_files + structured_output_files,
|
||||
trace=result.trace,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||
@@ -2084,20 +2011,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||
continue
|
||||
|
||||
structured_output_schema = None
|
||||
if self._node_data.structured_output_enabled:
|
||||
structured_output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=self._node_data.structured_output or {},
|
||||
)
|
||||
tool_instances.extend(
|
||||
build_agent_output_tools(
|
||||
tenant_id=self.tenant_id,
|
||||
invoke_from=self.invoke_from,
|
||||
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
|
||||
structured_output_schema=structured_output_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_instances
|
||||
|
||||
def _extract_prompt_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
@@ -2261,45 +2174,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Add tool call to pending list for model segment
|
||||
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
|
||||
|
||||
output_tool_names = {OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL, FINAL_STRUCTURED_OUTPUT_TOOL}
|
||||
|
||||
if tool_name not in output_tool_names:
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
icon=tool_icon,
|
||||
icon_dark=tool_icon_dark,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if tool_name in output_tool_names:
|
||||
content = ""
|
||||
if tool_name in (OUTPUT_TEXT_TOOL, FINAL_OUTPUT_TOOL):
|
||||
content = payload.tool_args["text"]
|
||||
elif tool_name == FINAL_STRUCTURED_OUTPUT_TOOL:
|
||||
raw_content = json.dumps(
|
||||
payload.tool_args["data"],
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
content = f"```json\n{raw_content}\n```"
|
||||
|
||||
if content:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=content,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=content,
|
||||
is_final=False,
|
||||
)
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
icon=tool_icon,
|
||||
icon_dark=tool_icon_dark,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
@@ -2543,7 +2429,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
content_position = 0
|
||||
tool_call_seen_index: dict[str, int] = {}
|
||||
for trace_segment in trace_state.trace_segments:
|
||||
# FIXME: These if will never happen
|
||||
if trace_segment.type == "thought":
|
||||
sequence.append({"type": "reasoning", "index": reasoning_index})
|
||||
reasoning_index += 1
|
||||
@@ -2586,15 +2471,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
||||
)
|
||||
|
||||
text_content: str
|
||||
if aggregate.text:
|
||||
text_content = aggregate.text
|
||||
elif aggregate.structured_output:
|
||||
text_content = json.dumps(aggregate.structured_output.structured_output)
|
||||
else:
|
||||
raise ValueError("Aggregate must have either text or structured output.")
|
||||
return LLMGenerationData(
|
||||
text=text_content,
|
||||
text=aggregate.text,
|
||||
reasoning_contents=buffers.reasoning_per_turn,
|
||||
tool_calls=tool_calls_for_generation,
|
||||
sequence=sequence,
|
||||
@@ -2602,7 +2480,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
finish_reason=aggregate.finish_reason,
|
||||
files=aggregate.files,
|
||||
trace=trace_state.trace_segments,
|
||||
structured_output=aggregate.structured_output,
|
||||
)
|
||||
|
||||
def _process_tool_outputs(
|
||||
@@ -2613,33 +2490,22 @@ class LLMNode(Node[LLMNodeData]):
|
||||
state = ToolOutputState()
|
||||
|
||||
try:
|
||||
while True:
|
||||
output = next(outputs)
|
||||
for output in outputs:
|
||||
if isinstance(output, AgentLog):
|
||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||
else:
|
||||
continue
|
||||
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||
except StopIteration as exception:
|
||||
if not isinstance(exception.value, AgentResult):
|
||||
raise ValueError(f"Unexpected output type: {type(exception.value)}") from exception
|
||||
state.agent.agent_result = exception.value
|
||||
agent_result = state.agent.agent_result
|
||||
if not agent_result:
|
||||
raise ValueError("No agent result found in tool outputs")
|
||||
output_payload = agent_result.output
|
||||
if isinstance(output_payload, dict):
|
||||
state.aggregate.structured_output = LLMStructuredOutput(structured_output=output_payload)
|
||||
state.aggregate.text = json.dumps(output_payload)
|
||||
elif isinstance(output_payload, str):
|
||||
state.aggregate.text = output_payload
|
||||
else:
|
||||
raise ValueError(f"Unexpected output type: {type(output_payload)}")
|
||||
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||
state.agent.agent_result = exception.value
|
||||
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
if state.agent.agent_result:
|
||||
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
@@ -156,7 +156,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_schema=None,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
|
||||
@@ -91,8 +91,6 @@ class BuiltinToolManageService:
|
||||
:return: the list of tools
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
|
||||
return []
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
result: list[ToolApiEntity] = []
|
||||
@@ -543,8 +541,6 @@ class BuiltinToolManageService:
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
if ToolManager.is_internal_builtin_provider(provider_controller.entity.identity.name):
|
||||
continue
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
|
||||
5
api/tests/fixtures/file output schema.yml
vendored
5
api/tests/fixtures/file output schema.yml
vendored
@@ -126,8 +126,9 @@ workflow:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
image:
|
||||
description: Sandbox file path of the selected image
|
||||
type: file
|
||||
description: File ID (UUID) of the selected image
|
||||
format: dify-file-ref
|
||||
type: string
|
||||
required:
|
||||
- image
|
||||
type: object
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
@@ -63,7 +64,7 @@ def test_cot_output_parser():
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.model_dump() == test_case["action"]
|
||||
output += result.model_dump_json()
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
|
||||
@@ -13,7 +13,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
class ConcreteAgentPattern(AgentPattern):
|
||||
"""Concrete implementation of AgentPattern for testing."""
|
||||
|
||||
def run(self, prompt_messages, model_parameters, stop=[]):
|
||||
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
|
||||
"""Minimal implementation for testing."""
|
||||
yield from []
|
||||
|
||||
|
||||
@@ -329,15 +329,13 @@ class TestAgentLogProcessing:
|
||||
)
|
||||
|
||||
result = AgentResult(
|
||||
output="Final answer",
|
||||
text="Final answer",
|
||||
files=[],
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
output_payload = result.output
|
||||
assert isinstance(output_payload, str)
|
||||
assert output_payload == "Final answer"
|
||||
assert result.text == "Final answer"
|
||||
assert result.files == []
|
||||
assert result.usage == usage
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestAgentScratchpadUnit:
|
||||
action_input={"query": "test"},
|
||||
)
|
||||
|
||||
result = action.model_dump()
|
||||
result = action.to_dict()
|
||||
|
||||
assert result == {
|
||||
"action": "search",
|
||||
|
||||
@@ -1,93 +1,269 @@
|
||||
"""
|
||||
Unit tests for sandbox file path detection and conversion.
|
||||
Unit tests for file reference detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_PATH_DESCRIPTION_SUFFIX,
|
||||
adapt_schema_for_sandbox_file_paths,
|
||||
convert_sandbox_file_paths_in_output,
|
||||
detect_file_path_fields,
|
||||
is_file_path_property,
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
def _build_file(file_id: str) -> File:
|
||||
return File(
|
||||
id=file_id,
|
||||
tenant_id="tenant_123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename="test.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=128,
|
||||
related_id=file_id,
|
||||
storage_key="sandbox/path",
|
||||
)
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
|
||||
class TestFilePathSchema:
|
||||
def test_is_file_path_property(self):
|
||||
assert is_file_path_property({"type": "file"}) is True
|
||||
assert is_file_path_property({"type": "string", "format": "dify-file-ref"}) is True
|
||||
assert is_file_path_property({"type": "string"}) is False
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
|
||||
def test_detect_file_path_fields(self):
|
||||
def test_simple_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": "dify-file-ref"},
|
||||
"files": {"type": "array", "items": {"type": "string", "format": "dify-file-ref"}},
|
||||
"meta": {"type": "object", "properties": {"doc": {"type": "file"}}},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
assert set(detect_file_path_fields(schema)) == {"image", "files[*]", "meta.doc"}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
|
||||
def test_adapt_schema_for_sandbox_file_paths(self):
|
||||
def test_multiple_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": "dify-file-ref"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
adapted, fields = adapt_schema_for_sandbox_file_paths(schema)
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
|
||||
assert set(fields) == {"image"}
|
||||
adapted_image = adapted["properties"]["image"]
|
||||
assert adapted_image["type"] == "string"
|
||||
assert "format" not in adapted_image
|
||||
assert FILE_PATH_DESCRIPTION_SUFFIX in adapted_image["description"]
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
|
||||
|
||||
class TestConvertSandboxFilePaths:
|
||||
def test_convert_sandbox_file_paths(self):
|
||||
output = {"image": "a.png", "files": ["b.png", "c.png"], "name": "demo"}
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
|
||||
def resolver(path: str) -> File:
|
||||
return _build_file(path)
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
|
||||
converted, files = convert_sandbox_file_paths_in_output(output, ["image", "files[*]"], resolver)
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
|
||||
assert isinstance(converted["image"], FileSegment)
|
||||
assert isinstance(converted["files"], ArrayFileSegment)
|
||||
assert converted["name"] == "demo"
|
||||
assert [file.id for file in files] == ["a.png", "b.png", "c.png"]
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
def test_invalid_path_value_raises(self):
|
||||
def resolver(path: str) -> File:
|
||||
return _build_file(path)
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
convert_sandbox_file_paths_in_output({"image": 123}, ["image"], resolver)
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
|
||||
def test_no_file_paths_returns_output(self):
|
||||
output = {"name": "demo"}
|
||||
converted, files = convert_sandbox_file_paths_in_output(output, [], _build_file)
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert converted == output
|
||||
assert files == []
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from decimal import Decimal
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -8,14 +6,16 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
_get_default_value_for_type,
|
||||
fill_defaults_from_schema,
|
||||
get_default_value_for_type,
|
||||
invoke_llm_with_pydantic_model,
|
||||
invoke_llm_with_structured_output,
|
||||
)
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
@@ -25,29 +25,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputTestCase(TypedDict):
|
||||
name: str
|
||||
provider: str
|
||||
model_name: str
|
||||
support_structure_output: bool
|
||||
stream: bool
|
||||
json_schema: Mapping[str, Any]
|
||||
expected_llm_response: LLMResult
|
||||
expected_result_type: type[LLMResultWithStructuredOutput] | None
|
||||
should_raise: bool
|
||||
expected_error: NotRequired[type[OutputParserError]]
|
||||
parameter_rules: NotRequired[list[ParameterRule]]
|
||||
|
||||
|
||||
SchemaData = dict[str, Any]
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
|
||||
|
||||
def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage:
|
||||
@@ -93,7 +71,7 @@ def get_model_instance() -> MagicMock:
|
||||
def test_structured_output_parser():
|
||||
"""Test cases for invoke_llm_with_structured_output function"""
|
||||
|
||||
testcases: list[StructuredOutputTestCase] = [
|
||||
testcases = [
|
||||
# Test case 1: Model with native structured output support, non-streaming
|
||||
{
|
||||
"name": "native_structured_output_non_streaming",
|
||||
@@ -110,6 +88,39 @@ def test_structured_output_parser():
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 2: Model with native structured output support, streaming
|
||||
{
|
||||
"name": "native_structured_output_streaming",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 3: Model without native structured output support, non-streaming
|
||||
{
|
||||
"name": "prompt_based_structured_output_non_streaming",
|
||||
@@ -126,24 +137,78 @@ def test_structured_output_parser():
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 4: Model without native structured output support, streaming
|
||||
{
|
||||
"name": "non_streaming_with_list_content",
|
||||
"name": "prompt_based_structured_output_streaming",
|
||||
"provider": "anthropic",
|
||||
"model_name": "claude-3-sonnet",
|
||||
"support_structure_output": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"answer": {"type": "string"}}},
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"answer": "test'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="claude-3-sonnet",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' response"}'),
|
||||
usage=create_mock_usage(prompt_tokens=15, completion_tokens=5),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 5: Streaming with list content
|
||||
{
|
||||
"name": "streaming_with_list_content",
|
||||
"provider": "openai",
|
||||
"model_name": "gpt-4o",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"stream": True,
|
||||
"json_schema": {"type": "object", "properties": {"data": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
"expected_llm_response": [
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"data":'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=2),
|
||||
),
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=' "value"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=3),
|
||||
),
|
||||
),
|
||||
],
|
||||
"expected_result_type": "generator",
|
||||
"should_raise": False,
|
||||
},
|
||||
# Test case 6: Error case - non-string LLM response content (non-streaming)
|
||||
@@ -188,13 +253,7 @@ def test_structured_output_parser():
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(en_US="response_format"),
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
options=["json_schema"],
|
||||
),
|
||||
MagicMock(name="response_format", options=["json_schema"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="gpt-4o",
|
||||
@@ -213,13 +272,7 @@ def test_structured_output_parser():
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"output": {"type": "string"}}},
|
||||
"parameter_rules": [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(en_US="response_format"),
|
||||
type=ParameterType.STRING,
|
||||
required=False,
|
||||
options=["JSON"],
|
||||
),
|
||||
MagicMock(name="response_format", options=["JSON"], required=False),
|
||||
],
|
||||
"expected_llm_response": LLMResult(
|
||||
model="claude-3-sonnet",
|
||||
@@ -232,72 +285,89 @@ def test_structured_output_parser():
|
||||
]
|
||||
|
||||
for case in testcases:
|
||||
provider = case["provider"]
|
||||
model_name = case["model_name"]
|
||||
support_structure_output = case["support_structure_output"]
|
||||
json_schema = case["json_schema"]
|
||||
stream = case["stream"]
|
||||
# Setup model entity
|
||||
model_schema = get_model_entity(case["provider"], case["model_name"], case["support_structure_output"])
|
||||
|
||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
||||
|
||||
parameter_rules = case.get("parameter_rules")
|
||||
if parameter_rules is not None:
|
||||
model_schema.parameter_rules = parameter_rules
|
||||
# Add parameter rules if specified
|
||||
if "parameter_rules" in case:
|
||||
model_schema.parameter_rules = case["parameter_rules"]
|
||||
|
||||
# Setup model instance
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = case["expected_llm_response"]
|
||||
|
||||
# Setup prompt messages
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Generate a response according to the schema."),
|
||||
]
|
||||
|
||||
if case["should_raise"]:
|
||||
expected_error = case.get("expected_error", OutputParserError)
|
||||
with pytest.raises(expected_error): # noqa: PT012
|
||||
if stream:
|
||||
# Test error cases
|
||||
with pytest.raises(case["expected_error"]): # noqa: PT012
|
||||
if case["stream"]:
|
||||
result_generator = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
json_schema=case["json_schema"],
|
||||
)
|
||||
# Consume the generator to trigger the error
|
||||
list(result_generator)
|
||||
else:
|
||||
invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
json_schema=case["json_schema"],
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
provider=case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
json_schema=case["json_schema"],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 100},
|
||||
user="test_user",
|
||||
)
|
||||
|
||||
expected_result_type = case["expected_result_type"]
|
||||
assert expected_result_type is not None
|
||||
assert isinstance(result, expected_result_type)
|
||||
assert result.model == model_name
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
if case["expected_result_type"] == "generator":
|
||||
# Test streaming results
|
||||
assert hasattr(result, "__iter__")
|
||||
chunks = list(result)
|
||||
assert len(chunks) > 0
|
||||
|
||||
# Verify all chunks are LLMResultChunkWithStructuredOutput
|
||||
for chunk in chunks[:-1]: # All except last
|
||||
assert isinstance(chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert chunk.model == case["model_name"]
|
||||
|
||||
# Last chunk should have structured output
|
||||
last_chunk = chunks[-1]
|
||||
assert isinstance(last_chunk, LLMResultChunkWithStructuredOutput)
|
||||
assert last_chunk.structured_output is not None
|
||||
assert isinstance(last_chunk.structured_output, dict)
|
||||
else:
|
||||
# Test non-streaming results
|
||||
assert isinstance(result, case["expected_result_type"])
|
||||
assert result.model == case["model_name"]
|
||||
assert result.structured_output is not None
|
||||
assert isinstance(result.structured_output, dict)
|
||||
|
||||
# Verify model_instance.invoke_llm was called with correct parameters
|
||||
model_instance.invoke_llm.assert_called_once()
|
||||
call_args = model_instance.invoke_llm.call_args
|
||||
|
||||
assert call_args.kwargs["stream"] == stream
|
||||
assert call_args.kwargs["stream"] == case["stream"]
|
||||
assert call_args.kwargs["user"] == "test_user"
|
||||
assert "temperature" in call_args.kwargs["model_parameters"]
|
||||
assert "max_tokens" in call_args.kwargs["model_parameters"]
|
||||
@@ -306,32 +376,45 @@ def test_structured_output_parser():
|
||||
def test_parse_structured_output_edge_cases():
|
||||
"""Test edge cases for structured output parsing"""
|
||||
|
||||
provider = "deepseek"
|
||||
model_name = "deepseek-r1"
|
||||
support_structure_output = False
|
||||
json_schema: SchemaData = {"type": "object", "properties": {"thought": {"type": "string"}}}
|
||||
expected_llm_response = LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
# Test case with list that contains dict (reasoning model scenario)
|
||||
testcase_list_with_dict = {
|
||||
"name": "list_with_dict_parsing",
|
||||
"provider": "deepseek",
|
||||
"model_name": "deepseek-r1",
|
||||
"support_structure_output": False,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"thought": {"type": "string"}}},
|
||||
"expected_llm_response": LLMResult(
|
||||
model="deepseek-r1",
|
||||
message=AssistantPromptMessage(content='[{"thought": "reasoning process"}, "other content"]'),
|
||||
usage=create_mock_usage(prompt_tokens=10, completion_tokens=5),
|
||||
),
|
||||
"expected_result_type": LLMResultWithStructuredOutput,
|
||||
"should_raise": False,
|
||||
}
|
||||
|
||||
# Setup for list parsing test
|
||||
model_schema = get_model_entity(
|
||||
testcase_list_with_dict["provider"],
|
||||
testcase_list_with_dict["model_name"],
|
||||
testcase_list_with_dict["support_structure_output"],
|
||||
)
|
||||
|
||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = expected_llm_response
|
||||
model_instance.invoke_llm.return_value = testcase_list_with_dict["expected_llm_response"]
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
provider=testcase_list_with_dict["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
json_schema=testcase_list_with_dict["json_schema"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@@ -341,16 +424,18 @@ def test_parse_structured_output_edge_cases():
|
||||
def test_model_specific_schema_preparation():
|
||||
"""Test schema preparation for different model types"""
|
||||
|
||||
provider = "google"
|
||||
model_name = "gemini-pro"
|
||||
support_structure_output = True
|
||||
json_schema: SchemaData = {
|
||||
"type": "object",
|
||||
"properties": {"result": {"type": "boolean"}},
|
||||
"additionalProperties": False,
|
||||
# Test Gemini model
|
||||
gemini_case = {
|
||||
"provider": "google",
|
||||
"model_name": "gemini-pro",
|
||||
"support_structure_output": True,
|
||||
"stream": False,
|
||||
"json_schema": {"type": "object", "properties": {"result": {"type": "boolean"}}, "additionalProperties": False},
|
||||
}
|
||||
|
||||
model_schema = get_model_entity(provider, model_name, support_structure_output)
|
||||
model_schema = get_model_entity(
|
||||
gemini_case["provider"], gemini_case["model_name"], gemini_case["support_structure_output"]
|
||||
)
|
||||
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
@@ -362,11 +447,11 @@ def test_model_specific_schema_preparation():
|
||||
prompt_messages = [UserPromptMessage(content="Test")]
|
||||
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
provider=gemini_case["provider"],
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
json_schema=gemini_case["json_schema"],
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@@ -408,26 +493,40 @@ def test_structured_output_with_pydantic_model_non_streaming():
|
||||
assert result.name == "test"
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_list_content():
|
||||
def test_structured_output_with_pydantic_model_streaming():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data='{"name":'),
|
||||
TextPromptMessageContent(data=' "test"}'),
|
||||
]
|
||||
),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
)
|
||||
|
||||
def mock_streaming_response():
|
||||
yield LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content='{"name":'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=2),
|
||||
),
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model="gpt-4o",
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
system_fingerprint="test",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=' "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
),
|
||||
)
|
||||
|
||||
model_instance.invoke_llm.return_value = mock_streaming_response()
|
||||
|
||||
result = invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="Return a JSON object with name.")],
|
||||
output_model=ExampleOutput,
|
||||
output_model=ExampleOutput
|
||||
)
|
||||
|
||||
assert isinstance(result, ExampleOutput)
|
||||
@@ -449,51 +548,51 @@ def test_structured_output_with_pydantic_model_validation_error():
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
output_model=ExampleOutput,
|
||||
output_model=ExampleOutput
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultValueForType:
|
||||
"""Test cases for get_default_value_for_type function"""
|
||||
"""Test cases for _get_default_value_for_type function"""
|
||||
|
||||
def test_string_type(self):
|
||||
assert get_default_value_for_type("string") == ""
|
||||
assert _get_default_value_for_type("string") == ""
|
||||
|
||||
def test_object_type(self):
|
||||
assert get_default_value_for_type("object") == {}
|
||||
assert _get_default_value_for_type("object") == {}
|
||||
|
||||
def test_array_type(self):
|
||||
assert get_default_value_for_type("array") == []
|
||||
assert _get_default_value_for_type("array") == []
|
||||
|
||||
def test_number_type(self):
|
||||
assert get_default_value_for_type("number") == 0
|
||||
assert _get_default_value_for_type("number") == 0
|
||||
|
||||
def test_integer_type(self):
|
||||
assert get_default_value_for_type("integer") == 0
|
||||
assert _get_default_value_for_type("integer") == 0
|
||||
|
||||
def test_boolean_type(self):
|
||||
assert get_default_value_for_type("boolean") is False
|
||||
assert _get_default_value_for_type("boolean") is False
|
||||
|
||||
def test_null_type(self):
|
||||
assert get_default_value_for_type("null") is None
|
||||
assert _get_default_value_for_type("null") is None
|
||||
|
||||
def test_none_type(self):
|
||||
assert get_default_value_for_type(None) is None
|
||||
assert _get_default_value_for_type(None) is None
|
||||
|
||||
def test_unknown_type(self):
|
||||
assert get_default_value_for_type("unknown") is None
|
||||
assert _get_default_value_for_type("unknown") is None
|
||||
|
||||
def test_union_type_string_null(self):
|
||||
# ["string", "null"] should return "" (first non-null type)
|
||||
assert get_default_value_for_type(["string", "null"]) == ""
|
||||
assert _get_default_value_for_type(["string", "null"]) == ""
|
||||
|
||||
def test_union_type_null_first(self):
|
||||
# ["null", "integer"] should return 0 (first non-null type)
|
||||
assert get_default_value_for_type(["null", "integer"]) == 0
|
||||
assert _get_default_value_for_type(["null", "integer"]) == 0
|
||||
|
||||
def test_union_type_only_null(self):
|
||||
# ["null"] should return None
|
||||
assert get_default_value_for_type(["null"]) is None
|
||||
assert _get_default_value_for_type(["null"]) is None
|
||||
|
||||
|
||||
class TestFillDefaultsFromSchema:
|
||||
@@ -501,7 +600,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_simple_required_fields(self):
|
||||
"""Test filling simple required fields"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
@@ -510,7 +609,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
}
|
||||
output: SchemaData = {"name": "Alice"}
|
||||
output = {"name": "Alice"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@@ -520,7 +619,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_non_required_fields_not_filled(self):
|
||||
"""Test that non-required fields are not filled"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
@@ -528,7 +627,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["required_field"],
|
||||
}
|
||||
output: SchemaData = {}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@@ -537,7 +636,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_nested_object_required_fields(self):
|
||||
"""Test filling nested object required fields"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
@@ -560,7 +659,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["user"],
|
||||
}
|
||||
output: SchemaData = {
|
||||
output = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"address": {
|
||||
@@ -585,7 +684,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_missing_nested_object_created(self):
|
||||
"""Test that missing required nested objects are created"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
@@ -599,7 +698,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["metadata"],
|
||||
}
|
||||
output: SchemaData = {}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@@ -611,7 +710,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_all_types_default_values(self):
|
||||
"""Test default values for all types"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"str_field": {"type": "string"},
|
||||
@@ -623,7 +722,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["str_field", "int_field", "num_field", "bool_field", "arr_field", "obj_field"],
|
||||
}
|
||||
output: SchemaData = {}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@@ -638,7 +737,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_existing_values_preserved(self):
|
||||
"""Test that existing values are not overwritten"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
@@ -646,7 +745,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["name", "count"],
|
||||
}
|
||||
output: SchemaData = {"name": "Bob", "count": 42}
|
||||
output = {"name": "Bob", "count": 42}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@@ -654,7 +753,7 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_complex_nested_structure(self):
|
||||
"""Test complex nested structure with multiple levels"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
@@ -690,7 +789,7 @@ class TestFillDefaultsFromSchema:
|
||||
},
|
||||
"required": ["user", "tags", "metadata", "is_active"],
|
||||
}
|
||||
output: SchemaData = {
|
||||
output = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"age": 25,
|
||||
@@ -730,8 +829,8 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_empty_schema(self):
|
||||
"""Test with empty schema"""
|
||||
schema: SchemaData = {}
|
||||
output: SchemaData = {"any": "value"}
|
||||
schema = {}
|
||||
output = {"any": "value"}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
@@ -739,14 +838,14 @@ class TestFillDefaultsFromSchema:
|
||||
|
||||
def test_schema_without_required(self):
|
||||
"""Test schema without required field"""
|
||||
schema: SchemaData = {
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"optional1": {"type": "string"},
|
||||
"optional2": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
output: SchemaData = {}
|
||||
output = {}
|
||||
|
||||
result = fill_defaults_from_schema(output, schema)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user