mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
perf: remove the n+1 query (#29483)
This commit is contained in:
@@ -835,7 +835,29 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def status_count(self):
|
||||
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
# Get all messages with workflow_run_id for this conversation
|
||||
messages = db.session.scalars(
|
||||
select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None))
|
||||
).all()
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Batch load all workflow runs in a single query, filtered by this conversation's app_id
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id]
|
||||
workflow_runs = {}
|
||||
|
||||
if workflow_run_ids:
|
||||
workflow_runs_query = db.session.scalars(
|
||||
select(WorkflowRun).where(
|
||||
WorkflowRun.id.in_(workflow_run_ids),
|
||||
WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id
|
||||
)
|
||||
).all()
|
||||
workflow_runs = {run.id: run for run in workflow_runs_query}
|
||||
|
||||
status_counts = {
|
||||
WorkflowExecutionStatus.RUNNING: 0,
|
||||
WorkflowExecutionStatus.SUCCEEDED: 0,
|
||||
@@ -845,18 +867,24 @@ class Conversation(Base):
|
||||
}
|
||||
|
||||
for message in messages:
|
||||
if message.workflow_run:
|
||||
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
|
||||
# Guard against None to satisfy type checker and avoid invalid dict lookups
|
||||
if message.workflow_run_id is None:
|
||||
continue
|
||||
workflow_run = workflow_runs.get(message.workflow_run_id)
|
||||
if not workflow_run:
|
||||
continue
|
||||
|
||||
return (
|
||||
{
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
if messages
|
||||
else None
|
||||
)
|
||||
try:
|
||||
status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1
|
||||
except (ValueError, KeyError):
|
||||
# Handle invalid status values gracefully
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||
}
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
|
||||
@@ -1149,3 +1149,258 @@ class TestModelIntegration:
|
||||
# Assert
|
||||
assert site.app_id == app.id
|
||||
assert app.enable_site is True
|
||||
|
||||
|
||||
class TestConversationStatusCount:
|
||||
"""Test suite for Conversation.status_count property N+1 query fix."""
|
||||
|
||||
def test_status_count_no_messages(self):
|
||||
"""Test status_count returns None when conversation has no messages."""
|
||||
# Arrange
|
||||
conversation = Conversation(
|
||||
app_id=str(uuid4()),
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
# Mock the database query to return no messages
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_messages_without_workflow_runs(self):
|
||||
"""Test status_count when messages have no workflow_run_id."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock the database query to return no messages with workflow_run_id
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_status_count_batch_loading_implementation(self):
|
||||
"""Test that status_count uses batch loading instead of N+1 queries."""
|
||||
# Arrange
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
# Create workflow run IDs
|
||||
workflow_run_id_1 = str(uuid4())
|
||||
workflow_run_id_2 = str(uuid4())
|
||||
workflow_run_id_3 = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock messages with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_1,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_2,
|
||||
),
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id_3,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow runs with different statuses
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id_1,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_2,
|
||||
status=WorkflowExecutionStatus.FAILED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
MagicMock(
|
||||
id=workflow_run_id_3,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Track database calls
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
# Return messages for the first query (messages with workflow_run_id)
|
||||
if "messages" in str(query) and "conversation_id" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
# Return workflow runs for the batch query
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act & Assert
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Verify only 2 database queries were made (not N+1)
|
||||
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
|
||||
|
||||
# Verify the first query gets messages
|
||||
assert "messages" in calls_made[0]
|
||||
assert "conversation_id" in calls_made[0]
|
||||
|
||||
# Verify the second query batch loads workflow runs with proper filtering
|
||||
assert "workflow_runs" in calls_made[1]
|
||||
assert "app_id" in calls_made[1] # Security filter applied
|
||||
assert "IN" in calls_made[1] # Batch loading with IN clause
|
||||
|
||||
# Verify correct status counts
|
||||
assert result["success"] == 1 # One SUCCEEDED
|
||||
assert result["failed"] == 1 # One FAILED
|
||||
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
|
||||
|
||||
def test_status_count_app_id_filtering(self):
|
||||
"""Test that status_count filters workflow runs by app_id for security."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
other_app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock message with workflow_run_id
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
calls_made = []
|
||||
|
||||
def mock_scalars(query):
|
||||
calls_made.append(str(query))
|
||||
mock_result = MagicMock()
|
||||
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
# Return empty list because no workflow run matches the correct app_id
|
||||
mock_result.all.return_value = [] # Workflow run filtered out by app_id
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
|
||||
return mock_result
|
||||
|
||||
# Act
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - query should include app_id filter
|
||||
workflow_query = calls_made[1]
|
||||
assert "app_id" in workflow_query
|
||||
|
||||
# Since workflow run has wrong app_id, it shouldn't be included in counts
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
|
||||
def test_status_count_handles_invalid_workflow_status(self):
|
||||
"""Test that status_count gracefully handles invalid workflow status values."""
|
||||
# Arrange
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock workflow run with invalid status
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id,
|
||||
status="invalid_status", # Invalid status that should raise ValueError
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
# Mock the messages query
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
return mock_result
|
||||
|
||||
mock_scalars.side_effect = mock_scalars_side_effect
|
||||
|
||||
# Act - should not raise exception
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - should handle invalid status gracefully
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
|
||||
Reference in New Issue
Block a user