From 12814b55d2cfd307c5b23a975a3344021b00a77a Mon Sep 17 00:00:00 2001 From: wdeveloper16 Date: Sat, 11 Apr 2026 18:32:20 +0200 Subject: [PATCH] refactor(api): migrate core RAG layer to SQLAlchemy 2.0 select() API (#34965) --- .../rag/index_processor/index_processor.py | 40 ++++++++++--------- .../processor/parent_child_index_processor.py | 8 ++-- .../processor/qa_index_processor.py | 9 ++--- api/core/rag/retrieval/dataset_retrieval.py | 23 ++++++----- api/core/rag/summary_index/summary_index.py | 31 +++++++------- .../test_parent_child_index_processor.py | 6 +-- .../processor/test_qa_index_processor.py | 6 +-- .../rag/retrieval/test_dataset_retrieval.py | 24 ++--------- 8 files changed, 66 insertions(+), 81 deletions(-) diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index 813a84cbbd..aded5315bd 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import Any from flask import current_app -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -63,11 +63,11 @@ class IndexProcessor: summary_index_setting: SummaryIndexSettingDict | None = None, ) -> IndexingResultDict: with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") @@ -104,12 +104,12 @@ class IndexProcessor: document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.word_count = ( - session.query(func.sum(DocumentSegment.word_count)) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, + session.scalar( + select(func.sum(DocumentSegment.word_count)).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) ) - .scalar() ) or 0 # Update need_summary based on dataset's summary_index_setting if summary_index_setting and summary_index_setting.get("enable") is True: @@ -118,15 +118,17 @@ class IndexProcessor: document.need_summary = False session.add(document) # update document segment status - session.query(DocumentSegment).where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, - ).update( - { - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + .values( + status="completed", + enabled=True, + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) ) result: IndexingResultDict = { @@ -151,11 +153,11 @@ class IndexProcessor: doc_language = None with session_factory.create_session() as session: if document_id: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) else: document = None - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 2db233874a..ba277d5018 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -159,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index b0f7928092..d3f311b08e 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -8,6 +8,7 @@ from typing import Any, TypedDict import pandas as pd from flask import Flask, current_app +from sqlalchemy import select from werkzeug.datastructures import FileStorage from core.db.session_factory import session_factory @@ -163,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 0f3351fd68..b681ff5db1 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from sqlalchemy import and_, func, literal, or_, select +from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import ( @@ -276,8 +276,8 @@ class DatasetRetrieval: document_ids = [i.segment.document_id for i in records] with session_factory.create_session() as session: - datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() - documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all() + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all() dataset_map = {i.id: i for i in datasets} document_map = {i.id: i for i in documents} @@ -971,9 +971,11 @@ class DatasetRetrieval: # Batch update hit_count for all segments if segment_ids_to_update: - session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id.in_(segment_ids_to_update)) + .values(hit_count=DocumentSegment.hit_count + 1) + .execution_options(synchronize_session=False) ) self._send_trace_task(message_id, documents, timer) @@ -1822,7 +1824,7 @@ class DatasetRetrieval: def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]: with session_factory.create_session() as session: subquery = ( - session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) + select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) .where( DocumentModel.indexing_status == "completed", DocumentModel.enabled == True, @@ -1834,13 +1836,12 @@ class DatasetRetrieval: .subquery() ) - results = ( - session.query(Dataset) + results = session.scalars( + select(Dataset) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) - .all() - ) + ).all() available_datasets = [] for dataset in results: diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 6f120bd471..bff5f85dec 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -1,6 +1,8 @@ import concurrent.futures import logging +from sqlalchemy import select + from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict @@ -21,7 +23,7 @@ class SummaryIndex: ) -> None: if is_preview: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return @@ -34,32 +36,31 @@ class SummaryIndex: if not document_id: return - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) # Skip qa_model documents if document is None or document.doc_form == "qa_model": return - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset_id, - document_id=document_id, - status="completed", - enabled=True, - ) - segments = query.all() + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + ).all() segment_ids = [segment.id for segment in segments] if not segment_ids: return - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, DocumentSegmentSummary.status == "completed", ) - .all() - ) + ).all() completed_summary_segment_ids = {i.chunk_id for i in existing_summaries} # Preview mode should process segments that are MISSING completed summaries pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids] @@ -73,7 +74,7 @@ class SummaryIndex: def process_segment(segment_id: str) -> None: """Process a single segment in a thread with a fresh DB session.""" with session_factory.create_session() as session: - segment = session.query(DocumentSegment).filter_by(id=segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if segment is None: return try: diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index c241b44d52..8ef0e046ef 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -258,10 +258,10 @@ class TestParentChildIndexProcessor: session.commit.assert_called_once() def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - segment_query = Mock() - segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="seg-1")] session = Mock() - session.query.return_value = segment_query + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 98c47bec8f..b1b1835a52 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -220,10 +220,10 @@ class TestQAIndexProcessor: self, processor: QAIndexProcessor, dataset: Mock ) -> None: mock_segment = SimpleNamespace(id="seg-1") - mock_query = Mock() - mock_query.filter.return_value.all.return_value = [mock_segment] + scalars_result = Mock() + scalars_result.all.return_value = [mock_segment] mock_session = Mock() - mock_session.query.return_value = mock_query + mock_session.scalars.return_value = scalars_result session_context = MagicMock() session_context.__enter__.return_value = mock_session session_context.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b98fec3854..1b17cbc368 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -8,7 +8,6 @@ import pytest from flask import Flask, current_app from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelFeature -from sqlalchemy import column from core.app.app_config.entities import ( DatasetEntity, @@ -4039,21 +4038,9 @@ class TestDatasetRetrievalAdditionalHelpers: def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: session = Mock() - subquery_query = Mock() - subquery_query.where.return_value = subquery_query - subquery_query.group_by.return_value = subquery_query - subquery_query.having.return_value = subquery_query - subquery_query.subquery.return_value = SimpleNamespace( - c=SimpleNamespace( - dataset_id=column("dataset_id"), available_document_count=column("available_document_count") - ) - ) - - dataset_query = Mock() - dataset_query.outerjoin.return_value = dataset_query - dataset_query.where.return_value = dataset_query - dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] - session.query.side_effect = [subquery_query, dataset_query] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session @@ -4902,9 +4889,6 @@ class TestInternalHooksCoverage: _scalars(segments), _scalars(bindings), ] - query = Mock() - query.where.return_value = query - session.query.return_value = query session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False @@ -4919,7 +4903,7 @@ class TestInternalHooksCoverage: ): retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) - query.update.assert_called_once() + session.execute.assert_called_once() mock_trace.assert_called_once() def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: