mirror of
https://github.com/langgenius/dify.git
synced 2026-04-12 18:00:24 -04:00
refactor(api): migrate core RAG layer to SQLAlchemy 2.0 select() API (#34965)
This commit is contained in:
@@ -6,7 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@@ -63,11 +63,11 @@ class IndexProcessor:
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
) -> IndexingResultDict:
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
@@ -104,12 +104,12 @@ class IndexProcessor:
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
session.scalar(
|
||||
select(func.sum(DocumentSegment.word_count)).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable") is True:
|
||||
@@ -118,15 +118,17 @@ class IndexProcessor:
|
||||
document.need_summary = False
|
||||
session.add(document)
|
||||
# update document segment status
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
.values(
|
||||
status="completed",
|
||||
enabled=True,
|
||||
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
|
||||
result: IndexingResultDict = {
|
||||
@@ -151,11 +153,11 @@ class IndexProcessor:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
if document_id:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
else:
|
||||
document = None
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
|
||||
@@ -159,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@@ -163,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
.filter(
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
|
||||
@@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, func, literal, or_, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@@ -276,8 +276,8 @@ class DatasetRetrieval:
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
@@ -971,9 +971,11 @@ class DatasetRetrieval:
|
||||
|
||||
# Batch update hit_count for all segments
|
||||
if segment_ids_to_update:
|
||||
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.id.in_(segment_ids_to_update))
|
||||
.values(hit_count=DocumentSegment.hit_count + 1)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
@@ -1822,7 +1824,7 @@ class DatasetRetrieval:
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
@@ -1834,13 +1836,12 @@ class DatasetRetrieval:
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
results = session.scalars(
|
||||
select(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
@@ -21,7 +23,7 @@ class SummaryIndex:
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
return
|
||||
|
||||
@@ -34,32 +36,31 @@ class SummaryIndex:
|
||||
if not document_id:
|
||||
return
|
||||
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
|
||||
# Skip qa_model documents
|
||||
if document is None or document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return
|
||||
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
existing_summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
|
||||
# Preview mode should process segments that are MISSING completed summaries
|
||||
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
|
||||
@@ -73,7 +74,7 @@ class SummaryIndex:
|
||||
def process_segment(segment_id: str) -> None:
|
||||
"""Process a single segment in a thread with a fresh DB session."""
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
|
||||
if segment is None:
|
||||
return
|
||||
try:
|
||||
|
||||
@@ -258,10 +258,10 @@ class TestParentChildIndexProcessor:
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
|
||||
segment_query = Mock()
|
||||
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session = Mock()
|
||||
session.query.return_value = segment_query
|
||||
session.scalars.return_value = scalars_result
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
|
||||
@@ -220,10 +220,10 @@ class TestQAIndexProcessor:
|
||||
self, processor: QAIndexProcessor, dataset: Mock
|
||||
) -> None:
|
||||
mock_segment = SimpleNamespace(id="seg-1")
|
||||
mock_query = Mock()
|
||||
mock_query.filter.return_value.all.return_value = [mock_segment]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [mock_segment]
|
||||
mock_session = Mock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.scalars.return_value = scalars_result
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from sqlalchemy import column
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -4039,21 +4038,9 @@ class TestDatasetRetrievalAdditionalHelpers:
|
||||
|
||||
def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None:
|
||||
session = Mock()
|
||||
subquery_query = Mock()
|
||||
subquery_query.where.return_value = subquery_query
|
||||
subquery_query.group_by.return_value = subquery_query
|
||||
subquery_query.having.return_value = subquery_query
|
||||
subquery_query.subquery.return_value = SimpleNamespace(
|
||||
c=SimpleNamespace(
|
||||
dataset_id=column("dataset_id"), available_document_count=column("available_document_count")
|
||||
)
|
||||
)
|
||||
|
||||
dataset_query = Mock()
|
||||
dataset_query.outerjoin.return_value = dataset_query
|
||||
dataset_query.where.return_value = dataset_query
|
||||
dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
|
||||
session.query.side_effect = [subquery_query, dataset_query]
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
@@ -4902,9 +4889,6 @@ class TestInternalHooksCoverage:
|
||||
_scalars(segments),
|
||||
_scalars(bindings),
|
||||
]
|
||||
query = Mock()
|
||||
query.where.return_value = query
|
||||
session.query.return_value = query
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = False
|
||||
@@ -4919,7 +4903,7 @@ class TestInternalHooksCoverage:
|
||||
):
|
||||
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
|
||||
|
||||
query.update.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
mock_trace.assert_called_once()
|
||||
|
||||
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None:
|
||||
|
||||
Reference in New Issue
Block a user