fix: optimize database query when retrieval knowledge in App (#29467)

This commit is contained in:
Jyong
2025-12-11 13:50:46 +08:00
committed by GitHub
parent aac6f44562
commit 69a22af1c9

View File

@@ -592,111 +592,116 @@ class DatasetRetrieval:
"""Handle retrieval end."""
with flask_app.app_context():
dify_documents = [document for document in documents if document.provider == "dify"]
segment_ids = []
segment_index_node_ids = []
if not dify_documents:
self._send_trace_task(message_id, documents, timer)
return
with Session(db.engine) as session:
for document in dify_documents:
if document.metadata is not None:
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
segment_id = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = session.scalar(child_chunk_stmt)
if child_chunk:
segment_id = child_chunk.segment_id
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
# Collect all document_ids and batch fetch DatasetDocuments
document_ids = {
doc.metadata["document_id"]
for doc in dify_documents
if doc.metadata and "document_id" in doc.metadata
}
if not document_ids:
self._send_trace_task(message_id, documents, timer)
return
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
dataset_docs = session.scalars(dataset_docs_stmt).all()
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
# Categorize documents by type and collect necessary IDs
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
for doc in dify_documents:
if not doc.metadata or "document_id" not in doc.metadata:
continue
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
if not dataset_doc:
continue
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
if is_parent_child:
if is_image:
parent_child_image_docs.append((doc, dataset_doc))
else:
parent_child_text_docs.append((doc, dataset_doc))
else:
if is_image:
normal_image_docs.append((doc, dataset_doc))
else:
normal_text_docs.append((doc, dataset_doc))
segment_ids_to_update: set[str] = set()
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
if parent_child_text_docs:
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
if index_node_ids:
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
child_chunks = session.scalars(child_chunks_stmt).all()
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
for doc, _ in parent_child_text_docs:
if doc.metadata:
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
if segment_id:
if segment_id not in segment_ids:
segment_ids.append(segment_id)
_ = (
session.query(DocumentSegment)
.where(DocumentSegment.id == segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
else:
query = None
if (
"doc_type" not in document.metadata
or document.metadata.get("doc_type") == DocType.TEXT
):
if document.metadata["doc_id"] not in segment_index_node_ids:
segment = (
session.query(DocumentSegment)
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
.first()
)
if segment:
segment_index_node_ids.append(document.metadata["doc_id"])
segment_ids.append(segment.id)
query = session.query(DocumentSegment).where(
DocumentSegment.id == segment.id
)
elif (
"doc_type" in document.metadata
and document.metadata.get("doc_type") == DocType.IMAGE
):
attachment_info_dict = RetrievalService.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
segment_id = attachment_info_dict["segment_id"]
if segment_id not in segment_ids:
segment_ids.append(segment_id)
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
if query:
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.where(
DocumentSegment.dataset_id == document.metadata["dataset_id"]
)
segment_ids_to_update.add(str(segment_id))
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
if normal_text_docs:
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
if index_node_ids:
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
segments = session.scalars(segments_stmt).all()
segment_map = {seg.index_node_id: seg.id for seg in segments}
for doc, _ in normal_text_docs:
if doc.metadata:
segment_id = segment_map.get(doc.metadata["doc_id"])
if segment_id:
segment_ids_to_update.add(str(segment_id))
db.session.commit()
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
all_image_docs = parent_child_image_docs + normal_image_docs
if all_image_docs:
attachment_ids = [
doc.metadata["doc_id"]
for doc, _ in all_image_docs
if doc.metadata and doc.metadata.get("doc_id")
]
if attachment_ids:
bindings_stmt = select(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
)
bindings = session.scalars(bindings_stmt).all()
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
# get tracing instance
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
# Batch update hit_count for all segments
if segment_ids_to_update:
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
session.commit()
self._send_trace_task(message_id, documents, timer)
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
"""Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
def _on_query(
self,