diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 4fb73f61f3..736e7dbe17 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -4,7 +4,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -244,27 +244,25 @@ class ChatMessageListApi(Resource): def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) - .first() + .limit(1) ) if not conversation: raise NotFound("Conversation Not Exists.") if args.first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == args.first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1) ) if not first_message: raise NotFound("First message not found") - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -272,16 +270,14 @@ class ChatMessageListApi(Resource): ) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() # Initialize has_more based on whether we have a full page if len(history_messages) == args.limit: @@ -326,7 +322,9 @@ class MessageFeedbackApi(Resource): message_id = str(args.message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -375,7 +373,9 @@ class MessageAnnotationCountApi(Resource): @login_required @account_initialization_required def get(self, app_model): - count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() + count = db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id) + ) return {"count": count} @@ -479,7 +479,9 @@ class MessageApi(Resource): def get(self, app_model, message_id: str): message_id = str(message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py index 3ffa53b6db..e6dfc0d3bd 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message.py +++ b/api/tests/unit_tests/controllers/console/app/test_message.py @@ -170,7 +170,7 @@ class TestMessageEndpoints: mock_app_model, qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, ) as (api, mock_db, v_args): - mock_db.data_query.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(NotFound): api.get(**v_args) @@ -198,11 +198,11 @@ class TestMessageEndpoints: mock_msg.message = {} mock_msg.message_metadata_dict = {} - # mock returns - q_mock = mock_db.data_query - q_mock.where.return_value.first.side_effect = [mock_conv] - q_mock.where.return_value.order_by.return_value.limit.return_value.all.return_value = [mock_msg] - mock_db.session.scalar.return_value = False + # scalar() is called twice: first for conversation lookup, second for has_more check + mock_db.session.scalar.side_effect = [mock_conv, False] + scalars_result = MagicMock() + scalars_result.all.return_value = [mock_msg] + mock_db.session.scalars.return_value = scalars_result resp = api.get(**v_args) assert resp["limit"] == 1 @@ -219,7 +219,7 @@ class TestMessageEndpoints: mock_app_model, payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, ) as (api, mock_db, v_args): - mock_db.data_query.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(NotFound): api.post(**v_args) @@ -231,7 +231,7 @@ class TestMessageEndpoints: ) as (api, mock_db, v_args): mock_msg = MagicMock() mock_msg.admin_feedback = None - mock_db.data_query.where.return_value.first.return_value = mock_msg + mock_db.session.scalar.return_value = mock_msg resp = api.post(**v_args) assert resp == {"result": "success"} @@ -240,7 +240,7 @@ class TestMessageEndpoints: with setup_test_context( app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model ) as (api, mock_db, v_args): - mock_db.data_query.where.return_value.count.return_value = 5 + mock_db.session.scalar.return_value = 5 resp = api.get(**v_args) assert resp == {"count": 5} @@ -314,7 +314,7 @@ class TestMessageEndpoints: mock_msg.message = {} mock_msg.message_metadata_dict = {} - mock_db.data_query.where.return_value.first.return_value = mock_msg + mock_db.session.scalar.return_value = mock_msg resp = api.get(**v_args) assert resp["id"] == "msg_123"