From e3cc4b83c878747c56d2ee95a2022abb5be2d431 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:46:36 -0500 Subject: [PATCH] refactor: migrate session.query to select API in clean dataset task (#34815) --- api/tasks/clean_dataset_task.py | 30 +++++++++++-------- .../tasks/test_clean_dataset_task.py | 27 ++++------------- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 0d51a743ad..377d0e5cc7 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -112,7 +112,9 @@ def clean_dataset_task( segment_ids = [segment.id for segment in segments] for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + ).all() for image_file in image_files: if image_file is None: continue @@ -150,20 +152,22 @@ def clean_dataset_task( ) session.execute(binding_delete_stmt) - session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() - session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() - session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() + session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id)) + session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id)) + session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id)) # delete dataset metadata - session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() - session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id)) + session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id)) # delete pipeline and workflow if pipeline_id: - session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() - session.query(Workflow).where( - Workflow.tenant_id == tenant_id, - Workflow.app_id == pipeline_id, - Workflow.type == WorkflowType.RAG_PIPELINE, - ).delete() + session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id)) + session.execute( + delete(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ) + ) # delete files if documents: file_ids = [] @@ -174,7 +178,7 @@ def clean_dataset_task( if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] file_ids.append(file_id) - files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() for file in files: storage.delete(file.key) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index 936a10d6c5..b4332334ab 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -60,12 +60,6 @@ def mock_db_session(): cm.__exit__.return_value = None mock_sf.create_session.return_value = cm - # Setup query chain - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.delete.return_value = 0 - # Setup scalars for select queries mock_session.scalars.return_value.all.return_value = [] @@ -220,11 +214,6 @@ class TestPipelineAndWorkflowDeletion: - Pipeline record is deleted - Related workflow record is deleted """ - # Arrange - mock_query = mock_db_session.session.query.return_value - mock_query.where.return_value = mock_query - mock_query.delete.return_value = 1 - # Act clean_dataset_task( dataset_id=dataset_id, @@ -236,9 +225,9 @@ class TestPipelineAndWorkflowDeletion: pipeline_id=pipeline_id, ) - # Assert - verify delete was called for pipeline-related queries - # The actual count depends on total queries, but pipeline deletion should add 2 more - assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow + # Assert - verify execute was called for delete operations + # 1 attachment JOIN query + 5 base deletes + 2 pipeline/workflow deletes = 8 + assert mock_db_session.session.execute.call_count >= 8 def test_clean_dataset_task_without_pipeline_id( self, @@ -256,11 +245,6 @@ class TestPipelineAndWorkflowDeletion: Expected behavior: - Pipeline and workflow deletion queries are not executed """ - # Arrange - mock_query = mock_db_session.session.query.return_value - mock_query.where.return_value = mock_query - mock_query.delete.return_value = 1 - # Act clean_dataset_task( dataset_id=dataset_id, @@ -272,8 +256,9 @@ class TestPipelineAndWorkflowDeletion: pipeline_id=None, ) - # Assert - verify delete was called only for base queries (5 times) - assert mock_query.delete.call_count == 5 + # Assert - verify execute was called for delete operations + # 1 attachment JOIN query + 5 base deletes = 6 + assert mock_db_session.session.execute.call_count == 6 # ============================================================================