diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index b07dc108be..b8d5db7a43 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -97,13 +97,13 @@ class Jieba(BaseKeyword): documents = [] - segment_query_stmt = db.session.query(DocumentSegment).where( + segment_query_stmt = select(DocumentSegment).where( DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices) ) if document_ids_filter: segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter)) - segments = db.session.execute(segment_query_stmt).scalars().all() + segments = db.session.scalars(segment_query_stmt).all() segment_map = {segment.index_node_id: segment for segment in segments} for chunk_index in sorted_chunk_indices: segment = segment_map.get(chunk_index) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index cc6ec12c75..203a8588d6 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -432,10 +432,11 @@ class RetrievalService: # Batch query dataset documents dataset_documents = { doc.id: doc - for doc in db.session.query(DatasetDocument) - .where(DatasetDocument.id.in_(document_ids)) - .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) - .all() + for doc in db.session.scalars( + select(DatasetDocument) + .where(DatasetDocument.id.in_(document_ids)) + .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) + ).all() } valid_dataset_documents = {} diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 3c1d5e015f..69c81d521c 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -426,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" else: - idle_tidb_auth_binding = ( - db.session.query(TidbAuthBinding) + idle_tidb_auth_binding = db.session.scalar( + select(TidbAuthBinding) .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .limit(1) - .one_or_none() ) if idle_tidb_auth_binding: idle_tidb_auth_binding.active = True diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 5a8d3a2f3f..26531eab88 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -277,7 +277,7 @@ class Vector: return self._vector_processor.search_by_vector(query_vector, **kwargs) def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: - upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file: UploadFile | None = db.session.get(UploadFile, file_id) if not upload_file: return [] diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index e5b794f80d..40f45953af 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import Any from graphon.model_runtime.entities.model_entities import ModelType -from sqlalchemy import func, select +from sqlalchemy import delete, func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -63,10 +63,8 @@ class DatasetDocumentStore: return output def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False): - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == self._document_id) - .scalar() + max_position = db.session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id) ) if max_position is None: @@ -155,12 +153,14 @@ class DatasetDocumentStore: ) if save_child and doc.children: # delete the existing child chunks - db.session.query(ChildChunk).where( - ChildChunk.tenant_id == self._dataset.tenant_id, - ChildChunk.dataset_id == self._dataset.id, - ChildChunk.document_id == self._document_id, - ChildChunk.segment_id == segment_document.id, - ).delete() + db.session.execute( + delete(ChildChunk).where( + ChildChunk.tenant_id == self._dataset.tenant_id, + ChildChunk.dataset_id == self._dataset.id, + ChildChunk.document_id == self._document_id, + ChildChunk.segment_id == segment_document.id, + ) + ) # add new child chunks for position, child in enumerate(doc.children, start=1): child_segment = ChildChunk( diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 3bdad00712..8d1c0da392 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -6,6 +6,7 @@ from typing import Any, cast import numpy as np from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from sqlalchemy import select from sqlalchemy.exc import IntegrityError from configs import dify_config @@ -31,14 +32,14 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) - embedding = ( - db.session.query(Embedding) - .filter_by( - model_name=self._model_instance.model_name, - hash=hash, - provider_name=self._model_instance.provider, + embedding = db.session.scalar( + select(Embedding) + .where( + Embedding.model_name == self._model_instance.model_name, + Embedding.hash == hash, + Embedding.provider_name == self._model_instance.provider, ) - .first() + .limit(1) ) if embedding: text_embeddings[i] = embedding.get_embedding() @@ -112,14 +113,14 @@ class CacheEmbedding(Embeddings): embedding_queue_indices = [] for i, multimodel_document in enumerate(multimodel_documents): file_id = multimodel_document["file_id"] - embedding = ( - db.session.query(Embedding) - .filter_by( - model_name=self._model_instance.model_name, - hash=file_id, - provider_name=self._model_instance.provider, + embedding = db.session.scalar( + select(Embedding) + .where( + Embedding.model_name == self._model_instance.model_name, + Embedding.hash == file_id, + Embedding.provider_name == self._model_instance.provider, ) - .first() + .limit(1) ) if embedding: multimodel_embeddings[i] = embedding.get_embedding() diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 372af8fd94..aa36160711 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,6 +4,7 @@ import operator from typing import Any, cast import httpx +from sqlalchemy import update from configs import dify_config from core.rag.extractor.extractor_base import BaseExtractor @@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor): if data_source_info: data_source_info["last_edited_time"] = last_edited_time - db.session.query(DocumentModel).filter_by(id=document_model.id).update( - {DocumentModel.data_source_info: json.dumps(data_source_info)} - ) # type: ignore + db.session.execute( + update(DocumentModel) + .where(DocumentModel.id == document_model.id) + .values(data_source_info=json.dumps(data_source_info)) + ) db.session.commit() def get_notion_last_edited_time(self) -> str: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index a435dfc46a..7d504fdb35 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional from urllib.parse import unquote, urlparse import httpx +from sqlalchemy import select from typing_extensions import TypedDict from configs import dify_config @@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC): # Get unique IDs for database query unique_upload_file_ids = list(set(upload_file_id_list)) - upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all() # Create a mapping from ID to UploadFile for quick lookup upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} @@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC): """ from services.file_service import FileService - tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + tool_file = db.session.get(ToolFile, tool_file_id) if not tool_file: return None blob = storage.load_once(tool_file.file_key) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5c10ffbf2d..22ab492cbf 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -18,6 +18,7 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota @@ -145,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if delete_summaries: if node_ids: # Find segments by index_node_id - segments = ( - db.session.query(DocumentSegment) - .filter( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) @@ -537,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # Get unique IDs for database query unique_upload_file_ids = list(set(upload_file_id_list)) - upload_files = ( - db.session.query(UploadFile) - .where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) - .all() - ) + upload_files = db.session.scalars( + select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id) + ).all() # Create File objects from UploadFile records file_objects = [] diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 70504e6e50..1c5e02e9c8 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -6,6 +6,8 @@ import uuid from collections.abc import Mapping from typing import Any +from sqlalchemy import delete, select + from configs import dify_config from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail @@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_node_ids = precomputed_child_node_ids else: # Fallback to original query (may fail if segments are already deleted) - child_node_ids = ( - db.session.query(ChildChunk.index_node_id) + rows = db.session.execute( + select(ChildChunk.index_node_id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) .where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ChildChunk.dataset_id == dataset.id, ) - .all() - ) - child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]] + ).all() + child_node_ids = [row[0] for row in rows if row[0]] # Delete from vector index if child_node_ids: @@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # Delete from database if delete_child_chunks and child_node_ids: - db.session.query(ChildChunk).where( - ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) - ).delete(synchronize_session=False) + db.session.execute( + delete(ChildChunk).where( + ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) + ) + ) db.session.commit() else: vector.delete() if delete_child_chunks: # Use existing compound index: (tenant_id, dataset_id, ...) - db.session.query(ChildChunk).where( - ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id - ).delete(synchronize_session=False) + db.session.execute( + delete(ChildChunk).where( + ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id + ) + ) db.session.commit() def retrieve( diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 211a9f5c5c..8283be19f9 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -134,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner): ): if document.metadata.get("doc_type") == DocType.IMAGE: # Query file info within db.session context to ensure thread-safe access - upload_file = ( - db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() - ) + upload_file = db.session.get(UploadFile, document.metadata["doc_id"]) if upload_file: blob = storage.load_once(upload_file.key) document_file_base64 = base64.b64encode(blob).decode() @@ -169,7 +167,7 @@ class RerankModelRunner(BaseRerankRunner): return rerank_result, unique_documents elif query_type == QueryType.IMAGE_QUERY: # Query file info within db.session context to ensure thread-safe access - upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + upload_file = db.session.get(UploadFile, query) if upload_file: blob = storage.load_once(upload_file.key) file_query = base64.b64encode(blob).decode() diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 1abea6639e..593e1f1420 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1340,7 +1340,7 @@ class DatasetRetrieval: metadata_filtering_conditions: MetadataFilteringCondition | None, inputs: dict, ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: - document_query = db.session.query(DatasetDocument).where( + document_query = select(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -1411,7 +1411,7 @@ class DatasetRetrieval: document_query = document_query.where(and_(*filters)) else: document_query = document_query.where(or_(*filters)) - documents = document_query.all() + documents = db.session.scalars(document_query).all() # group by dataset_id metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore for document in documents: diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index 795a325a6b..bbdd476914 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -201,27 +201,23 @@ def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch, document_id = _Field("document_id") keyword = Jieba(_dataset(_dataset_keyword_table())) - query_stmt = _FakeQuery() - patched_runtime.session.query.return_value = query_stmt - patched_runtime.session.execute.return_value = _FakeExecuteResult( - [ - SimpleNamespace( - index_node_id="node-2", - content="segment-content", - index_node_hash="hash-2", - document_id="doc-2", - dataset_id="dataset-1", - ) - ] - ) + patched_runtime.session.scalars.return_value.all.return_value = [ + SimpleNamespace( + index_node_id="node-2", + content="segment-content", + index_node_hash="hash-2", + document_id="doc-2", + dataset_id="dataset-1", + ) + ] monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment) + monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect()) monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}})) monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"])) documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"]) - assert len(query_stmt.where_calls) == 2 assert len(documents) == 1 assert documents[0].page_content == "segment-content" assert documents[0].metadata["doc_id"] == "node-2" diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 63de4b8af2..5dbd62580a 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -714,13 +714,13 @@ class TestRetrievalServiceInternals: dataset_id="dataset-id", ) - dataset_query = Mock() - dataset_query.where.return_value.options.return_value.all.return_value = [ + scalars_result = Mock() + scalars_result.all.return_value = [ dataset_doc_parent, dataset_doc_text, dataset_doc_parent_summary, ] - monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query)) + monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result)) monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk) monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment) @@ -882,7 +882,7 @@ class TestRetrievalServiceInternals: def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch): rollback = Mock() monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback) - monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error"))) + monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error"))) documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index 54ad6d330b..4e9ceddda9 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -340,15 +340,13 @@ def test_search_by_file_handles_missing_and_existing_upload(vector_factory_modul vector._embeddings = MagicMock() vector._vector_processor = MagicMock() + mock_session = SimpleNamespace(get=lambda _model, _id: None) monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) - monkeypatch.setattr( - vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query)) - ) + monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session)) - upload_query.first.return_value = None assert vector.search_by_file("file-1") == [] - upload_query.first.return_value = SimpleNamespace(key="blob-key") + mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key") monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes")) vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4] vector._vector_processor.search_by_vector.return_value = ["hit"] diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 3ba0628fe2..a7b7c1595b 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -167,7 +167,7 @@ class TestDatasetDocumentStoreAddDocuments: ): mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None mock_manager = MagicMock() mock_manager.get_model_instance.return_value = mock_model_instance @@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + mock_db.session.scalar.return_value = 5 with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + mock_db.session.scalar.return_value = 5 with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): @@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild: store.add_documents([mock_doc], save_child=True) - mock_db.session.query.return_value.where.return_value.delete.assert_called() + mock_db.session.execute.assert_called() mock_db.session.commit.assert_called() @@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer: with patch("core.rag.docstore.dataset_docstore.db") as mock_db: mock_session = MagicMock() mock_db.session = mock_session - mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + mock_db.session.scalar.return_value = 5 with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index 6fd44be4d4..3563186186 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -69,7 +69,7 @@ class TestCacheEmbeddingMultimodalDocuments: documents = [{"file_id": "file123", "content": "test content"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result result = cache_embedding.embed_multimodal_documents(documents) @@ -114,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result result = cache_embedding.embed_multimodal_documents(documents) @@ -134,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments: mock_cached_embedding.get_embedding.return_value = normalized_cached with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + mock_session.scalar.return_value = mock_cached_embedding result = cache_embedding.embed_multimodal_documents(documents) @@ -180,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - call_count = [0] - - def mock_filter_by(**kwargs): - call_count[0] += 1 - mock_query = Mock() - if call_count[0] == 1: - mock_query.first.return_value = mock_cached_embedding - else: - mock_query.first.return_value = None - return mock_query - - mock_session.query.return_value.filter_by = mock_filter_by + mock_session.scalar.side_effect = [mock_cached_embedding, None, None] mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result result = cache_embedding.embed_multimodal_documents(documents) @@ -224,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: @@ -265,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)] mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results @@ -281,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments: documents = [{"file_id": "file123"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") with pytest.raises(Exception) as exc_info: @@ -298,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments: documents = [{"file_id": "file123"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None) diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index d7ba944e58..408cf14a51 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -139,7 +139,7 @@ class TestCacheEmbeddingDocuments: # Mock database query to return no cached embedding (cache miss) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model invocation mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result @@ -203,7 +203,7 @@ class TestCacheEmbeddingDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -240,7 +240,7 @@ class TestCacheEmbeddingDocuments: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: # Mock database to return cached embedding (cache hit) - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + mock_session.scalar.return_value = mock_cached_embedding # Act result = cache_embedding.embed_documents(texts) @@ -313,19 +313,7 @@ class TestCacheEmbeddingDocuments: mock_hash.side_effect = generate_hash # Mock database to return cached embedding only for first text (hash_1) - call_count = [0] - - def mock_filter_by(**kwargs): - call_count[0] += 1 - mock_query = Mock() - # First call (hash_1) returns cached, others return None - if call_count[0] == 1: - mock_query.first.return_value = mock_cached_embedding - else: - mock_query.first.return_value = None - return mock_query - - mock_session.query.return_value.filter_by = mock_filter_by + mock_session.scalar.side_effect = [mock_cached_embedding, None, None] mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -392,7 +380,7 @@ class TestCacheEmbeddingDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to return appropriate batch results batch_results = [ @@ -455,7 +443,7 @@ class TestCacheEmbeddingDocuments: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: @@ -489,7 +477,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to raise connection error mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API") @@ -515,7 +503,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to raise rate limit error mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded") @@ -539,7 +527,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to raise authorization error mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key") @@ -564,7 +552,7 @@ class TestCacheEmbeddingDocuments: texts = ["Test text"] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result # Mock database commit to raise IntegrityError @@ -884,7 +872,7 @@ class TestEmbeddingModelSwitching: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None model_instance_ada.invoke_text_embedding.return_value = result_ada model_instance_3_small.invoke_text_embedding.return_value = result_3_small @@ -1047,7 +1035,7 @@ class TestEmbeddingDimensionValidation: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1100,7 +1088,7 @@ class TestEmbeddingDimensionValidation: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1186,7 +1174,7 @@ class TestEmbeddingDimensionValidation: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None model_instance_ada.invoke_text_embedding.return_value = result_ada model_instance_cohere.invoke_text_embedding.return_value = result_cohere @@ -1284,7 +1272,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1327,7 +1315,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1375,7 +1363,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1427,7 +1415,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1483,7 +1471,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1551,7 +1539,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1649,7 +1637,7 @@ class TestEmbeddingEdgeCases: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None mock_model_instance.invoke_text_embedding.return_value = embedding_result # Act @@ -1728,7 +1716,7 @@ class TestEmbeddingCachePerformance: with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: # First call: cache miss - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None usage = EmbeddingUsage( tokens=5, @@ -1756,7 +1744,7 @@ class TestEmbeddingCachePerformance: assert len(result1) == 1 # Arrange - Second call: cache hit - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + mock_session.scalar.return_value = mock_cached_embedding # Act - Second call (cache hit) result2 = cache_embedding.embed_documents([text]) @@ -1816,7 +1804,7 @@ class TestEmbeddingCachePerformance: ) with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Mock model to return appropriate batch results batch_results = [ diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index 6daee11f8f..808e41867e 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods: class FakeDocumentModel: data_source_info = "data_source_info" + id = "id" - update_calls = [] + execute_calls = [] - class FakeQuery: - def filter_by(self, **kwargs): + class FakeUpdateStmt: + def where(self, *args): return self - def update(self, payload): - update_calls.append(payload) + def values(self, **kwargs): + return self class FakeSession: committed = False - def query(self, model): - assert model is FakeDocumentModel - return FakeQuery() + def execute(self, stmt): + execute_calls.append(stmt) def commit(self): self.committed = True fake_db = SimpleNamespace(session=FakeSession()) monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) + monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt()) monkeypatch.setattr(notion_extractor, "db", fake_db) monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) extractor.update_last_edited_time(doc_model) - assert update_calls + assert execute_calls assert fake_db.session.committed is True def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index cc2873dd3f..d4b987c832 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -188,10 +188,10 @@ class TestParagraphIndexProcessor: mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: - segment_query = Mock() - segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="seg-1")] session = Mock() - session.query.return_value = segment_query + session.scalars.return_value = scalars_result with ( patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), @@ -531,10 +531,10 @@ class TestParagraphIndexProcessor: size=1, key="key", ) - query = Mock() - query.where.return_value.all.return_value = [image_upload, non_image_upload] + scalars_result = Mock() + scalars_result.all.return_value = [image_upload, non_image_upload] session = Mock() - session.query.return_value = query + session.scalars.return_value = scalars_result with ( patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), @@ -565,10 +565,10 @@ class TestParagraphIndexProcessor: size=1, key="key", ) - query = Mock() - query.where.return_value.all.return_value = [image_upload] + scalars_result = Mock() + scalars_result.all.return_value = [image_upload] session = Mock() - session.query.return_value = query + session.scalars.return_value = scalars_result with ( patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index b1ed735ee7..d363a0804d 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -208,11 +208,7 @@ class TestParentChildIndexProcessor: vector.create_multimodal.assert_called_once_with(multimodal_docs) def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - delete_query = Mock() - where_query = Mock() - where_query.delete.return_value = 2 session = Mock() - session.query.return_value.where.return_value = where_query with ( patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, @@ -227,16 +223,16 @@ class TestParentChildIndexProcessor: ) vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) - where_query.delete.assert_called_once_with(synchronize_session=False) + session.execute.assert_called() session.commit.assert_called_once() def test_clean_queries_child_ids_when_not_precomputed( self, processor: ParentChildIndexProcessor, dataset: Mock ) -> None: - child_query = Mock() - child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] + execute_result = Mock() + execute_result.all.return_value = [("child-1",), (None,), ("child-2",)] session = Mock() - session.query.return_value = child_query + session.execute.return_value = execute_result with ( patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, @@ -248,10 +244,7 @@ class TestParentChildIndexProcessor: vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - where_query = Mock() - where_query.delete.return_value = 3 session = Mock() - session.query.return_value.where.return_value = where_query with ( patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, @@ -261,7 +254,7 @@ class TestParentChildIndexProcessor: processor.clean(dataset, None, delete_child_chunks=True) vector.delete.assert_called_once() - where_query.delete.assert_called_once_with(synchronize_session=False) + session.execute.assert_called() session.commit.assert_called_once() def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py index b31bb6eea7..12c5238f5e 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -133,10 +133,10 @@ class TestBaseIndexProcessor: upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") - db_query = Mock() - db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] + scalars_result = Mock() + scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] db_session = Mock() - db_session.query.return_value = db_query + db_session.scalars.return_value = scalars_result with ( patch.object(processor, "_extract_markdown_images", return_value=images), @@ -170,10 +170,10 @@ class TestBaseIndexProcessor: def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] - db_query = Mock() - db_query.where.return_value.all.return_value = [] + scalars_result = Mock() + scalars_result.all.return_value = [] db_session = Mock() - db_session.query.return_value = db_query + db_session.scalars.return_value = scalars_result with ( patch.object(processor, "_extract_markdown_images", return_value=images), @@ -259,20 +259,16 @@ class TestBaseIndexProcessor: assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: - db_query = Mock() - db_query.where.return_value.first.return_value = None db_session = Mock() - db_session.query.return_value = db_query + db_session.get.return_value = None with patch("core.rag.index_processor.index_processor_base.db.session", db_session): assert processor._download_tool_file("tool-id", current_user=Mock()) is None def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") - db_query = Mock() - db_query.where.return_value.first.return_value = tool_file db_session = Mock() - db_session.query.return_value = db_query + db_session.get.return_value = tool_file mock_db = Mock() mock_db.session = db_session mock_db.engine = Mock() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 2ec7f0498e..c279b00d3b 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -473,12 +473,10 @@ class TestRerankModelRunnerMultimodal: metadata={}, provider="external", ) - query = Mock() - query.where.return_value.first.return_value = SimpleNamespace(key="image-key") rerank_result = RerankResult(model="rerank-model", docs=[]) with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, patch.object( rerank_runner, @@ -504,12 +502,10 @@ class TestRerankModelRunnerMultimodal: metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, provider="dify", ) - query = Mock() - query.where.return_value.first.return_value = None rerank_result = RerankResult(model="rerank-model", docs=[]) with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.db.session.get", return_value=None), patch.object( rerank_runner, "fetch_text_rerank", @@ -533,8 +529,6 @@ class TestRerankModelRunnerMultimodal: metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, provider="dify", ) - query_chain = Mock() - query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key") rerank_result = RerankResult( model="rerank-model", docs=[RerankDocument(index=0, text="text-content", score=0.77)], @@ -542,7 +536,7 @@ class TestRerankModelRunnerMultimodal: mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result session = MagicMock() - session.query.return_value = query_chain + session.get.return_value = SimpleNamespace(key="query-image-key") with ( patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), @@ -563,10 +557,7 @@ class TestRerankModelRunnerMultimodal: assert "user" not in invoke_kwargs def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): - query_chain = Mock() - query_chain.where.return_value.first.return_value = None - - with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain): + with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None): with pytest.raises(ValueError, match="Upload file not found for query"): rerank_runner.fetch_multimodal_rerank( query="missing-upload-id", diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index c11426163e..fee7b168ad 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -3971,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers: ) def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: - db_query = Mock() - db_query.where.return_value = db_query - db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] - with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): mapping, condition = retrieval.get_metadata_filter_condition( dataset_ids=["d1"], query="python", @@ -3991,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers: automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] with ( - patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result), patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), ): mapping, condition = retrieval.get_metadata_filter_condition( @@ -4012,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers: logical_operator="and", conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], ) - with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): mapping, condition = retrieval.get_metadata_filter_condition( dataset_ids=["d1"], query="python", @@ -4027,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers: assert condition is not None assert condition.conditions[0].value == "Alice" - with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result): with pytest.raises(ValueError, match="Invalid metadata filtering mode"): retrieval.get_metadata_filter_condition( dataset_ids=["d1"],