refactor: migrate session.query to select API in clean dataset task (#34815)

This commit is contained in:
Renzo
2026-04-09 00:46:36 -05:00
committed by GitHub
parent b5acc8e392
commit e3cc4b83c8
2 changed files with 23 additions and 34 deletions

View File

@@ -112,7 +112,9 @@ def clean_dataset_task(
segment_ids = [segment.id for segment in segments]
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
@@ -150,20 +152,22 @@ def clean_dataset_task(
)
session.execute(binding_delete_stmt)
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id))
session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id))
session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id))
# delete dataset metadata
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id))
session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id))
# delete pipeline and workflow
if pipeline_id:
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id))
session.execute(
delete(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
)
)
# delete files
if documents:
file_ids = []
@@ -174,7 +178,7 @@ def clean_dataset_task(
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file_ids.append(file_id)
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
storage.delete(file.key)

View File

@@ -60,12 +60,6 @@ def mock_db_session():
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
# Setup query chain
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.delete.return_value = 0
# Setup scalars for select queries
mock_session.scalars.return_value.all.return_value = []
@@ -220,11 +214,6 @@ class TestPipelineAndWorkflowDeletion:
- Pipeline record is deleted
- Related workflow record is deleted
"""
# Arrange
mock_query = mock_db_session.session.query.return_value
mock_query.where.return_value = mock_query
mock_query.delete.return_value = 1
# Act
clean_dataset_task(
dataset_id=dataset_id,
@@ -236,9 +225,9 @@ class TestPipelineAndWorkflowDeletion:
pipeline_id=pipeline_id,
)
# Assert - verify delete was called for pipeline-related queries
# The actual count depends on total queries, but pipeline deletion should add 2 more
assert mock_query.delete.call_count >= 7 # 5 base + 2 pipeline/workflow
# Assert - verify execute was called for delete operations
# 1 attachment JOIN query + 5 base deletes + 2 pipeline/workflow deletes = 8
assert mock_db_session.session.execute.call_count >= 8
def test_clean_dataset_task_without_pipeline_id(
self,
@@ -256,11 +245,6 @@ class TestPipelineAndWorkflowDeletion:
Expected behavior:
- Pipeline and workflow deletion queries are not executed
"""
# Arrange
mock_query = mock_db_session.session.query.return_value
mock_query.where.return_value = mock_query
mock_query.delete.return_value = 1
# Act
clean_dataset_task(
dataset_id=dataset_id,
@@ -272,8 +256,9 @@ class TestPipelineAndWorkflowDeletion:
pipeline_id=None,
)
# Assert - verify delete was called only for base queries (5 times)
assert mock_query.delete.call_count == 5
# Assert - verify execute was called for delete operations
# 1 attachment JOIN query + 5 base deletes = 6
assert mock_db_session.session.execute.call_count == 6
# ============================================================================