diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b4252e1a3e..be13d40f3e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -682,7 +682,7 @@ class ToolManager: with Session(db.engine, autoflush=False) as session: ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] - return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)))) @classmethod def list_providers_from_api( diff --git a/api/services/account_service.py b/api/services/account_service.py index 1f5f81e5bd..ccc4a7c1fa 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -809,11 +809,11 @@ class AccountService: rest of the system gradually normalizes new inputs. """ with session_factory.create_session() as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() if account or email == email.lower(): return account - return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index d6f6ee8086..43a726b100 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -13,6 +13,7 @@ import sqlalchemy as sa import tqdm from flask import Flask, current_app from pydantic import TypeAdapter +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity @@ -66,7 +67,7 @@ class PluginMigration: current_time = started_at with Session(db.engine) as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -123,9 +124,12 @@ class PluginMigration: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -147,8 +151,8 @@ class PluginMigration: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) @@ -235,7 +239,7 @@ class PluginMigration: Extract tool tables. """ with Session(db.engine) as session: - rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all() + rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all() result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) @@ -249,7 +253,7 @@ class PluginMigration: """ with Session(db.engine) as session: - rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all() + rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all() result = [] for row in rs: graph = row.graph_dict @@ -272,7 +276,7 @@ class PluginMigration: Extract app tables. """ with Session(db.engine) as session: - apps = session.query(App).where(App.tenant_id == tenant_id).all() + apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all() if not apps: return [] @@ -280,7 +284,7 @@ class PluginMigration: app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] - rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() + rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all() result = [] for row in rs: agent_config = row.agent_mode_dict diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c24bf3d649..65bdf43af5 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -283,7 +283,9 @@ class RagPipelineDslService: ): raise ValueError("Chunk structure is not compatible with the published pipeline") if not dataset: - datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() + datasets = self._session.scalars( + select(Dataset).where(Dataset.tenant_id == account.current_tenant_id) + ).all() names = [dataset.name for dataset in datasets] generate_name = generate_incremental_name(names, name) dataset = Dataset( @@ -303,8 +305,8 @@ class RagPipelineDslService: chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -312,7 +314,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -440,8 +442,8 @@ class RagPipelineDslService: dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -449,7 +451,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -591,14 +593,14 @@ class RagPipelineDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) - workflow = ( - self._session.query(Workflow) + workflow = self._session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() + .limit(1) ) # create draft workflow if not found @@ -665,14 +667,12 @@ class RagPipelineDslService: :param pipeline: Pipeline instance """ - workflow = ( - self._session.query(Workflow) - .where( + workflow = self._session.scalar( + select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() ) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") @@ -904,15 +904,16 @@ class RagPipelineDslService: ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if ( - self._session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() + if self._session.scalar( + select(Dataset).where( + Dataset.name == rag_pipeline_dataset_create_entity.name, + Dataset.tenant_id == tenant_id, + ) ): raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") else: # generate a random name as Untitled 1 2 3 ... - datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 1c1b94ae9d..8057d9b2c7 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -19,7 +19,7 @@ from graphon.variables.segments import ( ) from graphon.variables.types import SegmentType from graphon.variables.utils import dumps_with_segments -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session, sessionmaker @@ -222,11 +222,10 @@ class WorkflowDraftVariableService: ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where(WorkflowDraftVariable.id == variable_id) - .first() ) def get_draft_variables_by_selectors( @@ -254,20 +253,21 @@ class WorkflowDraftVariableService: # Alternatively, a `SELECT` statement could be constructed for each selector and # combined using `UNION` to fetch all rows. # Benchmarking indicates that both approaches yield comparable performance. - query = ( - self._session.query(WorkflowDraftVariable) - .options( - orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( - WorkflowDraftVariableFile.upload_file + return list( + self._session.scalars( + select(WorkflowDraftVariable) + .options( + orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( + WorkflowDraftVariableFile.upload_file + ) + ) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + or_(*ors), ) ) - .where( - WorkflowDraftVariable.app_id == app_id, - WorkflowDraftVariable.user_id == user_id, - or_(*ors), - ) ) - return query.all() def list_variables_without_values( self, app_id: str, page: int, limit: int, user_id: str @@ -277,18 +277,21 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.user_id == user_id, ] total = None - query = self._session.query(WorkflowDraftVariable).where(*criteria) + base_stmt = select(WorkflowDraftVariable).where(*criteria) if page == 1: - total = query.count() - variables = ( - # Do not load the `value` field - query.options( - orm.defer(WorkflowDraftVariable.value, raiseload=True), + from sqlalchemy import func as sa_func + + total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery())) + variables = list( + self._session.scalars( + # Do not load the `value` field + base_stmt.options( + orm.defer(WorkflowDraftVariable.value, raiseload=True), + ) + .order_by(WorkflowDraftVariable.created_at.desc()) + .limit(limit) + .offset((page - 1) * limit) ) - .order_by(WorkflowDraftVariable.created_at.desc()) - .limit(limit) - .offset((page - 1) * limit) - .all() ) return WorkflowDraftVariableList(variables=variables, total=total) @@ -299,11 +302,13 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ] - query = self._session.query(WorkflowDraftVariable).where(*criteria) - variables = ( - query.options(orm.selectinload(WorkflowDraftVariable.variable_file)) - .order_by(WorkflowDraftVariable.created_at.desc()) - .all() + variables = list( + self._session.scalars( + select(WorkflowDraftVariable) + .options(orm.selectinload(WorkflowDraftVariable.variable_file)) + .where(*criteria) + .order_by(WorkflowDraftVariable.created_at.desc()) + ) ) return WorkflowDraftVariableList(variables=variables) @@ -326,8 +331,8 @@ class WorkflowDraftVariableService: return self._get_variable(app_id, node_id, name, user_id=user_id) def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where( WorkflowDraftVariable.app_id == app_id, @@ -335,7 +340,6 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.name == name, WorkflowDraftVariable.user_id == user_id, ) - .first() ) def update_variable( @@ -488,20 +492,20 @@ class WorkflowDraftVariableService: self._session.delete(variable) def delete_user_workflow_variables(self, app_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_app_workflow_variables(self, app_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where(WorkflowDraftVariable.app_id == app_id) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]): @@ -540,14 +544,14 @@ class WorkflowDraftVariableService: return self._delete_node_variables(app_id, node_id, user_id=user_id) def _delete_node_variables(self, app_id: str, node_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None: @@ -588,13 +592,11 @@ class WorkflowDraftVariableService: conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id) if conv_id is not None: - conversation = ( - self._session.query(Conversation) - .where( + conversation = self._session.scalar( + select(Conversation).where( Conversation.id == conv_id, Conversation.app_id == workflow.app_id, ) - .first() ) # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). if conversation is not None: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c28704e83b..839b9e3319 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1512,14 +1512,12 @@ class WorkflowService: # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version - tool_provider = ( - session.query(WorkflowToolProvider) - .where( + tool_provider = session.scalar( + select(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.version == workflow.version, ) - .first() ) if tool_provider: diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 31b68f0b3f..9ebaa0417b 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql(): for scheme in ("postgresql", "mysql"): session = Mock() session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] - session.query.return_value.where.return_value.all.return_value = provider_records + session.scalars.return_value = iter(provider_records) with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): with patch("core.tools.tool_manager.db") as mock_db: diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py index f4fdac5f9f..6813a1bf2a 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py @@ -247,10 +247,11 @@ workflow: dataset_mock = Mock() dataset_mock.id = "d1" mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.filter_by.return_value.all.return_value = [] + session.scalars.return_value.all.return_value = [] account = Mock(current_tenant_id="t1") result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content) @@ -320,6 +321,7 @@ workflow: dataset_mock.id = "d1" mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", return_value=Mock(id="b1")) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) service = RagPipelineDslService(session=Mock()) # Mocking self._session.scalar for the pipeline lookup @@ -406,12 +408,14 @@ def test_create_or_update_pipeline_create_new(mocker) -> None: mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1")) mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock()) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline") pipeline_instance = pipeline_cls.return_value pipeline_instance.tenant_id = "t1" pipeline_instance.id = "p1" pipeline_instance.name = "P" pipeline_instance.is_published = False + session.scalar.return_value = None result = service._create_or_update_pipeline(pipeline=None, data=data, account=account, dependencies=[]) @@ -447,8 +451,7 @@ def test_export_rag_pipeline_dsl_with_workflow(mocker) -> None: workflow.rag_pipeline_variables = [] workflow.to_dict.return_value = {"graph": {"nodes": []}} - # Mocking single .where() call - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -550,7 +553,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None: ] } } - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -568,7 +571,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None: def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None: session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.filter_by.return_value.first.return_value = Mock() + session.scalar.return_value = Mock() create_entity = RagPipelineDatasetCreateEntity( name="Existing Name", description="", @@ -584,8 +587,8 @@ def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None: def test_create_rag_pipeline_dataset_generates_name_when_missing(mocker) -> None: session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.filter_by.return_value.first.return_value = None - session.query.return_value.filter_by.return_value.all.return_value = [Mock(name="Untitled")] + session.scalar.return_value = None + session.scalars.return_value.all.return_value = [Mock(name="Untitled")] mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="Untitled 2") mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", Mock(id="u1", current_tenant_id="t1")) mocker.patch.object( @@ -632,7 +635,7 @@ def test_append_workflow_export_data_encrypts_knowledge_retrieval_dataset_ids(mo ] } } - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch.object(service, "encrypt_dataset_id", side_effect=lambda dataset_id, tenant_id: f"enc-{dataset_id}") mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", @@ -727,7 +730,7 @@ def test_create_or_update_pipeline_decrypts_knowledge_retrieval_dataset_ids(mock }, } draft_workflow = Mock(id="wf1") - session.query.return_value.where.return_value.first.return_value = draft_workflow + session.scalar.return_value = draft_workflow mocker.patch.object(service, "decrypt_dataset_id", side_effect=["d1", None]) result = service._create_or_update_pipeline(pipeline=pipeline, data=data, account=account) @@ -743,7 +746,8 @@ def test_create_or_update_pipeline_creates_draft_when_missing(mocker) -> None: account = Mock(id="u1", current_tenant_id="t1") pipeline = Mock(id="p1", tenant_id="t1", name="N", description="D") data = {"rag_pipeline": {"name": "N2", "description": "D2"}, "workflow": {"graph": {"nodes": []}}} - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) workflow_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow") workflow_cls.return_value.id = "wf-new" @@ -817,7 +821,7 @@ def test_import_rag_pipeline_fails_for_non_string_version_type() -> None: def test_append_workflow_export_data_raises_when_draft_workflow_missing() -> None: session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with pytest.raises(ValueError, match="Missing draft workflow configuration"): service._append_workflow_export_data(export_data={}, pipeline=Mock(tenant_id="t1"), include_secret=False) @@ -841,7 +845,7 @@ def test_append_workflow_export_data_keeps_secret_fields_when_include_secret_tru ] } } - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -1003,7 +1007,8 @@ def test_import_rag_pipeline_sets_default_version_and_kind(mocker) -> None: ) dataset = Mock(id="d1") mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset) - session.query.return_value.filter_by.return_value.all.return_value = [] + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) + session.scalars.return_value.all.return_value = [] mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="P") result = service.import_rag_pipeline( @@ -1061,7 +1066,7 @@ def test_append_workflow_export_data_skips_empty_node_data(mocker) -> None: workflow = Mock() workflow.graph_dict = {"nodes": []} workflow.to_dict.return_value = {"graph": {"nodes": [{"data": {}}, {}]}} - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -1246,11 +1251,12 @@ def test_create_or_update_pipeline_saves_dependencies_to_redis(mocker) -> None: account = Mock(id="u1", current_tenant_id="t1") mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1")) mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock(id="wf-1")) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline") pipeline = pipeline_cls.return_value pipeline.tenant_id = "t1" pipeline.id = "p1" - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None setex = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.setex") dependency = PluginDependency( type=PluginDependency.Type.Marketplace, diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 76fcb19ab2..406b4fb9d0 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -969,8 +969,7 @@ class TestWorkflowService: # 1. Workflow exists # 2. No app is currently using it # 3. Not published as a tool - mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it - mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider + mock_session.scalar.side_effect = [mock_workflow, None, None] # workflow, no app using it, no tool provider with patch("services.workflow_service.select") as mock_select: mock_stmt = MagicMock() @@ -1045,8 +1044,7 @@ class TestWorkflowService: mock_tool_provider = MagicMock() mock_session = MagicMock() - mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it - mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider + mock_session.scalar.side_effect = [mock_workflow, None, mock_tool_provider] # workflow, no app, tool provider with patch("services.workflow_service.select") as mock_select: mock_stmt = MagicMock()