diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 90e72d5f34..1c7027efb4 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -132,7 +132,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.query(Message).where(Message.id == message_id).first() + message = db.session.get(Message, message_id) if message is None: return None if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54c595e0cb..9970b2e604 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -6,6 +6,7 @@ from typing import Literal import httpx from pydantic import TypeAdapter +from sqlalchemy import select from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from typing_extensions import TypedDict from werkzeug.exceptions import InternalServerError @@ -158,10 +159,10 @@ class BillingService: def is_tenant_owner_or_admin(current_user: Account): tenant_id = current_user.current_tenant_id - join: TenantAccountJoin | None = ( - db.session.query(TenantAccountJoin) + join: TenantAccountJoin | None = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) - .first() + .limit(1) ) if not join: diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ba1e7bb826..95482a2235 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -137,11 +137,11 @@ class ConversationService: @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = ( - db.session.query(Message) + message = db.session.scalar( + select(Message) .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id) .order_by(Message.created_at.asc()) - .first() + .limit(1) ) if not message: @@ -160,8 +160,8 @@ class ConversationService: @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where( Conversation.id == conversation_id, Conversation.app_id == app_model.id, @@ -170,7 +170,7 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), Conversation.is_deleted == False, ) - .first() + .limit(1) ) if not conversation: diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 2894826935..7826695366 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,6 +1,6 @@ import logging -from sqlalchemy import update +from sqlalchemy import select, update from sqlalchemy.orm import Session from configs import dify_config @@ -29,13 +29,13 @@ class CreditPoolService: @classmethod def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" - return ( - db.session.query(TenantCreditPool) - .filter_by( - tenant_id=tenant_id, - pool_type=pool_type, + return db.session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, ) - .first() + .limit(1) ) @classmethod diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py index c7ff42894d..b5107fb0f6 100644 --- a/api/services/enterprise/account_deletion_sync.py +++ b/api/services/enterprise/account_deletion_sync.py @@ -4,6 +4,7 @@ import uuid from datetime import UTC, datetime from redis import RedisError +from sqlalchemy import select from configs import dify_config from extensions.ext_database import db @@ -104,7 +105,9 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool: return True # Fetch all workspaces the account belongs to - workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + workspace_joins = db.session.scalars( + select(TenantAccountJoin).where(TenantAccountJoin.account_id == account_id) + ).all() # Queue sync task for each workspace success = True diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 07e1b8f20e..10e89b1dba 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -110,7 +110,7 @@ class PipelineGenerateService: Update document status to waiting :param document_id: document id """ - document = db.session.query(Document).where(Document.id == document_id).first() + document = db.session.get(Document, document_id) if document: document.indexing_status = IndexingStatus.WAITING db.session.add(document) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 4ac2e0792b..2ee871a266 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,5 @@ import yaml +from sqlalchemy import select from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -32,12 +33,11 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - pipeline_customized_templates = ( - db.session.query(PipelineCustomizedTemplate) + pipeline_customized_templates = db.session.scalars( + select(PipelineCustomizedTemplate) .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) - .all() - ) + ).all() recommended_pipelines_results = [] for pipeline_customized_template in pipeline_customized_templates: recommended_pipeline_result = { @@ -59,9 +59,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - pipeline_template = ( - db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() - ) + pipeline_template = db.session.get(PipelineCustomizedTemplate, template_id) if not pipeline_template: return None diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 908f9a2684..43b21a7b32 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,5 @@ import yaml +from sqlalchemy import select from extensions.ext_database import db from models.dataset import PipelineBuiltInTemplate @@ -30,8 +31,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ - pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( - db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all() + pipeline_built_in_templates = list( + db.session.scalars( + select(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language) + ).all() ) recommended_pipelines_results = [] @@ -58,9 +61,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :return: """ # is in public recommended list - pipeline_template = ( - db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first() - ) + pipeline_template = db.session.get(PipelineBuiltInTemplate, template_id) if not pipeline_template: return None diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index d0c49325dc..6fb90d356d 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -77,17 +77,15 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): :return: """ # is in public recommended list - recommended_app = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) - .first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id).limit(1) ) if not recommended_app: return None # get app detail - app_model = db.session.query(App).where(App.id == app_id).first() + app_model = db.session.get(App, app_id) if not app_model or not app_model.is_public: return None diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index e028e3e5e3..5ef9e9be61 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -64,15 +64,15 @@ class WebConversationService: def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return - pinned_conversation = ( - db.session.query(PinnedConversation) + pinned_conversation = db.session.scalar( + select(PinnedConversation) .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) - .first() + .limit(1) ) if pinned_conversation: @@ -96,15 +96,15 @@ class WebConversationService: def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): if not user: return - pinned_conversation = ( - db.session.query(PinnedConversation) + pinned_conversation = db.session.scalar( + select(PinnedConversation) .where( PinnedConversation.app_id == app_model.id, PinnedConversation.conversation_id == conversation_id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) - .first() + .limit(1) ) if not pinned_conversation: diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 5ca0b63001..eaea79af2f 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -3,6 +3,7 @@ import secrets from datetime import UTC, datetime, timedelta from typing import Any +from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -92,10 +93,10 @@ class WebAppAuthService: @classmethod def create_end_user(cls, app_code, email) -> EndUser: - site = db.session.query(Site).where(Site.code == app_code).first() + site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1)) if not site: raise NotFound("Site not found.") - app_model = db.session.query(App).where(App.id == site.app_id).first() + app_model = db.session.get(App, site.app_id) if not app_model: raise NotFound("App not found.") end_user = EndUser( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 31367f72fa..399c82849f 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -6,6 +6,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes import BuiltinNodeTypes from graphon.variables.input_entities import VariableEntity +from sqlalchemy import select from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -648,10 +649,10 @@ class WorkflowConverter: :param api_based_extension_id: api based extension id :return: """ - api_based_extension = ( - db.session.query(APIBasedExtension) + api_based_extension = db.session.scalar( + select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) - .first() + .limit(1) ) if not api_based_extension: diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 84a8b03329..eb4671cfaa 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,5 @@ from flask_login import current_user +from sqlalchemy import select from configs import dify_config from enums.cloud_plan import CloudPlan @@ -24,10 +25,10 @@ class WorkspaceService: } # Get role of user - tenant_account_join = ( - db.session.query(TenantAccountJoin) + tenant_account_join = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) - .first() + .limit(1) ) assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 175fd3ee01..cede6671ce 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -421,11 +421,8 @@ class TestAudioServiceTTS: answer="Message answer text", ) - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + # Mock database lookup + mock_db_session.get.return_value = message # Mock ModelManager mock_model_manager = mock_model_manager_class.return_value @@ -568,11 +565,8 @@ class TestAudioServiceTTS: # Arrange app = factory.create_app_mock() - # Mock database query returning None - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None + # Mock database lookup returning None + mock_db_session.get.return_value = None # Act result = AudioService.transcript_tts( @@ -594,11 +588,8 @@ class TestAudioServiceTTS: status=MessageStatus.NORMAL, ) - # Mock database query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = message + # Mock database lookup + mock_db_session.get.return_value = message # Act result = AudioService.transcript_tts( diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index b3d2e60802..168ab6cf0d 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -865,16 +865,11 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.OWNER - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act - should not raise exception BillingService.is_tenant_owner_or_admin(current_user) - # Assert - mock_db_session.query.assert_called_once() - def test_is_tenant_owner_or_admin_admin(self, mock_db_session): """Test tenant owner/admin check for admin role.""" # Arrange @@ -885,16 +880,11 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.ADMIN - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act - should not raise exception BillingService.is_tenant_owner_or_admin(current_user) - # Assert - mock_db_session.query.assert_called_once() - def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session): """Test tenant owner/admin check raises error for normal user.""" # Arrange @@ -905,9 +895,7 @@ class TestBillingServiceAccountManagement: mock_join = MagicMock(spec=TenantAccountJoin) mock_join.role = TenantAccountRole.NORMAL - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -921,9 +909,7 @@ class TestBillingServiceAccountManagement: current_user.id = "account-123" current_user.current_tenant_id = "tenant-456" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_db_session.query.return_value = mock_query + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1135,9 +1121,7 @@ class TestBillingServiceEdgeCases: mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged with patch("services.billing_service.db.session") as mock_session: - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_session.query.return_value = mock_query + mock_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1155,9 +1139,7 @@ class TestBillingServiceEdgeCases: mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged with patch("services.billing_service.db.session") as mock_session: - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_join - mock_session.query.return_value = mock_query + mock_session.scalar.return_value = mock_join # Act & Assert with pytest.raises(ValueError) as exc_info: diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 1bf4c0e172..a4359f00b8 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -355,15 +355,13 @@ class TestConversationServiceGetConversation: from_account_id=user.id, from_source=ConversationFromSource.CONSOLE ) - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = conversation + mock_db_session.scalar.return_value = conversation # Act result = ConversationService.get_conversation(app_model, "conv-123", user) # Assert assert result == conversation - mock_db_session.query.assert_called_once_with(Conversation) @patch("services.conversation_service.db.session") def test_get_conversation_success_with_end_user(self, mock_db_session): @@ -379,8 +377,7 @@ class TestConversationServiceGetConversation: from_end_user_id=user.id, from_source=ConversationFromSource.API ) - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = conversation + mock_db_session.scalar.return_value = conversation # Act result = ConversationService.get_conversation(app_model, "conv-123", user) @@ -399,8 +396,7 @@ class TestConversationServiceGetConversation: app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.first.return_value = None + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(ConversationNotExistsError): @@ -489,8 +485,7 @@ class TestConversationServiceAutoGenerateName: ) # Mock database query to return message - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = message + mock_db_session.scalar.return_value = message # Mock LLM generator mock_llm_generator.generate_conversation_name.return_value = "Generated Name" @@ -518,8 +513,7 @@ class TestConversationServiceAutoGenerateName: conversation = ConversationServiceTestDataFactory.create_conversation_mock() # Mock database query to return None - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = None + mock_db_session.scalar.return_value = None # Act & Assert with pytest.raises(MessageNotExistsError): @@ -541,8 +535,7 @@ class TestConversationServiceAutoGenerateName: ) # Mock database query to return message - mock_query = mock_db_session.query.return_value - mock_query.where.return_value.order_by.return_value.first.return_value = message + mock_db_session.scalar.return_value = message # Mock LLM generator to raise exception mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error")