mirror of
https://github.com/langgenius/dify.git
synced 2026-04-04 12:00:30 -04:00
refactor: core/rag docstore, datasource, embedding, rerank, retrieval (#34203)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
@@ -97,13 +97,13 @@ class Jieba(BaseKeyword):
|
||||
|
||||
documents = []
|
||||
|
||||
segment_query_stmt = db.session.query(DocumentSegment).where(
|
||||
segment_query_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
|
||||
segments = db.session.execute(segment_query_stmt).scalars().all()
|
||||
segments = db.session.scalars(segment_query_stmt).all()
|
||||
segment_map = {segment.index_node_id: segment for segment in segments}
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = segment_map.get(chunk_index)
|
||||
|
||||
@@ -432,10 +432,11 @@ class RetrievalService:
|
||||
# Batch query dataset documents
|
||||
dataset_documents = {
|
||||
doc.id: doc
|
||||
for doc in db.session.query(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
.all()
|
||||
for doc in db.session.scalars(
|
||||
select(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(document_ids))
|
||||
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
|
||||
).all()
|
||||
}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
|
||||
@@ -426,11 +426,10 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
else:
|
||||
idle_tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
idle_tidb_auth_binding = db.session.scalar(
|
||||
select(TidbAuthBinding)
|
||||
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
|
||||
.limit(1)
|
||||
.one_or_none()
|
||||
)
|
||||
if idle_tidb_auth_binding:
|
||||
idle_tidb_auth_binding.active = True
|
||||
|
||||
@@ -277,7 +277,7 @@ class Vector:
|
||||
return self._vector_processor.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]:
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file: UploadFile | None = db.session.get(UploadFile, file_id)
|
||||
|
||||
if not upload_file:
|
||||
return []
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@@ -63,10 +63,8 @@ class DatasetDocumentStore:
|
||||
return output
|
||||
|
||||
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False):
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == self._document_id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == self._document_id)
|
||||
)
|
||||
|
||||
if max_position is None:
|
||||
@@ -155,12 +153,14 @@ class DatasetDocumentStore:
|
||||
)
|
||||
if save_child and doc.children:
|
||||
# delete the existing child chunks
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == self._dataset.tenant_id,
|
||||
ChildChunk.dataset_id == self._dataset.id,
|
||||
ChildChunk.document_id == self._document_id,
|
||||
ChildChunk.segment_id == segment_document.id,
|
||||
).delete()
|
||||
db.session.execute(
|
||||
delete(ChildChunk).where(
|
||||
ChildChunk.tenant_id == self._dataset.tenant_id,
|
||||
ChildChunk.dataset_id == self._dataset.id,
|
||||
ChildChunk.document_id == self._document_id,
|
||||
ChildChunk.segment_id == segment_document.id,
|
||||
)
|
||||
)
|
||||
# add new child chunks
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
child_segment = ChildChunk(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
import numpy as np
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
@@ -31,14 +32,14 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding = db.session.scalar(
|
||||
select(Embedding)
|
||||
.where(
|
||||
Embedding.model_name == self._model_instance.model_name,
|
||||
Embedding.hash == hash,
|
||||
Embedding.provider_name == self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
@@ -112,14 +113,14 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_indices = []
|
||||
for i, multimodel_document in enumerate(multimodel_documents):
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding = db.session.scalar(
|
||||
select(Embedding)
|
||||
.where(
|
||||
Embedding.model_name == self._model_instance.model_name,
|
||||
Embedding.hash == file_id,
|
||||
Embedding.provider_name == self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if embedding:
|
||||
multimodel_embeddings[i] = embedding.get_embedding()
|
||||
|
||||
@@ -4,6 +4,7 @@ import operator
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import update
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -346,9 +347,11 @@ class NotionExtractor(BaseExtractor):
|
||||
if data_source_info:
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
|
||||
db.session.query(DocumentModel).filter_by(id=document_model.id).update(
|
||||
{DocumentModel.data_source_info: json.dumps(data_source_info)}
|
||||
) # type: ignore
|
||||
db.session.execute(
|
||||
update(DocumentModel)
|
||||
.where(DocumentModel.id == document_model.id)
|
||||
.values(data_source_info=json.dumps(data_source_info))
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, NotRequired, Optional
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
@@ -200,7 +201,7 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all()
|
||||
upload_files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids))).all()
|
||||
|
||||
# Create a mapping from ID to UploadFile for quick lookup
|
||||
upload_file_map = {upload_file.id: upload_file for upload_file in upload_files}
|
||||
@@ -312,7 +313,7 @@ class BaseIndexProcessor(ABC):
|
||||
"""
|
||||
from services.file_service import FileService
|
||||
|
||||
tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||
tool_file = db.session.get(ToolFile, tool_file_id)
|
||||
if not tool_file:
|
||||
return None
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
@@ -18,6 +18,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.llm import deduct_llm_quota
|
||||
@@ -145,14 +146,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
@@ -537,11 +536,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
upload_files = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
# Create File objects from UploadFile records
|
||||
file_objects = []
|
||||
|
||||
@@ -6,6 +6,8 @@ import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
@@ -177,17 +179,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_node_ids = precomputed_child_node_ids
|
||||
else:
|
||||
# Fallback to original query (may fail if segments are already deleted)
|
||||
child_node_ids = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
rows = db.session.execute(
|
||||
select(ChildChunk.index_node_id)
|
||||
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]]
|
||||
).all()
|
||||
child_node_ids = [row[0] for row in rows if row[0]]
|
||||
|
||||
# Delete from vector index
|
||||
if child_node_ids:
|
||||
@@ -195,18 +196,22 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Delete from database
|
||||
if delete_child_chunks and child_node_ids:
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
).delete(synchronize_session=False)
|
||||
db.session.execute(
|
||||
delete(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
if delete_child_chunks:
|
||||
# Use existing compound index: (tenant_id, dataset_id, ...)
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
|
||||
).delete(synchronize_session=False)
|
||||
db.session.execute(
|
||||
delete(ChildChunk).where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
def retrieve(
|
||||
|
||||
@@ -134,9 +134,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
):
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first()
|
||||
)
|
||||
upload_file = db.session.get(UploadFile, document.metadata["doc_id"])
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
document_file_base64 = base64.b64encode(blob).decode()
|
||||
@@ -169,7 +167,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
return rerank_result, unique_documents
|
||||
elif query_type == QueryType.IMAGE_QUERY:
|
||||
# Query file info within db.session context to ensure thread-safe access
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first()
|
||||
upload_file = db.session.get(UploadFile, query)
|
||||
if upload_file:
|
||||
blob = storage.load_once(upload_file.key)
|
||||
file_query = base64.b64encode(blob).decode()
|
||||
|
||||
@@ -1340,7 +1340,7 @@ class DatasetRetrieval:
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]:
|
||||
document_query = db.session.query(DatasetDocument).where(
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
@@ -1411,7 +1411,7 @@ class DatasetRetrieval:
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
documents = db.session.scalars(document_query).all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
|
||||
@@ -201,27 +201,23 @@ def test_search_returns_documents_in_rank_order_and_applies_filter(monkeypatch,
|
||||
document_id = _Field("document_id")
|
||||
|
||||
keyword = Jieba(_dataset(_dataset_keyword_table()))
|
||||
query_stmt = _FakeQuery()
|
||||
patched_runtime.session.query.return_value = query_stmt
|
||||
patched_runtime.session.execute.return_value = _FakeExecuteResult(
|
||||
[
|
||||
SimpleNamespace(
|
||||
index_node_id="node-2",
|
||||
content="segment-content",
|
||||
index_node_hash="hash-2",
|
||||
document_id="doc-2",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
]
|
||||
)
|
||||
patched_runtime.session.scalars.return_value.all.return_value = [
|
||||
SimpleNamespace(
|
||||
index_node_id="node-2",
|
||||
content="segment-content",
|
||||
index_node_hash="hash-2",
|
||||
document_id="doc-2",
|
||||
dataset_id="dataset-1",
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(jieba_module, "DocumentSegment", _FakeDocumentSegment)
|
||||
monkeypatch.setattr(jieba_module, "select", lambda *_: _FakeSelect())
|
||||
monkeypatch.setattr(keyword, "_get_dataset_keyword_table", MagicMock(return_value={"k": {"node-1", "node-2"}}))
|
||||
monkeypatch.setattr(keyword, "_retrieve_ids_by_query", MagicMock(return_value=["node-1", "node-2"]))
|
||||
|
||||
documents = keyword.search("query", top_k=2, document_ids_filter=["doc-2"])
|
||||
|
||||
assert len(query_stmt.where_calls) == 2
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "segment-content"
|
||||
assert documents[0].metadata["doc_id"] == "node-2"
|
||||
|
||||
@@ -714,13 +714,13 @@ class TestRetrievalServiceInternals:
|
||||
dataset_id="dataset-id",
|
||||
)
|
||||
|
||||
dataset_query = Mock()
|
||||
dataset_query.where.return_value.options.return_value.all.return_value = [
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [
|
||||
dataset_doc_parent,
|
||||
dataset_doc_text,
|
||||
dataset_doc_parent_summary,
|
||||
]
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(return_value=dataset_query))
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(return_value=scalars_result))
|
||||
monkeypatch.setattr(retrieval_service_module, "RetrievalChildChunk", _SimpleRetrievalChildChunk)
|
||||
monkeypatch.setattr(retrieval_service_module, "RetrievalSegments", _SimpleRetrievalSegment)
|
||||
|
||||
@@ -882,7 +882,7 @@ class TestRetrievalServiceInternals:
|
||||
def test_format_retrieval_documents_rolls_back_and_raises_when_db_fails(self, monkeypatch):
|
||||
rollback = Mock()
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "rollback", rollback)
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "query", Mock(side_effect=RuntimeError("db error")))
|
||||
monkeypatch.setattr(retrieval_service_module.db.session, "scalars", Mock(side_effect=RuntimeError("db error")))
|
||||
|
||||
documents = [Document(page_content="content", metadata={"document_id": "doc-1"}, provider="dify")]
|
||||
|
||||
|
||||
@@ -340,15 +340,13 @@ def test_search_by_file_handles_missing_and_existing_upload(vector_factory_modul
|
||||
vector._embeddings = MagicMock()
|
||||
vector._vector_processor = MagicMock()
|
||||
|
||||
mock_session = SimpleNamespace(get=lambda _model, _id: None)
|
||||
monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field()))
|
||||
monkeypatch.setattr(
|
||||
vector_factory_module, "db", SimpleNamespace(session=SimpleNamespace(query=lambda _model: upload_query))
|
||||
)
|
||||
monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session))
|
||||
|
||||
upload_query.first.return_value = None
|
||||
assert vector.search_by_file("file-1") == []
|
||||
|
||||
upload_query.first.return_value = SimpleNamespace(key="blob-key")
|
||||
mock_session.get = lambda _model, _id: SimpleNamespace(key="blob-key")
|
||||
monkeypatch.setattr(vector_factory_module.storage, "load_once", MagicMock(return_value=b"file-bytes"))
|
||||
vector._embeddings.embed_multimodal_query.return_value = [0.3, 0.4]
|
||||
vector._vector_processor.search_by_vector.return_value = ["hit"]
|
||||
|
||||
@@ -167,7 +167,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_model_instance.return_value = mock_model_instance
|
||||
@@ -211,7 +211,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -276,7 +276,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -353,7 +353,7 @@ class TestDatasetDocumentStoreAddDocuments:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -755,7 +755,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
@@ -767,7 +767,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateChild:
|
||||
|
||||
store.add_documents([mock_doc], save_child=True)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.delete.assert_called()
|
||||
mock_db.session.execute.assert_called()
|
||||
mock_db.session.commit.assert_called()
|
||||
|
||||
|
||||
@@ -798,7 +798,7 @@ class TestDatasetDocumentStoreAddDocumentsUpdateAnswer:
|
||||
with patch("core.rag.docstore.dataset_docstore.db") as mock_db:
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 5
|
||||
mock_db.session.scalar.return_value = 5
|
||||
|
||||
with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment):
|
||||
with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"):
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
documents = [{"file_id": "file123", "content": "test content"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
@@ -114,7 +114,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
@@ -134,7 +134,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
mock_cached_embedding.get_embedding.return_value = normalized_cached
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
mock_session.scalar.return_value = mock_cached_embedding
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
|
||||
@@ -180,18 +180,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
call_count = [0]
|
||||
|
||||
def mock_filter_by(**kwargs):
|
||||
call_count[0] += 1
|
||||
mock_query = Mock()
|
||||
if call_count[0] == 1:
|
||||
mock_query.first.return_value = mock_cached_embedding
|
||||
else:
|
||||
mock_query.first.return_value = None
|
||||
return mock_query
|
||||
|
||||
mock_session.query.return_value.filter_by = mock_filter_by
|
||||
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
result = cache_embedding.embed_multimodal_documents(documents)
|
||||
@@ -224,7 +213,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
@@ -265,7 +254,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)]
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results
|
||||
@@ -281,7 +270,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
@@ -298,7 +287,7 @@ class TestCacheEmbeddingMultimodalDocuments:
|
||||
documents = [{"file_id": "file123"}]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result
|
||||
|
||||
mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None)
|
||||
|
||||
@@ -139,7 +139,7 @@ class TestCacheEmbeddingDocuments:
|
||||
|
||||
# Mock database query to return no cached embedding (cache miss)
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model invocation
|
||||
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
|
||||
@@ -203,7 +203,7 @@ class TestCacheEmbeddingDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -240,7 +240,7 @@ class TestCacheEmbeddingDocuments:
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
# Mock database to return cached embedding (cache hit)
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
mock_session.scalar.return_value = mock_cached_embedding
|
||||
|
||||
# Act
|
||||
result = cache_embedding.embed_documents(texts)
|
||||
@@ -313,19 +313,7 @@ class TestCacheEmbeddingDocuments:
|
||||
mock_hash.side_effect = generate_hash
|
||||
|
||||
# Mock database to return cached embedding only for first text (hash_1)
|
||||
call_count = [0]
|
||||
|
||||
def mock_filter_by(**kwargs):
|
||||
call_count[0] += 1
|
||||
mock_query = Mock()
|
||||
# First call (hash_1) returns cached, others return None
|
||||
if call_count[0] == 1:
|
||||
mock_query.first.return_value = mock_cached_embedding
|
||||
else:
|
||||
mock_query.first.return_value = None
|
||||
return mock_query
|
||||
|
||||
mock_session.query.return_value.filter_by = mock_filter_by
|
||||
mock_session.scalar.side_effect = [mock_cached_embedding, None, None]
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -392,7 +380,7 @@ class TestCacheEmbeddingDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to return appropriate batch results
|
||||
batch_results = [
|
||||
@@ -455,7 +443,7 @@ class TestCacheEmbeddingDocuments:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.logger") as mock_logger:
|
||||
@@ -489,7 +477,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to raise connection error
|
||||
mock_model_instance.invoke_text_embedding.side_effect = InvokeConnectionError("Failed to connect to API")
|
||||
@@ -515,7 +503,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to raise rate limit error
|
||||
mock_model_instance.invoke_text_embedding.side_effect = InvokeRateLimitError("Rate limit exceeded")
|
||||
@@ -539,7 +527,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to raise authorization error
|
||||
mock_model_instance.invoke_text_embedding.side_effect = InvokeAuthorizationError("Invalid API key")
|
||||
@@ -564,7 +552,7 @@ class TestCacheEmbeddingDocuments:
|
||||
texts = ["Test text"]
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = sample_embedding_result
|
||||
|
||||
# Mock database commit to raise IntegrityError
|
||||
@@ -884,7 +872,7 @@ class TestEmbeddingModelSwitching:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
model_instance_ada.invoke_text_embedding.return_value = result_ada
|
||||
model_instance_3_small.invoke_text_embedding.return_value = result_3_small
|
||||
@@ -1047,7 +1035,7 @@ class TestEmbeddingDimensionValidation:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1100,7 +1088,7 @@ class TestEmbeddingDimensionValidation:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1186,7 +1174,7 @@ class TestEmbeddingDimensionValidation:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
model_instance_ada.invoke_text_embedding.return_value = result_ada
|
||||
model_instance_cohere.invoke_text_embedding.return_value = result_cohere
|
||||
@@ -1284,7 +1272,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1327,7 +1315,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1375,7 +1363,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1427,7 +1415,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1483,7 +1471,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1551,7 +1539,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1649,7 +1637,7 @@ class TestEmbeddingEdgeCases:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
mock_model_instance.invoke_text_embedding.return_value = embedding_result
|
||||
|
||||
# Act
|
||||
@@ -1728,7 +1716,7 @@ class TestEmbeddingCachePerformance:
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
# First call: cache miss
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
usage = EmbeddingUsage(
|
||||
tokens=5,
|
||||
@@ -1756,7 +1744,7 @@ class TestEmbeddingCachePerformance:
|
||||
assert len(result1) == 1
|
||||
|
||||
# Arrange - Second call: cache hit
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding
|
||||
mock_session.scalar.return_value = mock_cached_embedding
|
||||
|
||||
# Act - Second call (cache hit)
|
||||
result2 = cache_embedding.embed_documents([text])
|
||||
@@ -1816,7 +1804,7 @@ class TestEmbeddingCachePerformance:
|
||||
)
|
||||
|
||||
with patch("core.rag.embedding.cached_embedding.db.session") as mock_session:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Mock model to return appropriate batch results
|
||||
batch_results = [
|
||||
|
||||
@@ -405,35 +405,36 @@ class TestNotionMetadataAndCredentialMethods:
|
||||
|
||||
class FakeDocumentModel:
|
||||
data_source_info = "data_source_info"
|
||||
id = "id"
|
||||
|
||||
update_calls = []
|
||||
execute_calls = []
|
||||
|
||||
class FakeQuery:
|
||||
def filter_by(self, **kwargs):
|
||||
class FakeUpdateStmt:
|
||||
def where(self, *args):
|
||||
return self
|
||||
|
||||
def update(self, payload):
|
||||
update_calls.append(payload)
|
||||
def values(self, **kwargs):
|
||||
return self
|
||||
|
||||
class FakeSession:
|
||||
committed = False
|
||||
|
||||
def query(self, model):
|
||||
assert model is FakeDocumentModel
|
||||
return FakeQuery()
|
||||
def execute(self, stmt):
|
||||
execute_calls.append(stmt)
|
||||
|
||||
def commit(self):
|
||||
self.committed = True
|
||||
|
||||
fake_db = SimpleNamespace(session=FakeSession())
|
||||
monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel)
|
||||
monkeypatch.setattr(notion_extractor, "update", lambda model: FakeUpdateStmt())
|
||||
monkeypatch.setattr(notion_extractor, "db", fake_db)
|
||||
monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z")
|
||||
|
||||
doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"})
|
||||
extractor.update_last_edited_time(doc_model)
|
||||
|
||||
assert update_calls
|
||||
assert execute_calls
|
||||
assert fake_db.session.committed is True
|
||||
|
||||
def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture):
|
||||
|
||||
@@ -188,10 +188,10 @@ class TestParagraphIndexProcessor:
|
||||
mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs)
|
||||
|
||||
def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None:
|
||||
segment_query = Mock()
|
||||
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session = Mock()
|
||||
session.query.return_value = segment_query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
@@ -531,10 +531,10 @@ class TestParagraphIndexProcessor:
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.all.return_value = [image_upload, non_image_upload]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [image_upload, non_image_upload]
|
||||
session = Mock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
@@ -565,10 +565,10 @@ class TestParagraphIndexProcessor:
|
||||
size=1,
|
||||
key="key",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.all.return_value = [image_upload]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [image_upload]
|
||||
session = Mock()
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session),
|
||||
|
||||
@@ -208,11 +208,7 @@ class TestParentChildIndexProcessor:
|
||||
vector.create_multimodal.assert_called_once_with(multimodal_docs)
|
||||
|
||||
def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
delete_query = Mock()
|
||||
where_query = Mock()
|
||||
where_query.delete.return_value = 2
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value = where_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
@@ -227,16 +223,16 @@ class TestParentChildIndexProcessor:
|
||||
)
|
||||
|
||||
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
|
||||
where_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
session.execute.assert_called()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_queries_child_ids_when_not_precomputed(
|
||||
self, processor: ParentChildIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
child_query = Mock()
|
||||
child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)]
|
||||
execute_result = Mock()
|
||||
execute_result.all.return_value = [("child-1",), (None,), ("child-2",)]
|
||||
session = Mock()
|
||||
session.query.return_value = child_query
|
||||
session.execute.return_value = execute_result
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
@@ -248,10 +244,7 @@ class TestParentChildIndexProcessor:
|
||||
vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"])
|
||||
|
||||
def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
where_query = Mock()
|
||||
where_query.delete.return_value = 3
|
||||
session = Mock()
|
||||
session.query.return_value.where.return_value = where_query
|
||||
|
||||
with (
|
||||
patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls,
|
||||
@@ -261,7 +254,7 @@ class TestParentChildIndexProcessor:
|
||||
processor.clean(dataset, None, delete_child_chunks=True)
|
||||
|
||||
vector.delete.assert_called_once()
|
||||
where_query.delete.assert_called_once_with(synchronize_session=False)
|
||||
session.execute.assert_called()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
|
||||
@@ -133,10 +133,10 @@ class TestBaseIndexProcessor:
|
||||
upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png")
|
||||
upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png")
|
||||
upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png")
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [upload_a, upload_b, upload_tool, upload_remote]
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch.object(processor, "_extract_markdown_images", return_value=images),
|
||||
@@ -170,10 +170,10 @@ class TestBaseIndexProcessor:
|
||||
def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"})
|
||||
images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"]
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.all.return_value = []
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = []
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.scalars.return_value = scalars_result
|
||||
|
||||
with (
|
||||
patch.object(processor, "_extract_markdown_images", return_value=images),
|
||||
@@ -259,20 +259,16 @@ class TestBaseIndexProcessor:
|
||||
assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None
|
||||
|
||||
def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = None
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.get.return_value = None
|
||||
|
||||
with patch("core.rag.index_processor.index_processor_base.db.session", db_session):
|
||||
assert processor._download_tool_file("tool-id", current_user=Mock()) is None
|
||||
|
||||
def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None:
|
||||
tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png")
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = tool_file
|
||||
db_session = Mock()
|
||||
db_session.query.return_value = db_query
|
||||
db_session.get.return_value = tool_file
|
||||
mock_db = Mock()
|
||||
mock_db.session = db_session
|
||||
mock_db.engine = Mock()
|
||||
|
||||
@@ -473,12 +473,10 @@ class TestRerankModelRunnerMultimodal:
|
||||
metadata={},
|
||||
provider="external",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.first.return_value = SimpleNamespace(key="image-key")
|
||||
rerank_result = RerankResult(model="rerank-model", docs=[])
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
|
||||
patch("core.rag.rerank.rerank_model.db.session.get", return_value=SimpleNamespace(key="image-key")),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once,
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
@@ -504,12 +502,10 @@ class TestRerankModelRunnerMultimodal:
|
||||
metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE},
|
||||
provider="dify",
|
||||
)
|
||||
query = Mock()
|
||||
query.where.return_value.first.return_value = None
|
||||
rerank_result = RerankResult(model="rerank-model", docs=[])
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session.query", return_value=query),
|
||||
patch("core.rag.rerank.rerank_model.db.session.get", return_value=None),
|
||||
patch.object(
|
||||
rerank_runner,
|
||||
"fetch_text_rerank",
|
||||
@@ -533,8 +529,6 @@ class TestRerankModelRunnerMultimodal:
|
||||
metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT},
|
||||
provider="dify",
|
||||
)
|
||||
query_chain = Mock()
|
||||
query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key")
|
||||
rerank_result = RerankResult(
|
||||
model="rerank-model",
|
||||
docs=[RerankDocument(index=0, text="text-content", score=0.77)],
|
||||
@@ -542,7 +536,7 @@ class TestRerankModelRunnerMultimodal:
|
||||
mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_chain
|
||||
session.get.return_value = SimpleNamespace(key="query-image-key")
|
||||
with (
|
||||
patch("core.rag.rerank.rerank_model.db.session", session),
|
||||
patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"),
|
||||
@@ -563,10 +557,7 @@ class TestRerankModelRunnerMultimodal:
|
||||
assert "user" not in invoke_kwargs
|
||||
|
||||
def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner):
|
||||
query_chain = Mock()
|
||||
query_chain.where.return_value.first.return_value = None
|
||||
|
||||
with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain):
|
||||
with patch("core.rag.rerank.rerank_model.db.session.get", return_value=None):
|
||||
with pytest.raises(ValueError, match="Upload file not found for query"):
|
||||
rerank_runner.fetch_multimodal_rerank(
|
||||
query="missing-upload-id",
|
||||
|
||||
@@ -3971,11 +3971,10 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
)
|
||||
|
||||
def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None:
|
||||
db_query = Mock()
|
||||
db_query.where.return_value = db_query
|
||||
db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")]
|
||||
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
mapping, condition = retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=["d1"],
|
||||
query="python",
|
||||
@@ -3991,7 +3990,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
|
||||
automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}]
|
||||
with (
|
||||
patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query),
|
||||
patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result),
|
||||
patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters),
|
||||
):
|
||||
mapping, condition = retrieval.get_metadata_filter_condition(
|
||||
@@ -4012,7 +4011,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
logical_operator="and",
|
||||
conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")],
|
||||
)
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
mapping, condition = retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=["d1"],
|
||||
query="python",
|
||||
@@ -4027,7 +4026,7 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
assert condition is not None
|
||||
assert condition.conditions[0].value == "Alice"
|
||||
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query):
|
||||
with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=scalars_result):
|
||||
with pytest.raises(ValueError, match="Invalid metadata filtering mode"):
|
||||
retrieval.get_metadata_filter_condition(
|
||||
dataset_ids=["d1"],
|
||||
|
||||
Reference in New Issue
Block a user