refactor(api): migrate core RAG layer to SQLAlchemy 2.0 select() API (#34965)

This commit is contained in:
wdeveloper16
2026-04-11 18:32:20 +02:00
committed by GitHub
parent 50206ae8a7
commit 12814b55d2
8 changed files with 66 additions and 81 deletions

View File

@@ -6,7 +6,7 @@ from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import delete, func, select
from sqlalchemy import delete, func, select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@@ -63,11 +63,11 @@ class IndexProcessor:
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> IndexingResultDict:
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
@@ -104,12 +104,12 @@ class IndexProcessor:
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
session.scalar(
select(func.sum(DocumentSegment.word_count)).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
)
.scalar()
) or 0
# Update need_summary based on dataset's summary_index_setting
if summary_index_setting and summary_index_setting.get("enable") is True:
@@ -118,15 +118,17 @@ class IndexProcessor:
document.need_summary = False
session.add(document)
# update document segment status
session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
.values(
status="completed",
enabled=True,
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
)
)
result: IndexingResultDict = {
@@ -151,11 +153,11 @@ class IndexProcessor:
doc_language = None
with session_factory.create_session() as session:
if document_id:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
else:
document = None
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")

View File

@@ -159,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)

View File

@@ -8,6 +8,7 @@ from typing import Any, TypedDict
import pandas as pd
from flask import Flask, current_app
from sqlalchemy import select
from werkzeug.datastructures import FileStorage
from core.db.session_factory import session_factory
@@ -163,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor):
if node_ids:
# Find segments by index_node_id
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment)
.filter(
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
).all()
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)

View File

@@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy import and_, func, literal, or_, select, update
from sqlalchemy.orm import sessionmaker
from core.app.app_config.entities import (
@@ -276,8 +276,8 @@ class DatasetRetrieval:
document_ids = [i.segment.document_id for i in records]
with session_factory.create_session() as session:
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
@@ -971,9 +971,11 @@ class DatasetRetrieval:
# Batch update hit_count for all segments
if segment_ids_to_update:
session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
session.execute(
update(DocumentSegment)
.where(DocumentSegment.id.in_(segment_ids_to_update))
.values(hit_count=DocumentSegment.hit_count + 1)
.execution_options(synchronize_session=False)
)
self._send_trace_task(message_id, documents, timer)
@@ -1822,7 +1824,7 @@ class DatasetRetrieval:
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
with session_factory.create_session() as session:
subquery = (
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
.where(
DocumentModel.indexing_status == "completed",
DocumentModel.enabled == True,
@@ -1834,13 +1836,12 @@ class DatasetRetrieval:
.subquery()
)
results = (
session.query(Dataset)
results = session.scalars(
select(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
).all()
available_datasets = []
for dataset in results:

View File

@@ -1,6 +1,8 @@
import concurrent.futures
import logging
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
@@ -21,7 +23,7 @@ class SummaryIndex:
) -> None:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
return
@@ -34,32 +36,31 @@ class SummaryIndex:
if not document_id:
return
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
# Skip qa_model documents
if document is None or document.doc_form == "qa_model":
return
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset_id,
document_id=document_id,
status="completed",
enabled=True,
)
segments = query.all()
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
).all()
segment_ids = [segment.id for segment in segments]
if not segment_ids:
return
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
existing_summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.status == "completed",
)
.all()
)
).all()
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
# Preview mode should process segments that are MISSING completed summaries
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
@@ -73,7 +74,7 @@ class SummaryIndex:
def process_segment(segment_id: str) -> None:
"""Process a single segment in a thread with a fresh DB session."""
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
if segment is None:
return
try:

View File

@@ -258,10 +258,10 @@ class TestParentChildIndexProcessor:
session.commit.assert_called_once()
def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None:
segment_query = Mock()
segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(id="seg-1")]
session = Mock()
session.query.return_value = segment_query
session.scalars.return_value = scalars_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False

View File

@@ -220,10 +220,10 @@ class TestQAIndexProcessor:
self, processor: QAIndexProcessor, dataset: Mock
) -> None:
mock_segment = SimpleNamespace(id="seg-1")
mock_query = Mock()
mock_query.filter.return_value.all.return_value = [mock_segment]
scalars_result = Mock()
scalars_result.all.return_value = [mock_segment]
mock_session = Mock()
mock_session.query.return_value = mock_query
mock_session.scalars.return_value = scalars_result
session_context = MagicMock()
session_context.__enter__.return_value = mock_session
session_context.__exit__.return_value = False

View File

@@ -8,7 +8,6 @@ import pytest
from flask import Flask, current_app
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelFeature
from sqlalchemy import column
from core.app.app_config.entities import (
DatasetEntity,
@@ -4039,21 +4038,9 @@ class TestDatasetRetrievalAdditionalHelpers:
def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None:
session = Mock()
subquery_query = Mock()
subquery_query.where.return_value = subquery_query
subquery_query.group_by.return_value = subquery_query
subquery_query.having.return_value = subquery_query
subquery_query.subquery.return_value = SimpleNamespace(
c=SimpleNamespace(
dataset_id=column("dataset_id"), available_document_count=column("available_document_count")
)
)
dataset_query = Mock()
dataset_query.outerjoin.return_value = dataset_query
dataset_query.where.return_value = dataset_query
dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
session.query.side_effect = [subquery_query, dataset_query]
scalars_result = Mock()
scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")]
session.scalars.return_value = scalars_result
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
@@ -4902,9 +4889,6 @@ class TestInternalHooksCoverage:
_scalars(segments),
_scalars(bindings),
]
query = Mock()
query.where.return_value = query
session.query.return_value = query
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = False
@@ -4919,7 +4903,7 @@ class TestInternalHooksCoverage:
):
retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1})
query.update.assert_called_once()
session.execute.assert_called_once()
mock_trace.assert_called_once()
def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: