mirror of
https://github.com/langgenius/dify.git
synced 2025-12-25 01:00:42 -05:00
fix: optimize database query when retrieval knowledge in App (#29467)
This commit is contained in:
@@ -592,111 +592,116 @@ class DatasetRetrieval:
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
segment_ids = []
|
||||
segment_index_node_ids = []
|
||||
if not dify_documents:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
with Session(db.engine) as session:
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
segment_id = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment_id = child_chunk.segment_id
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
# Collect all document_ids and batch fetch DatasetDocuments
|
||||
document_ids = {
|
||||
doc.metadata["document_id"]
|
||||
for doc in dify_documents
|
||||
if doc.metadata and "document_id" in doc.metadata
|
||||
}
|
||||
if not document_ids:
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
return
|
||||
|
||||
dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
|
||||
dataset_docs = session.scalars(dataset_docs_stmt).all()
|
||||
dataset_doc_map = {str(doc.id): doc for doc in dataset_docs}
|
||||
|
||||
# Categorize documents by type and collect necessary IDs
|
||||
parent_child_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
parent_child_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_text_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
normal_image_docs: list[tuple[Document, DatasetDocument]] = []
|
||||
|
||||
for doc in dify_documents:
|
||||
if not doc.metadata or "document_id" not in doc.metadata:
|
||||
continue
|
||||
dataset_doc = dataset_doc_map.get(doc.metadata["document_id"])
|
||||
if not dataset_doc:
|
||||
continue
|
||||
|
||||
is_image = doc.metadata.get("doc_type") == DocType.IMAGE
|
||||
is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX
|
||||
|
||||
if is_parent_child:
|
||||
if is_image:
|
||||
parent_child_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
parent_child_text_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
if is_image:
|
||||
normal_image_docs.append((doc, dataset_doc))
|
||||
else:
|
||||
normal_text_docs.append((doc, dataset_doc))
|
||||
|
||||
segment_ids_to_update: set[str] = set()
|
||||
|
||||
# Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks
|
||||
if parent_child_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids))
|
||||
child_chunks = session.scalars(child_chunks_stmt).all()
|
||||
child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks}
|
||||
for doc, _ in parent_child_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = child_chunk_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
_ = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = None
|
||||
if (
|
||||
"doc_type" not in document.metadata
|
||||
or document.metadata.get("doc_type") == DocType.TEXT
|
||||
):
|
||||
if document.metadata["doc_id"] not in segment_index_node_ids:
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(DocumentSegment.index_node_id == document.metadata["doc_id"])
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
segment_index_node_ids.append(document.metadata["doc_id"])
|
||||
segment_ids.append(segment.id)
|
||||
query = session.query(DocumentSegment).where(
|
||||
DocumentSegment.id == segment.id
|
||||
)
|
||||
elif (
|
||||
"doc_type" in document.metadata
|
||||
and document.metadata.get("doc_type") == DocType.IMAGE
|
||||
):
|
||||
attachment_info_dict = RetrievalService.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
if segment_id not in segment_ids:
|
||||
segment_ids.append(segment_id)
|
||||
query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id)
|
||||
if query:
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(
|
||||
DocumentSegment.dataset_id == document.metadata["dataset_id"]
|
||||
)
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
# Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments
|
||||
if normal_text_docs:
|
||||
index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata]
|
||||
if index_node_ids:
|
||||
segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids))
|
||||
segments = session.scalars(segments_stmt).all()
|
||||
segment_map = {seg.index_node_id: seg.id for seg in segments}
|
||||
for doc, _ in normal_text_docs:
|
||||
if doc.metadata:
|
||||
segment_id = segment_map.get(doc.metadata["doc_id"])
|
||||
if segment_id:
|
||||
segment_ids_to_update.add(str(segment_id))
|
||||
|
||||
db.session.commit()
|
||||
# Process IMAGE documents - batch fetch SegmentAttachmentBindings
|
||||
all_image_docs = parent_child_image_docs + normal_image_docs
|
||||
if all_image_docs:
|
||||
attachment_ids = [
|
||||
doc.metadata["doc_id"]
|
||||
for doc, _ in all_image_docs
|
||||
if doc.metadata and doc.metadata.get("doc_id")
|
||||
]
|
||||
if attachment_ids:
|
||||
bindings_stmt = select(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.attachment_id.in_(attachment_ids)
|
||||
)
|
||||
bindings = session.scalars(bindings_stmt).all()
|
||||
segment_ids_to_update.update(str(binding.segment_id) for binding in bindings)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user