perf: remove the n+1 query (#29483)

This commit is contained in:
wangxiaolei
2025-12-11 15:18:27 +08:00
committed by GitHub
parent a30cbe3c95
commit 6e802a343e
2 changed files with 295 additions and 12 deletions

View File

@@ -835,7 +835,29 @@ class Conversation(Base):
@property
def status_count(self):
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
from models.workflow import WorkflowRun
# Get all messages with workflow_run_id for this conversation
messages = db.session.scalars(
select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None))
).all()
if not messages:
return None
# Batch load all workflow runs in a single query, filtered by this conversation's app_id
workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id]
workflow_runs = {}
if workflow_run_ids:
workflow_runs_query = db.session.scalars(
select(WorkflowRun).where(
WorkflowRun.id.in_(workflow_run_ids),
WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id
)
).all()
workflow_runs = {run.id: run for run in workflow_runs_query}
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@@ -845,18 +867,24 @@ class Conversation(Base):
}
for message in messages:
if message.workflow_run:
status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1
# Guard against None to satisfy type checker and avoid invalid dict lookups
if message.workflow_run_id is None:
continue
workflow_run = workflow_runs.get(message.workflow_run_id)
if not workflow_run:
continue
return (
{
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
"failed": status_counts[WorkflowExecutionStatus.FAILED],
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
}
if messages
else None
)
try:
status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1
except (ValueError, KeyError):
# Handle invalid status values gracefully
pass
return {
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
"failed": status_counts[WorkflowExecutionStatus.FAILED],
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
}
@property
def first_message(self):

View File

@@ -1149,3 +1149,258 @@ class TestModelIntegration:
# Assert
assert site.app_id == app.id
assert app.enable_site is True
class TestConversationStatusCount:
"""Test suite for Conversation.status_count property N+1 query fix."""
def test_status_count_no_messages(self):
"""Test status_count returns None when conversation has no messages."""
# Arrange
conversation = Conversation(
app_id=str(uuid4()),
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = str(uuid4())
# Mock the database query to return no messages
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
result = conversation.status_count
# Assert
assert result is None
def test_status_count_messages_without_workflow_runs(self):
"""Test status_count when messages have no workflow_run_id."""
# Arrange
app_id = str(uuid4())
conversation_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
# Mock the database query to return no messages with workflow_run_id
with patch("models.model.db.session.scalars") as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
result = conversation.status_count
# Assert
assert result is None
def test_status_count_batch_loading_implementation(self):
"""Test that status_count uses batch loading instead of N+1 queries."""
# Arrange
from core.workflow.enums import WorkflowExecutionStatus
app_id = str(uuid4())
conversation_id = str(uuid4())
# Create workflow run IDs
workflow_run_id_1 = str(uuid4())
workflow_run_id_2 = str(uuid4())
workflow_run_id_3 = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
# Mock messages with workflow_run_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_1,
),
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_2,
),
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id_3,
),
]
# Mock workflow runs with different statuses
mock_workflow_runs = [
MagicMock(
id=workflow_run_id_1,
status=WorkflowExecutionStatus.SUCCEEDED.value,
app_id=app_id,
),
MagicMock(
id=workflow_run_id_2,
status=WorkflowExecutionStatus.FAILED.value,
app_id=app_id,
),
MagicMock(
id=workflow_run_id_3,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value,
app_id=app_id,
),
]
# Track database calls
calls_made = []
def mock_scalars(query):
calls_made.append(str(query))
mock_result = MagicMock()
# Return messages for the first query (messages with workflow_run_id)
if "messages" in str(query) and "conversation_id" in str(query):
mock_result.all.return_value = mock_messages
# Return workflow runs for the batch query
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
# Act & Assert
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Verify only 2 database queries were made (not N+1)
assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}"
# Verify the first query gets messages
assert "messages" in calls_made[0]
assert "conversation_id" in calls_made[0]
# Verify the second query batch loads workflow runs with proper filtering
assert "workflow_runs" in calls_made[1]
assert "app_id" in calls_made[1] # Security filter applied
assert "IN" in calls_made[1] # Batch loading with IN clause
# Verify correct status counts
assert result["success"] == 1 # One SUCCEEDED
assert result["failed"] == 1 # One FAILED
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
def test_status_count_app_id_filtering(self):
"""Test that status_count filters workflow runs by app_id for security."""
# Arrange
app_id = str(uuid4())
other_app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
# Mock message with workflow_run_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
calls_made = []
def mock_scalars(query):
calls_made.append(str(query))
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
# Return empty list because no workflow run matches the correct app_id
mock_result.all.return_value = [] # Workflow run filtered out by app_id
else:
mock_result.all.return_value = []
return mock_result
# Act
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
result = conversation.status_count
# Assert - query should include app_id filter
workflow_query = calls_made[1]
assert "app_id" in workflow_query
# Since workflow run has wrong app_id, it shouldn't be included in counts
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
def test_status_count_handles_invalid_workflow_status(self):
"""Test that status_count gracefully handles invalid workflow status values."""
# Arrange
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
# Mock workflow run with invalid status
mock_workflow_runs = [
MagicMock(
id=workflow_run_id,
status="invalid_status", # Invalid status that should raise ValueError
app_id=app_id,
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
# Mock the messages query
def mock_scalars_side_effect(query):
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
mock_scalars.side_effect = mock_scalars_side_effect
# Act - should not raise exception
result = conversation.status_count
# Assert - should handle invalid status gracefully
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0