mirror of
https://github.com/langgenius/dify.git
synced 2026-02-17 19:00:55 -05:00
chore: fix the llm node memory issue
This commit is contained in:
@@ -18,13 +18,14 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.entities import LLMGenerationData, ModelConfig
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -214,21 +215,87 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
def build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
generation_data: LLMGenerationData | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
For tool-enabled runs, reconstructs the full conversation including tool calls and results.
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
|
||||
Args:
|
||||
prompt_messages: Initial prompt messages (user query, etc.)
|
||||
assistant_response: Final assistant response text
|
||||
generation_data: Optional generation data containing trace for tool-enabled runs
|
||||
"""
|
||||
|
||||
context_messages: list[PromptMessage] = [
|
||||
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
# For tool-enabled runs, reconstruct messages from trace
|
||||
if generation_data and generation_data.trace:
|
||||
context_messages.extend(_build_messages_from_trace(generation_data, assistant_response))
|
||||
else:
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
|
||||
def _build_messages_from_trace(
|
||||
generation_data: LLMGenerationData,
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build assistant and tool messages from trace segments.
|
||||
|
||||
Processes trace in order to reconstruct the conversation flow:
|
||||
- Model segments with tool_calls -> AssistantPromptMessage with tool_calls
|
||||
- Tool segments -> ToolPromptMessage with result
|
||||
- Final response -> AssistantPromptMessage with assistant_response
|
||||
"""
|
||||
from core.workflow.nodes.llm.entities import ModelTraceSegment, ToolTraceSegment
|
||||
|
||||
messages: list[PromptMessage] = []
|
||||
|
||||
for segment in generation_data.trace:
|
||||
if segment.type == "model" and isinstance(segment.output, ModelTraceSegment):
|
||||
model_output = segment.output
|
||||
segment_content = model_output.text or ""
|
||||
|
||||
if model_output.tool_calls:
|
||||
# Build tool_calls for AssistantPromptMessage
|
||||
tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc.id or "",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tc.name or "",
|
||||
arguments=tc.arguments or "",
|
||||
),
|
||||
)
|
||||
for tc in model_output.tool_calls
|
||||
]
|
||||
messages.append(AssistantPromptMessage(content=segment_content, tool_calls=tool_calls))
|
||||
|
||||
elif segment.type == "tool" and isinstance(segment.output, ToolTraceSegment):
|
||||
tool_output = segment.output
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_output.output or "",
|
||||
tool_call_id=tool_output.id or "",
|
||||
name=tool_output.name or "",
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response as the authoritative text
|
||||
messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
|
||||
@@ -373,7 +373,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text),
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
|
||||
}
|
||||
|
||||
# Build generation field
|
||||
|
||||
@@ -171,9 +171,7 @@ class TestSandboxLayer:
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(
|
||||
self, mock_sandbox_storage: MagicMock
|
||||
) -> None:
|
||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None:
|
||||
sandbox_id = "test-exec-456"
|
||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||
|
||||
@@ -123,6 +123,195 @@ class TestBuildContext:
|
||||
assert len(context) == 2
|
||||
assert context[1].content == "The answer is 4."
|
||||
|
||||
def test_builds_context_with_tool_calls_from_generation_data(self):
|
||||
"""Should reconstruct full conversation including tool calls when generation_data is provided."""
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMGenerationData,
|
||||
LLMTraceSegment,
|
||||
ModelTraceSegment,
|
||||
ToolCall,
|
||||
ToolTraceSegment,
|
||||
)
|
||||
|
||||
messages = [UserPromptMessage(content="What's the weather in Beijing?")]
|
||||
|
||||
# Create trace with tool call and result
|
||||
generation_data = LLMGenerationData(
|
||||
text="The weather in Beijing is sunny, 25°C.",
|
||||
reasoning_contents=[],
|
||||
tool_calls=[],
|
||||
sequence=[],
|
||||
usage=LLMUsage.empty_usage(),
|
||||
finish_reason="stop",
|
||||
files=[],
|
||||
trace=[
|
||||
LLMTraceSegment(
|
||||
type="model",
|
||||
duration=0.5,
|
||||
usage=None,
|
||||
output=ModelTraceSegment(
|
||||
text="Let me check the weather.",
|
||||
reasoning=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
LLMTraceSegment(
|
||||
type="tool",
|
||||
duration=0.3,
|
||||
usage=None,
|
||||
output=ToolTraceSegment(
|
||||
id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
output="Sunny, 25°C",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
context = build_context(messages, "The weather in Beijing is sunny, 25°C.", generation_data)
|
||||
|
||||
# Should have: user message + assistant with tool_call + tool result + final assistant
|
||||
assert len(context) == 4
|
||||
assert context[0].content == "What's the weather in Beijing?"
|
||||
assert isinstance(context[1], AssistantPromptMessage)
|
||||
assert context[1].content == "Let me check the weather."
|
||||
assert len(context[1].tool_calls) == 1
|
||||
assert context[1].tool_calls[0].id == "call_123"
|
||||
assert context[1].tool_calls[0].function.name == "get_weather"
|
||||
assert isinstance(context[2], ToolPromptMessage)
|
||||
assert context[2].content == "Sunny, 25°C"
|
||||
assert context[2].tool_call_id == "call_123"
|
||||
assert isinstance(context[3], AssistantPromptMessage)
|
||||
assert context[3].content == "The weather in Beijing is sunny, 25°C."
|
||||
|
||||
def test_builds_context_with_multiple_tool_calls(self):
|
||||
"""Should handle multiple tool calls in a single conversation."""
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
LLMGenerationData,
|
||||
LLMTraceSegment,
|
||||
ModelTraceSegment,
|
||||
ToolCall,
|
||||
ToolTraceSegment,
|
||||
)
|
||||
|
||||
messages = [UserPromptMessage(content="Compare weather in Beijing and Shanghai")]
|
||||
|
||||
generation_data = LLMGenerationData(
|
||||
text="Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.",
|
||||
reasoning_contents=[],
|
||||
tool_calls=[],
|
||||
sequence=[],
|
||||
usage=LLMUsage.empty_usage(),
|
||||
finish_reason="stop",
|
||||
files=[],
|
||||
trace=[
|
||||
# First model call with two tool calls
|
||||
LLMTraceSegment(
|
||||
type="model",
|
||||
duration=0.5,
|
||||
usage=None,
|
||||
output=ModelTraceSegment(
|
||||
text="I'll check both cities.",
|
||||
reasoning=None,
|
||||
tool_calls=[
|
||||
ToolCall(id="call_1", name="get_weather", arguments='{"city": "Beijing"}'),
|
||||
ToolCall(id="call_2", name="get_weather", arguments='{"city": "Shanghai"}'),
|
||||
],
|
||||
),
|
||||
),
|
||||
# First tool result
|
||||
LLMTraceSegment(
|
||||
type="tool",
|
||||
duration=0.2,
|
||||
usage=None,
|
||||
output=ToolTraceSegment(
|
||||
id="call_1",
|
||||
name="get_weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
output="Sunny, 25°C",
|
||||
),
|
||||
),
|
||||
# Second tool result
|
||||
LLMTraceSegment(
|
||||
type="tool",
|
||||
duration=0.2,
|
||||
usage=None,
|
||||
output=ToolTraceSegment(
|
||||
id="call_2",
|
||||
name="get_weather",
|
||||
arguments='{"city": "Shanghai"}',
|
||||
output="Cloudy, 22°C",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
context = build_context(messages, "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.", generation_data)
|
||||
|
||||
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
|
||||
assert len(context) == 5
|
||||
assert context[0].content == "Compare weather in Beijing and Shanghai"
|
||||
assert isinstance(context[1], AssistantPromptMessage)
|
||||
assert len(context[1].tool_calls) == 2
|
||||
assert isinstance(context[2], ToolPromptMessage)
|
||||
assert context[2].content == "Sunny, 25°C"
|
||||
assert isinstance(context[3], ToolPromptMessage)
|
||||
assert context[3].content == "Cloudy, 22°C"
|
||||
assert isinstance(context[4], AssistantPromptMessage)
|
||||
assert context[4].content == "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C."
|
||||
|
||||
def test_builds_context_without_generation_data(self):
|
||||
"""Should fallback to simple context when no generation_data is provided."""
|
||||
messages = [UserPromptMessage(content="Hello!")]
|
||||
|
||||
context = build_context(messages, "Hi there!", generation_data=None)
|
||||
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
|
||||
def test_builds_context_with_empty_trace(self):
|
||||
"""Should fallback to simple context when trace is empty."""
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.nodes.llm.entities import LLMGenerationData
|
||||
|
||||
messages = [UserPromptMessage(content="Hello!")]
|
||||
|
||||
generation_data = LLMGenerationData(
|
||||
text="Hi there!",
|
||||
reasoning_contents=[],
|
||||
tool_calls=[],
|
||||
sequence=[],
|
||||
usage=LLMUsage.empty_usage(),
|
||||
finish_reason="stop",
|
||||
files=[],
|
||||
trace=[], # Empty trace
|
||||
)
|
||||
|
||||
context = build_context(messages, "Hi there!", generation_data)
|
||||
|
||||
# Should fallback to simple context
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
|
||||
|
||||
class TestRestoreMultimodalContentInMessages:
|
||||
"""Tests for restore_multimodal_content_in_messages function."""
|
||||
|
||||
Reference in New Issue
Block a user