mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 12:00:26 -04:00
refactor(api): use sessionmaker in builtin tools manage service (#34812)
This commit is contained in:
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@@ -46,13 +46,12 @@ class BuiltinToolManageService:
|
||||
delete custom oauth client params
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@@ -150,7 +149,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
@@ -203,9 +202,7 @@ class BuiltinToolManageService:
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -222,7 +219,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
@@ -281,9 +278,7 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
return {"result": "success"}
|
||||
@@ -379,7 +374,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
@@ -393,7 +388,6 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@@ -409,7 +403,7 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
if target_provider is None:
|
||||
@@ -422,7 +416,6 @@ class BuiltinToolManageService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -654,7 +647,7 @@ class BuiltinToolManageService:
|
||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@@ -690,7 +683,6 @@ class BuiltinToolManageService:
|
||||
if enable_oauth_custom_client is not None:
|
||||
custom_client_params.enabled = enable_oauth_custom_client
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -15,17 +15,24 @@ def _mock_session(mock_session_cls):
|
||||
return session
|
||||
|
||||
|
||||
def _mock_sessionmaker(mock_sm_cls):
|
||||
"""Helper: set up a sessionmaker().begin() context manager mock and return the inner session."""
|
||||
session = MagicMock()
|
||||
mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
return session
|
||||
|
||||
|
||||
class TestDeleteCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_deletes_and_returns_success(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
|
||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestListBuiltinToolProviderTools:
|
||||
@@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled:
|
||||
class TestDeleteBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
@@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider:
|
||||
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
mock_cache = MagicMock()
|
||||
@@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.delete.assert_called_once_with(db_provider)
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestSetDefaultProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="provider not found"):
|
||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
target = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||
|
||||
@@ -187,14 +193,13 @@ class TestSetDefaultProvider:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
assert target.is_default is True
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
@@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||
@patch(f"{MODULE}.CredentialType")
|
||||
@patch(f"{MODULE}.ToolManager")
|
||||
@patch(f"{MODULE}.Session")
|
||||
@patch(f"{MODULE}.sessionmaker")
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_session(mock_session_cls)
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
|
||||
@@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.commit.assert_called_once()
|
||||
mock_cache.delete.assert_called_once()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user