diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 966c34a0d7..b864d3b368 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -18,13 +18,14 @@ from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageContentUnionTypes, PromptMessageRole, + ToolPromptMessage, ) from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.entities import LLMGenerationData, ModelConfig from core.workflow.runtime import VariablePool from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -214,21 +215,87 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs def build_context( prompt_messages: Sequence[PromptMessage], assistant_response: str, + generation_data: LLMGenerationData | None = None, ) -> list[PromptMessage]: """ Build context from prompt messages and assistant response. Excludes system messages and includes the current LLM response. Returns list[PromptMessage] for use with ArrayPromptMessageSegment. + For tool-enabled runs, reconstructs the full conversation including tool calls and results. Note: Multi-modal content base64 data is truncated to avoid storing large data in context. + + Args: + prompt_messages: Initial prompt messages (user query, etc.) + assistant_response: Final assistant response text + generation_data: Optional generation data containing trace for tool-enabled runs """ + context_messages: list[PromptMessage] = [ _truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM ] - context_messages.append(AssistantPromptMessage(content=assistant_response)) + + # For tool-enabled runs, reconstruct messages from trace + if generation_data and generation_data.trace: + context_messages.extend(_build_messages_from_trace(generation_data, assistant_response)) + else: + context_messages.append(AssistantPromptMessage(content=assistant_response)) + return context_messages +def _build_messages_from_trace( + generation_data: LLMGenerationData, + assistant_response: str, +) -> list[PromptMessage]: + """ + Build assistant and tool messages from trace segments. + + Processes trace in order to reconstruct the conversation flow: + - Model segments with tool_calls -> AssistantPromptMessage with tool_calls + - Tool segments -> ToolPromptMessage with result + - Final response -> AssistantPromptMessage with assistant_response + """ + from core.workflow.nodes.llm.entities import ModelTraceSegment, ToolTraceSegment + + messages: list[PromptMessage] = [] + + for segment in generation_data.trace: + if segment.type == "model" and isinstance(segment.output, ModelTraceSegment): + model_output = segment.output + segment_content = model_output.text or "" + + if model_output.tool_calls: + # Build tool_calls for AssistantPromptMessage + tool_calls = [ + AssistantPromptMessage.ToolCall( + id=tc.id or "", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tc.name or "", + arguments=tc.arguments or "", + ), + ) + for tc in model_output.tool_calls + ] + messages.append(AssistantPromptMessage(content=segment_content, tool_calls=tool_calls)) + + elif segment.type == "tool" and isinstance(segment.output, ToolTraceSegment): + tool_output = segment.output + messages.append( + ToolPromptMessage( + content=tool_output.output or "", + tool_call_id=tool_output.id or "", + name=tool_output.name or "", + ) + ) + + # Add final assistant response as the authoritative text + messages.append(AssistantPromptMessage(content=assistant_response)) + + return messages + + def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage: """ Truncate multi-modal content base64 data in a message to avoid storing large data. diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1853739650..c404ba4318 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -373,7 +373,7 @@ class LLMNode(Node[LLMNodeData]): "reasoning_content": reasoning_content, "usage": jsonable_encoder(usage), "finish_reason": finish_reason, - "context": llm_utils.build_context(prompt_messages, clean_text), + "context": llm_utils.build_context(prompt_messages, clean_text, generation_data), } # Build generation field diff --git a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py index 76fd68cfcb..3929c64896 100644 --- a/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_sandbox_layer.py @@ -171,9 +171,7 @@ class TestSandboxLayer: layer.on_event(GraphRunSucceededEvent(outputs={})) layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1)) - def test_on_graph_end_releases_sandbox_and_unregisters_from_manager( - self, mock_sandbox_storage: MagicMock - ) -> None: + def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None: sandbox_id = "test-exec-456" layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage) mock_sandbox = MagicMock(spec=VirtualEnvironment) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index e327e03159..477c108aeb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -123,6 +123,195 @@ class TestBuildContext: assert len(context) == 2 assert context[1].content == "The answer is 4." + def test_builds_context_with_tool_calls_from_generation_data(self): + """Should reconstruct full conversation including tool calls when generation_data is provided.""" + from core.model_runtime.entities.llm_entities import LLMUsage + from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ToolPromptMessage, + ) + from core.workflow.nodes.llm.entities import ( + LLMGenerationData, + LLMTraceSegment, + ModelTraceSegment, + ToolCall, + ToolTraceSegment, + ) + + messages = [UserPromptMessage(content="What's the weather in Beijing?")] + + # Create trace with tool call and result + generation_data = LLMGenerationData( + text="The weather in Beijing is sunny, 25°C.", + reasoning_contents=[], + tool_calls=[], + sequence=[], + usage=LLMUsage.empty_usage(), + finish_reason="stop", + files=[], + trace=[ + LLMTraceSegment( + type="model", + duration=0.5, + usage=None, + output=ModelTraceSegment( + text="Let me check the weather.", + reasoning=None, + tool_calls=[ + ToolCall( + id="call_123", + name="get_weather", + arguments='{"city": "Beijing"}', + ) + ], + ), + ), + LLMTraceSegment( + type="tool", + duration=0.3, + usage=None, + output=ToolTraceSegment( + id="call_123", + name="get_weather", + arguments='{"city": "Beijing"}', + output="Sunny, 25°C", + ), + ), + ], + ) + + context = build_context(messages, "The weather in Beijing is sunny, 25°C.", generation_data) + + # Should have: user message + assistant with tool_call + tool result + final assistant + assert len(context) == 4 + assert context[0].content == "What's the weather in Beijing?" + assert isinstance(context[1], AssistantPromptMessage) + assert context[1].content == "Let me check the weather." + assert len(context[1].tool_calls) == 1 + assert context[1].tool_calls[0].id == "call_123" + assert context[1].tool_calls[0].function.name == "get_weather" + assert isinstance(context[2], ToolPromptMessage) + assert context[2].content == "Sunny, 25°C" + assert context[2].tool_call_id == "call_123" + assert isinstance(context[3], AssistantPromptMessage) + assert context[3].content == "The weather in Beijing is sunny, 25°C." + + def test_builds_context_with_multiple_tool_calls(self): + """Should handle multiple tool calls in a single conversation.""" + from core.model_runtime.entities.llm_entities import LLMUsage + from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ToolPromptMessage, + ) + from core.workflow.nodes.llm.entities import ( + LLMGenerationData, + LLMTraceSegment, + ModelTraceSegment, + ToolCall, + ToolTraceSegment, + ) + + messages = [UserPromptMessage(content="Compare weather in Beijing and Shanghai")] + + generation_data = LLMGenerationData( + text="Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.", + reasoning_contents=[], + tool_calls=[], + sequence=[], + usage=LLMUsage.empty_usage(), + finish_reason="stop", + files=[], + trace=[ + # First model call with two tool calls + LLMTraceSegment( + type="model", + duration=0.5, + usage=None, + output=ModelTraceSegment( + text="I'll check both cities.", + reasoning=None, + tool_calls=[ + ToolCall(id="call_1", name="get_weather", arguments='{"city": "Beijing"}'), + ToolCall(id="call_2", name="get_weather", arguments='{"city": "Shanghai"}'), + ], + ), + ), + # First tool result + LLMTraceSegment( + type="tool", + duration=0.2, + usage=None, + output=ToolTraceSegment( + id="call_1", + name="get_weather", + arguments='{"city": "Beijing"}', + output="Sunny, 25°C", + ), + ), + # Second tool result + LLMTraceSegment( + type="tool", + duration=0.2, + usage=None, + output=ToolTraceSegment( + id="call_2", + name="get_weather", + arguments='{"city": "Shanghai"}', + output="Cloudy, 22°C", + ), + ), + ], + ) + + context = build_context(messages, "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.", generation_data) + + # Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant + assert len(context) == 5 + assert context[0].content == "Compare weather in Beijing and Shanghai" + assert isinstance(context[1], AssistantPromptMessage) + assert len(context[1].tool_calls) == 2 + assert isinstance(context[2], ToolPromptMessage) + assert context[2].content == "Sunny, 25°C" + assert isinstance(context[3], ToolPromptMessage) + assert context[3].content == "Cloudy, 22°C" + assert isinstance(context[4], AssistantPromptMessage) + assert context[4].content == "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C." + + def test_builds_context_without_generation_data(self): + """Should fallback to simple context when no generation_data is provided.""" + messages = [UserPromptMessage(content="Hello!")] + + context = build_context(messages, "Hi there!", generation_data=None) + + assert len(context) == 2 + assert context[0].content == "Hello!" + assert context[1].content == "Hi there!" + + def test_builds_context_with_empty_trace(self): + """Should fallback to simple context when trace is empty.""" + from core.model_runtime.entities.llm_entities import LLMUsage + from core.workflow.nodes.llm.entities import LLMGenerationData + + messages = [UserPromptMessage(content="Hello!")] + + generation_data = LLMGenerationData( + text="Hi there!", + reasoning_contents=[], + tool_calls=[], + sequence=[], + usage=LLMUsage.empty_usage(), + finish_reason="stop", + files=[], + trace=[], # Empty trace + ) + + context = build_context(messages, "Hi there!", generation_data) + + # Should fallback to simple context + assert len(context) == 2 + assert context[0].content == "Hello!" + assert context[1].content == "Hi there!" + class TestRestoreMultimodalContentInMessages: """Tests for restore_multimodal_content_in_messages function."""