From dbdbb098d5cc04e7e89880f1e2ad43bad3e7cd5f Mon Sep 17 00:00:00 2001 From: Desel72 Date: Tue, 31 Mar 2026 17:28:05 +0300 Subject: [PATCH] =?UTF-8?q?refactor:=20use=20sessionmaker().begin()=20in?= =?UTF-8?q?=20console=20workspace=20and=20misc=20co=E2=80=A6=20(#34284)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controllers/console/apikey.py | 4 +-- .../console/explore/conversation.py | 4 +-- api/controllers/console/workspace/__init__.py | 4 +-- api/controllers/console/workspace/account.py | 4 +-- .../console/workspace/tool_providers.py | 26 +++++++++---------- .../console/workspace/trigger_providers.py | 5 ++-- .../console/workspace/test_tool_provider.py | 4 +-- .../workspace/test_trigger_providers.py | 8 +++--- 8 files changed, 29 insertions(+), 30 deletions(-) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 783cb5c444..772bb9d0f1 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -2,7 +2,7 @@ import flask_restx from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus from sqlalchemy import delete, func, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden from extensions.ext_database import db @@ -34,7 +34,7 @@ api_key_list_model = console_ns.model( def _get_resource(resource_id, tenant_id, resource_model): - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: resource = session.execute( select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) ).scalar_one_or_none() diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 933c80f509..092f509f1c 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -2,7 +2,7 @@ from typing import Any from flask import request from pydantic import BaseModel, Field, TypeAdapter, model_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models @@ -74,7 +74,7 @@ class ConversationListApi(InstalledAppResource): try: if not isinstance(current_user, Account): raise ValueError("current_user must be an Account instance") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 876e2301f2..9484cc773e 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -2,7 +2,7 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden from extensions.ext_database import db @@ -24,7 +24,7 @@ def plugin_permission_required( user = current_user tenant_id = current_tenant_id - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: permission = ( session.query(TenantPluginPermission) .where( diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 6f93ff1e70..dcd4438b67 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,7 +8,7 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import supported_language @@ -562,7 +562,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) if account is None: raise AccountNotFound() diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 80216915cd..c9956501e2 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -7,7 +7,7 @@ from flask import make_response, redirect, request, send_file from flask_restx import Resource from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden from configs import dify_config @@ -1019,7 +1019,7 @@ class ToolProviderMCPApi(Resource): # Step 1: Get provider data for URL validation (short-lived session, no network I/O) validation_data = None - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) validation_data = service.get_provider_for_url_validation( tenant_id=current_tenant_id, provider_id=payload.provider_id @@ -1034,7 +1034,7 @@ class ToolProviderMCPApi(Resource): ) # Step 3: Perform database update in a transaction - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, @@ -1061,7 +1061,7 @@ class ToolProviderMCPApi(Resource): payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) @@ -1079,7 +1079,7 @@ class ToolMCPAuthApi(Resource): provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) if not db_provider: @@ -1100,7 +1100,7 @@ class ToolMCPAuthApi(Resource): sse_read_timeout=provider_entity.sse_read_timeout, ): # Update credentials in new transaction - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.update_provider_credentials( provider_id=provider_id, @@ -1118,17 +1118,17 @@ class ToolMCPAuthApi(Resource): resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) response = service.execute_auth_actions(auth_result) return response except MCPRefreshTokenError as e: - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e except (MCPError, ValueError) as e: - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) raise ValueError(f"Failed to connect to MCP server: {e}") from e @@ -1141,7 +1141,7 @@ class ToolMCPDetailApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True)) @@ -1155,7 +1155,7 @@ class ToolMCPListAllApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) # Skip sensitive data decryption for list view to improve performance tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False) @@ -1170,7 +1170,7 @@ class ToolMCPUpdateApi(Resource): @account_initialization_required def get(self, provider_id): _, tenant_id = current_account_with_tenant() - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) tools = service.list_provider_tools( tenant_id=tenant_id, @@ -1188,7 +1188,7 @@ class ToolMCPCallbackApi(Resource): authorization_code = query.code # Create service instance for handle_callback - with Session(db.engine) as session, session.begin(): + with sessionmaker(db.engine).begin() as session: mcp_service = MCPToolManageService(session=session) # handle_callback now returns state data and tokens state_data, tokens = handle_callback(state_key, authorization_code) diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 76d64cb97c..7a28a09861 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -5,7 +5,7 @@ from flask import make_response, redirect, request from flask_restx import Resource from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config @@ -375,7 +375,7 @@ class TriggerSubscriptionDeleteApi(Resource): assert user.current_tenant_id is not None try: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # Delete trigger provider subscription TriggerProviderService.delete_trigger_provider( session=session, @@ -388,7 +388,6 @@ class TriggerSubscriptionDeleteApi(Resource): tenant_id=user.current_tenant_id, subscription_id=subscription_id, ) - session.commit() return {"result": "success"} except ValueError as e: raise BadRequest(str(e)) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index e36bd213d9..f2e7104b18 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -69,7 +69,7 @@ def client(flask_app_with_containers): return_value=(MagicMock(id="u1"), "t1"), autospec=True, ) -@patch("controllers.console.workspace.tool_providers.Session", autospec=True) +@patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True) @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True) @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): @@ -88,7 +88,7 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_ create_result.id = "provider-1" svc.create_provider.return_value = create_result svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path - mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.begin.return_value.__enter__.return_value = MagicMock() # Patch MCPToolManageService constructed inside controller with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True): payload = { diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index b4d12bff62..ca8195af53 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -306,14 +306,14 @@ class TestTriggerSubscriptionCrud: app.test_request_context("/"), patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch("controllers.console.workspace.trigger_providers.db") as mock_db, - patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls, + patch("controllers.console.workspace.trigger_providers.sessionmaker") as mock_session_cls, patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription" ), ): mock_db.engine = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session result = method(api, "sub1") @@ -327,14 +327,14 @@ class TestTriggerSubscriptionCrud: app.test_request_context("/"), patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch("controllers.console.workspace.trigger_providers.db") as mock_db, - patch("controllers.console.workspace.trigger_providers.Session") as session_cls, + patch("controllers.console.workspace.trigger_providers.sessionmaker") as session_cls, patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider", side_effect=ValueError("bad"), ), ): mock_db.engine = MagicMock() - session_cls.return_value.__enter__.return_value = MagicMock() + session_cls.return_value.begin.return_value.__enter__.return_value = MagicMock() with pytest.raises(BadRequest): method(api, "sub1")