From cf50d7c7b52449bd3c8ee40b8b0f80b3663b485e Mon Sep 17 00:00:00 2001 From: Desel72 Date: Tue, 31 Mar 2026 16:10:16 +0300 Subject: [PATCH] refactor: use sessionmaker().begin() in console app controllers (#34282) Co-authored-by: Asuka Minato --- api/controllers/console/app/app.py | 5 ++- api/controllers/console/app/app_import.py | 10 +++--- .../console/app/conversation_variables.py | 4 +-- api/controllers/console/app/workflow.py | 17 +++------- .../console/app/workflow_app_log.py | 6 ++-- .../console/app/workflow_draft_variable.py | 8 ++--- .../console/app/workflow_trigger.py | 11 +++---- .../controllers/console/app/test_app_apis.py | 33 +++++++++++++++---- 8 files changed, 51 insertions(+), 43 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 738e77b371..ec56cd3baa 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,7 +9,7 @@ from graphon.enums import WorkflowExecutionStatus from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest from controllers.common.helpers import FileInfo @@ -642,7 +642,7 @@ class AppCopyApi(Resource): args = CopyAppPayload.model_validate(console_ns.payload or {}) - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) result = import_service.import_app( @@ -655,7 +655,6 @@ class AppCopyApi(Resource): icon=args.icon, icon_background=args.icon_background, ) - session.commit() # Inherit web app permission from original app if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index fdef54ba5a..16e1fa3245 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,6 @@ from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -71,7 +71,7 @@ class AppImportApi(Resource): args = AppImportPayload.model_validate(console_ns.payload) # Create service with session - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = AppDslService(session) # Import app account = current_user @@ -87,7 +87,6 @@ class AppImportApi(Resource): icon_background=args.icon_background, app_id=args.app_id, ) - session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") @@ -112,12 +111,11 @@ class AppImportConfirmApi(Resource): current_user, _ = current_account_with_tenant() # Create service with session - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = AppDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) - session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -134,7 +132,7 @@ class AppImportCheckDependenciesApi(Resource): @marshal_with(app_import_check_dependencies_model) @edit_permission_required def get(self, app_model: App): - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 368a6112ba..369c26a80c 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -2,7 +2,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -69,7 +69,7 @@ class ConversationVariablesApi(Resource): page_size = 100 stmt = stmt.limit(page_size).offset((page - 1) * page_size) - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: rows = session.scalars(stmt).all() return { diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1f5a84c0b2..6df8f7032e 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -10,7 +10,7 @@ from graphon.file import File from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services @@ -840,7 +840,7 @@ class PublishedWorkflowApi(Resource): args = PublishWorkflowPayload.model_validate(console_ns.payload or {}) workflow_service = WorkflowService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflow = workflow_service.publish_workflow( session=session, app_model=app_model, @@ -858,8 +858,6 @@ class PublishedWorkflowApi(Resource): workflow_created_at = TimestampField().format(workflow.created_at) - session.commit() - return { "result": "success", "created_at": workflow_created_at, @@ -982,7 +980,7 @@ class PublishedAllWorkflowApi(Resource): raise Forbidden() workflow_service = WorkflowService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflows, has_more = workflow_service.get_all_published_workflow( session=session, app_model=app_model, @@ -1072,7 +1070,7 @@ class WorkflowByIdApi(Resource): workflow_service = WorkflowService() # Create a session and manage the transaction - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: workflow = workflow_service.update_workflow( session=session, workflow_id=workflow_id, @@ -1084,9 +1082,6 @@ class WorkflowByIdApi(Resource): if not workflow: raise NotFound("Workflow not found") - # Commit the transaction in the controller - session.commit() - return workflow @setup_required @@ -1101,13 +1096,11 @@ class WorkflowByIdApi(Resource): workflow_service = WorkflowService() # Create a session and manage the transaction - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: try: workflow_service.delete_workflow( session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id ) - # Commit the transaction in the controller - session.commit() except WorkflowInUseError as e: abort(400, description=str(e)) except DraftWorkflowDeletionError as e: diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index f0e26c86a5..3b24c2a402 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -5,7 +5,7 @@ from flask import request from flask_restx import Resource, marshal_with from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource): # get paginate workflow app logs workflow_app_service = WorkflowAppService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( session=session, app_model=app_model, @@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource): args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workflow_app_service = WorkflowAppService() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( session=session, app_model=app_model, diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 4052897e9a..35e2df847c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -10,7 +10,7 @@ from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment from graphon.variables.types import SegmentType from pydantic import BaseModel, Field -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.console import console_ns from controllers.console.app.error import ( @@ -244,7 +244,7 @@ class WorkflowVariableCollectionApi(Resource): raise DraftWorkflowNotExist() # fetch draft workflow by app_model - with Session(bind=db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) @@ -298,7 +298,7 @@ class NodeVariableCollectionApi(Resource): @marshal_with(workflow_draft_variable_list_model) def get(self, app_model: App, node_id: str): validate_node_id(node_id) - with Session(bind=db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) @@ -465,7 +465,7 @@ class VariableResetApi(Resource): def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: - with Session(bind=db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 8236e766ae..aa37d24738 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -4,7 +4,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config @@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource): node_id = args.node_id - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # Get webhook trigger for this app and node webhook_trigger = ( session.query(WorkflowWebhookTrigger) @@ -95,7 +95,7 @@ class AppTriggersApi(Resource): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # Get all triggers for this app using select API triggers = ( session.execute( @@ -137,7 +137,7 @@ class AppTriggerEnableApi(Resource): assert current_user.current_tenant_id is not None trigger_id = args.trigger_id - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Find the trigger using select trigger = session.execute( select(AppTrigger).where( @@ -153,9 +153,6 @@ class AppTriggerEnableApi(Resource): # Update status based on enable_trigger boolean trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED - session.commit() - session.refresh(trigger) - # Add computed icon field url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" if trigger.trigger_type == "trigger-plugin": diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index fbaec069bb..0841217fcf 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -383,14 +383,21 @@ class TestWorkflowAppLogEndpoints: monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock())) - class DummySession: + class DummySessionCtx: def __enter__(self): return "session" def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession()) + class DummySessionMaker: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return DummySessionCtx() + + monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker) def fake_get_paginate(self, **_kwargs): return {"items": [], "total": 0} @@ -423,13 +430,20 @@ class TestWorkflowDraftVariableEndpoints: monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) - class DummySession: + class DummySessionCtx: def __enter__(self): return "session" def __exit__(self, exc_type, exc, tb): return False + class DummySessionMaker: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return DummySessionCtx() + class DummyDraftService: def __init__(self, session): self.session = session @@ -437,7 +451,7 @@ class TestWorkflowDraftVariableEndpoints: def list_variables_without_values(self, **_kwargs): return {"items": [], "total": 0} - monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession()) + monkeypatch.setattr(workflow_draft_variable_module, "sessionmaker", DummySessionMaker) class DummyWorkflowService: def is_workflow_exist(self, *args, **kwargs): @@ -543,14 +557,21 @@ class TestWorkflowTriggerEndpoints: session = MagicMock() session.query.return_value.where.return_value.first.return_value = trigger - class DummySession: + class DummySessionCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession()) + class DummySessionMaker: + def __init__(self, *args, **kwargs): + pass + + def begin(self): + return DummySessionCtx() + + monkeypatch.setattr(workflow_trigger_module, "sessionmaker", DummySessionMaker) with app.test_request_context("/?node_id=node-1"): result = method(app_model=SimpleNamespace(id="app-1"))