From 66e588c8caaa7e89fbac3baa197d754d9d628db9 Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:58:38 -0500 Subject: [PATCH] refactor(api): use sessionmaker in builtin tools manage service (#34812) --- .../tools/builtin_tools_manage_service.py | 22 +++----- .../test_builtin_tools_manage_service.py | 54 ++++++++++--------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index d529d2f065..3daaf9a263 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any from sqlalchemy import exists, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -46,13 +46,12 @@ class BuiltinToolManageService: delete custom oauth client params """ tool_provider = ToolProviderID(provider) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.query(ToolOAuthTenantClient).filter_by( tenant_id=tenant_id, provider=tool_provider.provider_name, plugin_id=tool_provider.plugin_id, ).delete() - session.commit() return {"result": "success"} @staticmethod @@ -150,7 +149,7 @@ class BuiltinToolManageService: """ update builtin tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get if the provider exists db_provider = ( session.query(BuiltinToolProvider) @@ -203,9 +202,7 @@ class BuiltinToolManageService: db_provider.name = name - session.commit() except Exception as e: - session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -222,7 +219,7 @@ class BuiltinToolManageService: """ add builtin tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): @@ -281,9 +278,7 @@ class BuiltinToolManageService: ) session.add(db_provider) - session.commit() except Exception as e: - session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -379,7 +374,7 @@ class BuiltinToolManageService: """ delete tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: db_provider = ( session.query(BuiltinToolProvider) .where( @@ -393,7 +388,6 @@ class BuiltinToolManageService: raise ValueError(f"you have not added provider {provider}") session.delete(db_provider) - session.commit() # delete cache provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -409,7 +403,7 @@ class BuiltinToolManageService: """ set default provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get provider target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() if target_provider is None: @@ -422,7 +416,6 @@ class BuiltinToolManageService: # set new default provider target_provider.is_default = True - session.commit() return {"result": "success"} @@ -654,7 +647,7 @@ class BuiltinToolManageService: if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: custom_client_params = ( session.query(ToolOAuthTenantClient) .filter_by( @@ -690,7 +683,6 @@ class BuiltinToolManageService: if enable_oauth_custom_client is not None: custom_client_params.enabled = enable_oauth_custom_client - session.commit() return {"result": "success"} @staticmethod diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index 175900071b..e80c306854 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -15,17 +15,24 @@ def _mock_session(mock_session_cls): return session +def _mock_sessionmaker(mock_sm_cls): + """Helper: set up a sessionmaker().begin() context manager mock and return the inner session.""" + session = MagicMock() + mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) + mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + return session + + class TestDeleteCustomOauthClientParams: - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_deletes_and_returns_success(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_deletes_and_returns_success(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") assert result == {"result": "success"} session.query.return_value.filter_by.return_value.delete.assert_called_once() - session.commit.assert_called_once() class TestListBuiltinToolProviderTools: @@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled: class TestDeleteBuiltinToolProvider: @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.ToolManager") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc): - session = _mock_session(mock_session_cls) + def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc): + session = _mock_sessionmaker(mock_sm_cls) session.query.return_value.where.return_value.first.return_value = None with pytest.raises(ValueError, match="you have not added provider"): @@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider: @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.ToolManager") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc): - session = _mock_session(mock_session_cls) + def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc): + session = _mock_sessionmaker(mock_sm_cls) db_provider = MagicMock() session.query.return_value.where.return_value.first.return_value = db_provider mock_cache = MagicMock() @@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider: assert result == {"result": "success"} session.delete.assert_called_once_with(db_provider) - session.commit.assert_called_once() mock_cache.delete.assert_called_once() class TestSetDefaultProvider: - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_raises_when_not_found(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_raises_when_not_found(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) session.query.return_value.filter_by.return_value.first.return_value = None with pytest.raises(ValueError, match="provider not found"): BuiltinToolManageService.set_default_provider("t", "u", "p", "id") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_sets_default_and_clears_old(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) target = MagicMock() session.query.return_value.filter_by.return_value.first.return_value = target @@ -187,14 +193,13 @@ class TestSetDefaultProvider: assert result == {"result": "success"} assert target.is_default is True - session.commit.assert_called_once() class TestUpdateBuiltinToolProvider: - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls): - session = _mock_session(mock_session_cls) + def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls): + session = _mock_sessionmaker(mock_sm_cls) session.query.return_value.where.return_value.first.return_value = None with pytest.raises(ValueError, match="you have not added provider"): @@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider: @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.CredentialType") @patch(f"{MODULE}.ToolManager") - @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.sessionmaker") @patch(f"{MODULE}.db") - def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc): - session = _mock_session(mock_session_cls) + def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc): + session = _mock_sessionmaker(mock_sm_cls) db_provider = MagicMock(credential_type="api_key", credentials="{}") session.query.return_value.where.return_value.first.return_value = db_provider @@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider: result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"}) assert result == {"result": "success"} - session.commit.assert_called_once() mock_cache.delete.assert_called_once()