chore: fix the llm node memory issue

This commit is contained in:
Novice
2026-01-20 13:52:15 +08:00
parent 8154d0af53
commit 27de07e93d
4 changed files with 260 additions and 6 deletions

View File

@@ -18,13 +18,14 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
ToolPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.entities import LLMGenerationData, ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@@ -214,21 +215,87 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
def build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
generation_data: LLMGenerationData | None = None,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
For tool-enabled runs, reconstructs the full conversation including tool calls and results.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
Args:
prompt_messages: Initial prompt messages (user query, etc.)
assistant_response: Final assistant response text
generation_data: Optional generation data containing trace for tool-enabled runs
"""
context_messages: list[PromptMessage] = [
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
# For tool-enabled runs, reconstruct messages from trace
if generation_data and generation_data.trace:
context_messages.extend(_build_messages_from_trace(generation_data, assistant_response))
else:
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _build_messages_from_trace(
generation_data: LLMGenerationData,
assistant_response: str,
) -> list[PromptMessage]:
"""
Build assistant and tool messages from trace segments.
Processes trace in order to reconstruct the conversation flow:
- Model segments with tool_calls -> AssistantPromptMessage with tool_calls
- Tool segments -> ToolPromptMessage with result
- Final response -> AssistantPromptMessage with assistant_response
"""
from core.workflow.nodes.llm.entities import ModelTraceSegment, ToolTraceSegment
messages: list[PromptMessage] = []
for segment in generation_data.trace:
if segment.type == "model" and isinstance(segment.output, ModelTraceSegment):
model_output = segment.output
segment_content = model_output.text or ""
if model_output.tool_calls:
# Build tool_calls for AssistantPromptMessage
tool_calls = [
AssistantPromptMessage.ToolCall(
id=tc.id or "",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tc.name or "",
arguments=tc.arguments or "",
),
)
for tc in model_output.tool_calls
]
messages.append(AssistantPromptMessage(content=segment_content, tool_calls=tool_calls))
elif segment.type == "tool" and isinstance(segment.output, ToolTraceSegment):
tool_output = segment.output
messages.append(
ToolPromptMessage(
content=tool_output.output or "",
tool_call_id=tool_output.id or "",
name=tool_output.name or "",
)
)
# Add final assistant response as the authoritative text
messages.append(AssistantPromptMessage(content=assistant_response))
return messages
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.

View File

@@ -373,7 +373,7 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text),
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
}
# Build generation field

View File

@@ -171,9 +171,7 @@ class TestSandboxLayer:
layer.on_event(GraphRunSucceededEvent(outputs={}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(
self, mock_sandbox_storage: MagicMock
) -> None:
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None:
sandbox_id = "test-exec-456"
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
mock_sandbox = MagicMock(spec=VirtualEnvironment)

View File

@@ -123,6 +123,195 @@ class TestBuildContext:
assert len(context) == 2
assert context[1].content == "The answer is 4."
def test_builds_context_with_tool_calls_from_generation_data(self):
"""Should reconstruct full conversation including tool calls when generation_data is provided."""
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ToolPromptMessage,
)
from core.workflow.nodes.llm.entities import (
LLMGenerationData,
LLMTraceSegment,
ModelTraceSegment,
ToolCall,
ToolTraceSegment,
)
messages = [UserPromptMessage(content="What's the weather in Beijing?")]
# Create trace with tool call and result
generation_data = LLMGenerationData(
text="The weather in Beijing is sunny, 25°C.",
reasoning_contents=[],
tool_calls=[],
sequence=[],
usage=LLMUsage.empty_usage(),
finish_reason="stop",
files=[],
trace=[
LLMTraceSegment(
type="model",
duration=0.5,
usage=None,
output=ModelTraceSegment(
text="Let me check the weather.",
reasoning=None,
tool_calls=[
ToolCall(
id="call_123",
name="get_weather",
arguments='{"city": "Beijing"}',
)
],
),
),
LLMTraceSegment(
type="tool",
duration=0.3,
usage=None,
output=ToolTraceSegment(
id="call_123",
name="get_weather",
arguments='{"city": "Beijing"}',
output="Sunny, 25°C",
),
),
],
)
context = build_context(messages, "The weather in Beijing is sunny, 25°C.", generation_data)
# Should have: user message + assistant with tool_call + tool result + final assistant
assert len(context) == 4
assert context[0].content == "What's the weather in Beijing?"
assert isinstance(context[1], AssistantPromptMessage)
assert context[1].content == "Let me check the weather."
assert len(context[1].tool_calls) == 1
assert context[1].tool_calls[0].id == "call_123"
assert context[1].tool_calls[0].function.name == "get_weather"
assert isinstance(context[2], ToolPromptMessage)
assert context[2].content == "Sunny, 25°C"
assert context[2].tool_call_id == "call_123"
assert isinstance(context[3], AssistantPromptMessage)
assert context[3].content == "The weather in Beijing is sunny, 25°C."
def test_builds_context_with_multiple_tool_calls(self):
"""Should handle multiple tool calls in a single conversation."""
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ToolPromptMessage,
)
from core.workflow.nodes.llm.entities import (
LLMGenerationData,
LLMTraceSegment,
ModelTraceSegment,
ToolCall,
ToolTraceSegment,
)
messages = [UserPromptMessage(content="Compare weather in Beijing and Shanghai")]
generation_data = LLMGenerationData(
text="Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.",
reasoning_contents=[],
tool_calls=[],
sequence=[],
usage=LLMUsage.empty_usage(),
finish_reason="stop",
files=[],
trace=[
# First model call with two tool calls
LLMTraceSegment(
type="model",
duration=0.5,
usage=None,
output=ModelTraceSegment(
text="I'll check both cities.",
reasoning=None,
tool_calls=[
ToolCall(id="call_1", name="get_weather", arguments='{"city": "Beijing"}'),
ToolCall(id="call_2", name="get_weather", arguments='{"city": "Shanghai"}'),
],
),
),
# First tool result
LLMTraceSegment(
type="tool",
duration=0.2,
usage=None,
output=ToolTraceSegment(
id="call_1",
name="get_weather",
arguments='{"city": "Beijing"}',
output="Sunny, 25°C",
),
),
# Second tool result
LLMTraceSegment(
type="tool",
duration=0.2,
usage=None,
output=ToolTraceSegment(
id="call_2",
name="get_weather",
arguments='{"city": "Shanghai"}',
output="Cloudy, 22°C",
),
),
],
)
context = build_context(messages, "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.", generation_data)
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
assert len(context) == 5
assert context[0].content == "Compare weather in Beijing and Shanghai"
assert isinstance(context[1], AssistantPromptMessage)
assert len(context[1].tool_calls) == 2
assert isinstance(context[2], ToolPromptMessage)
assert context[2].content == "Sunny, 25°C"
assert isinstance(context[3], ToolPromptMessage)
assert context[3].content == "Cloudy, 22°C"
assert isinstance(context[4], AssistantPromptMessage)
assert context[4].content == "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C."
def test_builds_context_without_generation_data(self):
"""Should fallback to simple context when no generation_data is provided."""
messages = [UserPromptMessage(content="Hello!")]
context = build_context(messages, "Hi there!", generation_data=None)
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
def test_builds_context_with_empty_trace(self):
"""Should fallback to simple context when trace is empty."""
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.nodes.llm.entities import LLMGenerationData
messages = [UserPromptMessage(content="Hello!")]
generation_data = LLMGenerationData(
text="Hi there!",
reasoning_contents=[],
tool_calls=[],
sequence=[],
usage=LLMUsage.empty_usage(),
finish_reason="stop",
files=[],
trace=[], # Empty trace
)
context = build_context(messages, "Hi there!", generation_data)
# Should fallback to simple context
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
class TestRestoreMultimodalContentInMessages:
"""Tests for restore_multimodal_content_in_messages function."""