refactor: migrate session.query to select API in rag pipeline task files (#34648)

This commit is contained in:
Renzo
2026-04-07 00:56:19 -05:00
committed by GitHub
parent bceb0eee9b
commit e2ecd68556
2 changed files with 10 additions and 8 deletions

View File

@@ -10,6 +10,7 @@ from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@@ -118,20 +119,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
with Session(db.engine, expire_on_commit=False) as session:
# Load required entities
account = session.query(Account).where(Account.id == user_id).first()
account = session.scalar(select(Account).where(Account.id == user_id).limit(1))
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id).limit(1))
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = session.scalar(select(Pipeline).where(Pipeline.id == pipeline_id).limit(1))
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = session.scalar(select(Workflow).where(Workflow.id == pipeline.workflow_id).limit(1))
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")

View File

@@ -11,6 +11,7 @@ from typing import Any
import click
from celery import group, shared_task
from flask import current_app, g
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@@ -132,20 +133,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).where(Account.id == user_id).first()
account = session.scalar(select(Account).where(Account.id == user_id).limit(1))
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id).limit(1))
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = session.scalar(select(Pipeline).where(Pipeline.id == pipeline_id).limit(1))
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = session.scalar(select(Workflow).where(Workflow.id == pipeline.workflow_id).limit(1))
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")