From 6e802a343ed224c7e8c3275bafbcc4e94f5655e9 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 11 Dec 2025 15:18:27 +0800 Subject: [PATCH] perf: remove the n+1 query (#29483) --- api/models/model.py | 52 +++- .../unit_tests/models/test_app_models.py | 255 ++++++++++++++++++ 2 files changed, 295 insertions(+), 12 deletions(-) diff --git a/api/models/model.py b/api/models/model.py index c8fa6fd406..c8fbdc40ec 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -835,7 +835,29 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() + from models.workflow import WorkflowRun + + # Get all messages with workflow_run_id for this conversation + messages = db.session.scalars( + select(Message).where(Message.conversation_id == self.id, Message.workflow_run_id.isnot(None)) + ).all() + + if not messages: + return None + + # Batch load all workflow runs in a single query, filtered by this conversation's app_id + workflow_run_ids = [msg.workflow_run_id for msg in messages if msg.workflow_run_id] + workflow_runs = {} + + if workflow_run_ids: + workflow_runs_query = db.session.scalars( + select(WorkflowRun).where( + WorkflowRun.id.in_(workflow_run_ids), + WorkflowRun.app_id == self.app_id, # Filter by this conversation's app_id + ) + ).all() + workflow_runs = {run.id: run for run in workflow_runs_query} + status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -845,18 +867,24 @@ class Conversation(Base): } for message in messages: - if message.workflow_run: - status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1 + # Guard against None to satisfy type checker and avoid invalid dict lookups + if message.workflow_run_id is None: + continue + workflow_run = workflow_runs.get(message.workflow_run_id) + if not workflow_run: + continue - return ( - { - "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], - "failed": status_counts[WorkflowExecutionStatus.FAILED], - "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], - } - if messages - else None - ) + try: + status_counts[WorkflowExecutionStatus(workflow_run.status)] += 1 + except (ValueError, KeyError): + # Handle invalid status values gracefully + pass + + return { + "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], + "failed": status_counts[WorkflowExecutionStatus.FAILED], + "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], + } @property def first_message(self): diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 268ba1282a..e35788660d 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -1149,3 +1149,258 @@ class TestModelIntegration: # Assert assert site.app_id == app.id assert app.enable_site is True + + +class TestConversationStatusCount: + """Test suite for Conversation.status_count property N+1 query fix.""" + + def test_status_count_no_messages(self): + """Test status_count returns None when conversation has no messages.""" + # Arrange + conversation = Conversation( + app_id=str(uuid4()), + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = str(uuid4()) + + # Mock the database query to return no messages + with patch("models.model.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + + # Act + result = conversation.status_count + + # Assert + assert result is None + + def test_status_count_messages_without_workflow_runs(self): + """Test status_count when messages have no workflow_run_id.""" + # Arrange + app_id = str(uuid4()) + conversation_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock the database query to return no messages with workflow_run_id + with patch("models.model.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + + # Act + result = conversation.status_count + + # Assert + assert result is None + + def test_status_count_batch_loading_implementation(self): + """Test that status_count uses batch loading instead of N+1 queries.""" + # Arrange + from core.workflow.enums import WorkflowExecutionStatus + + app_id = str(uuid4()) + conversation_id = str(uuid4()) + + # Create workflow run IDs + workflow_run_id_1 = str(uuid4()) + workflow_run_id_2 = str(uuid4()) + workflow_run_id_3 = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock messages with workflow_run_id + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_1, + ), + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_2, + ), + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id_3, + ), + ] + + # Mock workflow runs with different statuses + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id_1, + status=WorkflowExecutionStatus.SUCCEEDED.value, + app_id=app_id, + ), + MagicMock( + id=workflow_run_id_2, + status=WorkflowExecutionStatus.FAILED.value, + app_id=app_id, + ), + MagicMock( + id=workflow_run_id_3, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + app_id=app_id, + ), + ] + + # Track database calls + calls_made = [] + + def mock_scalars(query): + calls_made.append(str(query)) + mock_result = MagicMock() + + # Return messages for the first query (messages with workflow_run_id) + if "messages" in str(query) and "conversation_id" in str(query): + mock_result.all.return_value = mock_messages + # Return workflow runs for the batch query + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + + return mock_result + + # Act & Assert + with patch("models.model.db.session.scalars", side_effect=mock_scalars): + result = conversation.status_count + + # Verify only 2 database queries were made (not N+1) + assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}" + + # Verify the first query gets messages + assert "messages" in calls_made[0] + assert "conversation_id" in calls_made[0] + + # Verify the second query batch loads workflow runs with proper filtering + assert "workflow_runs" in calls_made[1] + assert "app_id" in calls_made[1] # Security filter applied + assert "IN" in calls_made[1] # Batch loading with IN clause + + # Verify correct status counts + assert result["success"] == 1 # One SUCCEEDED + assert result["failed"] == 1 # One FAILED + assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED + + def test_status_count_app_id_filtering(self): + """Test that status_count filters workflow runs by app_id for security.""" + # Arrange + app_id = str(uuid4()) + other_app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + # Mock message with workflow_run_id + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + calls_made = [] + + def mock_scalars(query): + calls_made.append(str(query)) + mock_result = MagicMock() + + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + # Return empty list because no workflow run matches the correct app_id + mock_result.all.return_value = [] # Workflow run filtered out by app_id + else: + mock_result.all.return_value = [] + + return mock_result + + # Act + with patch("models.model.db.session.scalars", side_effect=mock_scalars): + result = conversation.status_count + + # Assert - query should include app_id filter + workflow_query = calls_made[1] + assert "app_id" in workflow_query + + # Since workflow run has wrong app_id, it shouldn't be included in counts + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + + def test_status_count_handles_invalid_workflow_status(self): + """Test that status_count gracefully handles invalid workflow status values.""" + # Arrange + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_run_id = str(uuid4()) + + conversation = Conversation( + app_id=app_id, + mode=AppMode.CHAT, + name="Test Conversation", + status="normal", + from_source="api", + ) + conversation.id = conversation_id + + mock_messages = [ + MagicMock( + conversation_id=conversation_id, + workflow_run_id=workflow_run_id, + ), + ] + + # Mock workflow run with invalid status + mock_workflow_runs = [ + MagicMock( + id=workflow_run_id, + status="invalid_status", # Invalid status that should raise ValueError + app_id=app_id, + ), + ] + + with patch("models.model.db.session.scalars") as mock_scalars: + # Mock the messages query + def mock_scalars_side_effect(query): + mock_result = MagicMock() + if "messages" in str(query): + mock_result.all.return_value = mock_messages + elif "workflow_runs" in str(query): + mock_result.all.return_value = mock_workflow_runs + else: + mock_result.all.return_value = [] + return mock_result + + mock_scalars.side_effect = mock_scalars_side_effect + + # Act - should not raise exception + result = conversation.status_count + + # Assert - should handle invalid status gracefully + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0