refactor: core/rag docstore, datasource, embedding, rerank, retrieval (#34203)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Renzo
2026-03-30 10:09:49 +02:00
committed by GitHub
parent 40fa0f365c
commit 456684dfc3
24 changed files with 170 additions and 214 deletions

View File

@@ -97,13 +97,13 @@ class Jieba(BaseKeyword):
documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
segment_query_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segments = db.session.scalars(segment_query_stmt).all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices:
segment = segment_map.get(chunk_index)

View File

@@ -432,10 +432,11 @@ class RetrievalService:
# Batch query dataset documents
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
for doc in db.session.scalars(
select(DatasetDocument)
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
).all()
}
valid_dataset_documents = {}

View File

@@ -426,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
idle_tidb_auth_binding = db.session.scalar(
select(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True

View File

@@ -277,7 +277,7 @@ class Vector:
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file: UploadFile | None = db.session.get(UploadFile, file_id)
if not upload_file:
return []

View File

@@ -4,7 +4,7 @@ from collections.abc import Sequence
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from core.model_manager import ModelManager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@@ -63,10 +63,8 @@ class DatasetDocumentStore:
return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == self._document_id)
.scalar()
max_position = db.session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id)
)
if max_position is None:
@@ -155,12 +153,14 @@ class DatasetDocumentStore:
)
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
).delete()
db.session.execute(
delete(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
)
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(

View File

@@ -6,6 +6,7 @@ from typing import Any, cast
import numpy as np
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from configs import dify_config
@@ -31,14 +32,14 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
embedding = db.session.scalar(
select(Embedding)
.where(
Embedding.model_name == self._model_instance.model_name,
Embedding.hash == hash,
Embedding.provider_name == self._model_instance.provider,
)
.first()
.limit(1)
)
if embedding:
text_embeddings[i] = embedding.get_embedding()
@@ -112,14 +113,14 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, multimodel_document in enumerate(multimodel_documents):
file_id = multimodel_document["file_id"]
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
embedding = db.session.scalar(
select(Embedding)
.where(
Embedding.model_name == self._model_instance.model_name,
Embedding.hash == file_id,
Embedding.provider_name == self._model_instance.provider,
)
.first()
.limit(1)
)
if embedding:
multimodel_embeddings[i] = embedding.get_embedding()

View File

@@ -4,6 +4,7 @@ import operator
from typing import Any, cast
import httpx
from sqlalchemy import update
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
@@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor):
if data_source_info:
data_source_info["last_edited_time"] = last_edited_time
db.session.query(DocumentModel).filter_by(id=document_model.id).update(
{DocumentModel.data_source_info: json.dumps(data_source_info)}
) # type: ignore
db.session.execute(
update(DocumentModel)
.where(DocumentModel.id == document_model.id)
.values(data_source_info=json.dumps(data_source_info))
)
db.session.commit()
def get_notion_last_edited_time(self) -> str:

View File

@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional
from urllib.parse import unquote, urlparse
import httpx
from sqlalchemy import select
from typing_extensions import TypedDict
from configs import dify_config
@@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC):
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all()
# Create a mapping from ID to UploadFile for quick lookup
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
@@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC):
"""
from services.file_service import FileService
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
tool_file = db.session.get(ToolFile, tool_file_id)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)

View File

@@ -18,6 +18,7 @@ from graphon.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from sqlalchemy import select
from core.app.file_access import DatabaseFileAccessController
from core.app.llm import deduct_llm_quota
@@ -145,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
@@ -537,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = (
db.session.query(UploadFile)
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
.all()
)
upload_files = db.session.scalars(
select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
).all()
# Create File objects from UploadFile records
file_objects = []

View File

@@ -6,6 +6,8 @@ import uuid
from collections.abc import Mapping
from typing import Any
from sqlalchemy import delete, select
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.knowledge_entities import PreviewDetail
@@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = precomputed_child_node_ids
else:
# Fallback to original query (may fail if segments are already deleted)
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
rows = db.session.execute(
select(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
).all()
child_node_ids = [row[0] for row in rows if row[0]]
# Delete from vector index
if child_node_ids:
@@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# Delete from database
if delete_child_chunks and child_node_ids:
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete(synchronize_session=False)
db.session.execute(
delete(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
)
)
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
# Use existing compound index: (tenant_id, dataset_id, ...)
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
).delete(synchronize_session=False)
db.session.execute(
delete(ChildChunk).where(
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
)
)
db.session.commit()
def retrieve(

View File

@@ -134,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner):
):
if document.metadata.get("doc_type") == DocType.IMAGE:
# Query file info within db.session context to ensure thread-safe access
upload_file = (
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
)
upload_file = db.session.get(UploadFile, document.metadata["doc_id"])
if upload_file:
blob = storage.load_once(upload_file.key)
document_file_base64 = base64.b64encode(blob).decode()
@@ -169,7 +167,7 @@ class RerankModelRunner(BaseRerankRunner):
return rerank_result, unique_documents
elif query_type == QueryType.IMAGE_QUERY:
# Query file info within db.session context to ensure thread-safe access
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
upload_file = db.session.get(UploadFile, query)
if upload_file:
blob = storage.load_once(upload_file.key)
file_query = base64.b64encode(blob).decode()

View File

@@ -1340,7 +1340,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict,
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
document_query = db.session.query(DatasetDocument).where(
document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
documents = document_query.all()
documents = db.session.scalars(document_query).all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:

View File

@@ -201,27 +201,23 @@ def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch,
document_id = _Field("document_id")
keyword = Jieba(_dataset(_dataset_keyword_table()))
query_stmt = _FakeQuery()
patched_runtime.session.query.return_value = query_stmt
patched_runtime.session.execute.return_value = _FakeExecuteResult(
[
SimpleNamespace(
index_node_id="node-2",
content="segment-content",
index_node_hash="hash-2",
document_id="doc-2",
dataset_id="dataset-1",
)
]
)
patched_runtime.session.scalars.return_value.all.return_value = [
SimpleNamespace(
index_node_id="node-2",
content="segment-content",
index_node_hash="hash-2",
document_id="doc-2",
dataset_id="dataset-1",
)
]
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
assert len(query_stmt.where_calls) == 2
assert len(documents) == 1
assert documents[0].page_content == "segment-content"
assert documents[0].metadata["doc_id"] == "node-2"

View File

@@ -714,13 +714,13 @@ class TestRetrievalServiceInternals:
dataset_id="dataset-id",
)
dataset_query = Mock()
dataset_query.where.return_value.options.return_value.all.return_value = [
scalars_result = Mock()
scalars_result.all.return_value = [
dataset_doc_parent,
dataset_doc_text,
dataset_doc_parent_summary,
]
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query))
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result))
monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk)
monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment)
@@ -882,7 +882,7 @@ class TestRetrievalServiceInternals:
def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch):
rollback = Mock()
monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback)
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error")))
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error")))
documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")]

View File

@@ -340,15 +340,13 @@ def test_search_by_file_handles_missing_and_existing_upload(vector_factory_modul
vector._embeddings = MagicMock()
vector._vector_processor = MagicMock()
mock_session = SimpleNamespace(get=lambda _model, _id: None)
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
monkeypatch.setattr(
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
)
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
upload_query.first.return_value = None
assert vector.search_by_file("file-1") == []
upload_query.first.return_value = SimpleNamespace(key="blob-key")
mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key")
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
vector._vector_processor.search_by_vector.return_value = ["hit"]

View File

@@ -167,7 +167,7 @@ class TestDatasetDocumentStoreAddDocuments:
):
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
mock_manager = MagicMock()
mock_manager.get_model_instance.return_value = mock_model_instance
@@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
@@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
store.add_documents([mock_doc], save_child=True)
mock_db.session.query.return_value.where.return_value.delete.assert_called()
mock_db.session.execute.assert_called()
mock_db.session.commit.assert_called()
@@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
mock_session = MagicMock()
mock_db.session = mock_session
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
mock_db.session.scalar.return_value = 5
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):

View File

@@ -69,7 +69,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123", "content": "test content"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
result = cache_embedding.embed_multimodal_documents(documents)
@@ -114,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents)
@@ -134,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments:
mock_cached_embedding.get_embedding.return_value = normalized_cached
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
mock_session.scalar.return_value = mock_cached_embedding
result = cache_embedding.embed_multimodal_documents(documents)
@@ -180,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
call_count = [0]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
result = cache_embedding.embed_multimodal_documents(documents)
@@ -224,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
@@ -265,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
@@ -281,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info:
@@ -298,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments:
documents = [{"file_id": "file123"}]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)

View File

@@ -139,7 +139,7 @@ class TestCacheEmbeddingDocuments:
# Mock database query to return no cached embedding (cache miss)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model invocation
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
@@ -203,7 +203,7 @@ class TestCacheEmbeddingDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -240,7 +240,7 @@ class TestCacheEmbeddingDocuments:
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
# Mock database to return cached embedding (cache hit)
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
mock_session.scalar.return_value = mock_cached_embedding
# Act
result = cache_embedding.embed_documents(texts)
@@ -313,19 +313,7 @@ class TestCacheEmbeddingDocuments:
mock_hash.side_effect = generate_hash
# Mock database to return cached embedding only for first text (hash_1)
call_count = [0]
def mock_filter_by(**kwargs):
call_count[0] += 1
mock_query = Mock()
# First call (hash_1) returns cached, others return None
if call_count[0] == 1:
mock_query.first.return_value = mock_cached_embedding
else:
mock_query.first.return_value = None
return mock_query
mock_session.query.return_value.filter_by = mock_filter_by
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -392,7 +380,7 @@ class TestCacheEmbeddingDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to return appropriate batch results
batch_results = [
@@ -455,7 +443,7 @@ class TestCacheEmbeddingDocuments:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
@@ -489,7 +477,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to raise connection error
mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API")
@@ -515,7 +503,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to raise rate limit error
mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded")
@@ -539,7 +527,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to raise authorization error
mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key")
@@ -564,7 +552,7 @@ class TestCacheEmbeddingDocuments:
texts = ["Test text"]
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
# Mock database commit to raise IntegrityError
@@ -884,7 +872,7 @@ class TestEmbeddingModelSwitching:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
model_instance_ada.invoke_text_embedding.return_value = result_ada
model_instance_3_small.invoke_text_embedding.return_value = result_3_small
@@ -1047,7 +1035,7 @@ class TestEmbeddingDimensionValidation:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1100,7 +1088,7 @@ class TestEmbeddingDimensionValidation:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1186,7 +1174,7 @@ class TestEmbeddingDimensionValidation:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
model_instance_ada.invoke_text_embedding.return_value = result_ada
model_instance_cohere.invoke_text_embedding.return_value = result_cohere
@@ -1284,7 +1272,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1327,7 +1315,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1375,7 +1363,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1427,7 +1415,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1483,7 +1471,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1551,7 +1539,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1649,7 +1637,7 @@ class TestEmbeddingEdgeCases:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
mock_model_instance.invoke_text_embedding.return_value = embedding_result
# Act
@@ -1728,7 +1716,7 @@ class TestEmbeddingCachePerformance:
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
# First call: cache miss
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
usage = EmbeddingUsage(
tokens=5,
@@ -1756,7 +1744,7 @@ class TestEmbeddingCachePerformance:
assert len(result1) == 1
# Arrange - Second call: cache hit
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
mock_session.scalar.return_value = mock_cached_embedding
# Act - Second call (cache hit)
result2 = cache_embedding.embed_documents([text])
@@ -1816,7 +1804,7 @@ class TestEmbeddingCachePerformance:
)
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
mock_session.scalar.return_value = None
# Mock model to return appropriate batch results
batch_results = [

View File

@@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods:
class FakeDocumentModel:
data_source_info = "data_source_info"
id = "id"
update_calls = []
execute_calls = []
class FakeQuery:
def filter_by(self, **kwargs):
class FakeUpdateStmt:
def where(self, *args):
return self
def update(self, payload):
update_calls.append(payload)
def values(self, **kwargs):
return self
class FakeSession:
committed = False
def query(self, model):
assert model is FakeDocumentModel
return FakeQuery()
def execute(self, stmt):
execute_calls.append(stmt)
def commit(self):
self.committed = True
fake_db = SimpleNamespace(session=FakeSession())
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt())
monkeypatch.setattr(notion_extractor, "db", fake_db)
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
extractor.update_last_edited_time(doc_model)
assert update_calls
assert execute_calls
assert fake_db.session.committed is True
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):

View File

@@ -188,10 +188,10 @@ class TestParagraphIndexProcessor:
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
segment_query = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock()
session.query.return_value = segment_query
session.scalars.return_value = scalars_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
@@ -531,10 +531,10 @@ class TestParagraphIndexProcessor:
size=1,
key="key",
)
query = Mock()
query.where.return_value.all.return_value = [image_upload, non_image_upload]
scalars_result = Mock()
scalars_result.all.return_value = [image_upload, non_image_upload]
session = Mock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
@@ -565,10 +565,10 @@ class TestParagraphIndexProcessor:
size=1,
key="key",
)
query = Mock()
query.where.return_value.all.return_value = [image_upload]
scalars_result = Mock()
scalars_result.all.return_value = [image_upload]
session = Mock()
session.query.return_value = query
session.scalars.return_value = scalars_result
with (
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),

View File

@@ -208,11 +208,7 @@ class TestParentChildIndexProcessor:
vector.create_multimodal.assert_called_once_with(multimodal_docs)
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
delete_query = Mock()
where_query = Mock()
where_query.delete.return_value = 2
session = Mock()
session.query.return_value.where.return_value = where_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@@ -227,16 +223,16 @@ class TestParentChildIndexProcessor:
)
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
where_query.delete.assert_called_once_with(synchronize_session=False)
session.execute.assert_called()
session.commit.assert_called_once()
def test_clean_queries_child_ids_when_not_precomputed(
self, processor: ParentChildIndexProcessor, dataset: Mock
) -> None:
child_query = Mock()
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)]
execute_result = Mock()
execute_result.all.return_value = [("child-1",), (None,), ("child-2",)]
session = Mock()
session.query.return_value = child_query
session.execute.return_value = execute_result
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@@ -248,10 +244,7 @@ class TestParentChildIndexProcessor:
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
where_query = Mock()
where_query.delete.return_value = 3
session = Mock()
session.query.return_value.where.return_value = where_query
with (
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
@@ -261,7 +254,7 @@ class TestParentChildIndexProcessor:
processor.clean(dataset, None, delete_child_chunks=True)
vector.delete.assert_called_once()
where_query.delete.assert_called_once_with(synchronize_session=False)
session.execute.assert_called()
session.commit.assert_called_once()
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:

View File

@@ -133,10 +133,10 @@ class TestBaseIndexProcessor:
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
db_query = Mock()
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
scalars_result = Mock()
scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
db_session = Mock()
db_session.query.return_value = db_query
db_session.scalars.return_value = scalars_result
with (
patch.object(processor, "_extract_markdown_images", return_value=images),
@@ -170,10 +170,10 @@ class TestBaseIndexProcessor:
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
db_query = Mock()
db_query.where.return_value.all.return_value = []
scalars_result = Mock()
scalars_result.all.return_value = []
db_session = Mock()
db_session.query.return_value = db_query
db_session.scalars.return_value = scalars_result
with (
patch.object(processor, "_extract_markdown_images", return_value=images),
@@ -259,20 +259,16 @@ class TestBaseIndexProcessor:
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
db_query = Mock()
db_query.where.return_value.first.return_value = None
db_session = Mock()
db_session.query.return_value = db_query
db_session.get.return_value = None
with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
assert processor._download_tool_file("tool-id", current_user=Mock()) is None
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
db_query = Mock()
db_query.where.return_value.first.return_value = tool_file
db_session = Mock()
db_session.query.return_value = db_query
db_session.get.return_value = tool_file
mock_db = Mock()
mock_db.session = db_session
mock_db.engine = Mock()

View File

@@ -473,12 +473,10 @@ class TestRerankModelRunnerMultimodal:
metadata={},
provider="external",
)
query = Mock()
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
rerank_result = RerankResult(model="rerank-model", docs=[])
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
patch.object(
rerank_runner,
@@ -504,12 +502,10 @@ class TestRerankModelRunnerMultimodal:
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
provider="dify",
)
query = Mock()
query.where.return_value.first.return_value = None
rerank_result = RerankResult(model="rerank-model", docs=[])
with (
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
patch("core.rag.rerank.rerank_model.db.session.get", return_value=None),
patch.object(
rerank_runner,
"fetch_text_rerank",
@@ -533,8 +529,6 @@ class TestRerankModelRunnerMultimodal:
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
provider="dify",
)
query_chain = Mock()
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
rerank_result = RerankResult(
model="rerank-model",
docs=[RerankDocument(index=0, text="text-content", score=0.77)],
@@ -542,7 +536,7 @@ class TestRerankModelRunnerMultimodal:
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
session = MagicMock()
session.query.return_value = query_chain
session.get.return_value = SimpleNamespace(key="query-image-key")
with (
patch("core.rag.rerank.rerank_model.db.session", session),
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
@@ -563,10 +557,7 @@ class TestRerankModelRunnerMultimodal:
assert "user" not in invoke_kwargs
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
query_chain = Mock()
query_chain.where.return_value.first.return_value = None
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None):
with pytest.raises(ValueError, match="Upload file not found for query"):
rerank_runner.fetch_multimodal_rerank(
query="missing-upload-id",

View File

@@ -3971,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers:
)
def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None:
db_query = Mock()
db_query.where.return_value = db_query
db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
mapping, condition = retrieval.get_metadata_filter_condition(
dataset_ids=["d1"],
query="python",
@@ -3991,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers:
automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}]
with (
patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query),
patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result),
patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters),
):
mapping, condition = retrieval.get_metadata_filter_condition(
@@ -4012,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers:
logical_operator="and",
conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")],
)
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
mapping, condition = retrieval.get_metadata_filter_condition(
dataset_ids=["d1"],
query="python",
@@ -4027,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers:
assert condition is not None
assert condition.conditions[0].value == "Alice"
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
retrieval.get_metadata_filter_condition(
dataset_ids=["d1"],