mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 12:00:19 -04:00
refactor(services): migrate builtin_tools_manage_service to SQLAlchemy 2.0 select() API (#34973)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@@ -47,11 +47,15 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(ToolOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@@ -151,13 +155,13 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get if the provider exists
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
db_provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if db_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
@@ -228,7 +232,13 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
provider_count = (
|
||||
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
session.scalar(
|
||||
select(func.count(BuiltinToolProvider.id)).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# check if the provider count is reached the limit
|
||||
@@ -304,16 +314,15 @@ class BuiltinToolManageService:
|
||||
def generate_builtin_tool_provider_name(
|
||||
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
||||
) -> str:
|
||||
db_providers = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type,
|
||||
db_providers = session.scalars(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.credential_type == credential_type,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
@@ -375,13 +384,13 @@ class BuiltinToolManageService:
|
||||
delete tool provider
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
db_provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
db_provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if db_provider is None:
|
||||
@@ -405,14 +414,26 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
target_provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
session.execute(
|
||||
update(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.user_id == user_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.is_default.is_(True),
|
||||
)
|
||||
.values(is_default=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
@@ -426,10 +447,13 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider_name)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(ToolOAuthSystemClient)
|
||||
.where(
|
||||
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthSystemClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return system_client is not None
|
||||
|
||||
@@ -440,15 +464,15 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
user_client = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return user_client is not None and user_client.enabled
|
||||
|
||||
@@ -465,15 +489,15 @@ class BuiltinToolManageService:
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=tool_provider.provider_name,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
enabled=True,
|
||||
user_client = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if user_client:
|
||||
@@ -487,10 +511,13 @@ class BuiltinToolManageService:
|
||||
if not is_verified:
|
||||
return oauth_params
|
||||
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
.first()
|
||||
system_client = session.scalar(
|
||||
select(ToolOAuthSystemClient)
|
||||
.where(
|
||||
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthSystemClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if system_client:
|
||||
try:
|
||||
@@ -582,8 +609,8 @@ class BuiltinToolManageService:
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
@@ -592,11 +619,11 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
provider = session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
@@ -606,7 +633,7 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -616,14 +643,14 @@ class BuiltinToolManageService:
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
session.query(BuiltinToolProvider)
|
||||
return session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -648,14 +675,14 @@ class BuiltinToolManageService:
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
custom_client_params = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
custom_client_params = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# if the record does not exist, create a basic record
|
||||
@@ -692,14 +719,14 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
tool_provider = ToolProviderID(provider)
|
||||
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
custom_oauth_client_params = session.scalar(
|
||||
select(ToolOAuthTenantClient)
|
||||
.where(
|
||||
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if custom_oauth_client_params is None:
|
||||
return {}
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestDeleteCustomOauthClientParams:
|
||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestListBuiltinToolProviderTools:
|
||||
@@ -111,7 +111,7 @@ class TestIsOauthSystemClientExists:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_true_when_exists(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
|
||||
session.scalar.return_value = MagicMock()
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
|
||||
|
||||
@@ -119,7 +119,7 @@ class TestIsOauthSystemClientExists:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_false_when_missing(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
|
||||
|
||||
@@ -129,7 +129,7 @@ class TestIsOauthCustomClientEnabled:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_true_when_enabled(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
|
||||
session.scalar.return_value = MagicMock(enabled=True)
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestIsOauthCustomClientEnabled:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_false_when_none(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestDeleteBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
|
||||
@@ -161,7 +161,7 @@ class TestDeleteBuiltinToolProvider:
|
||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
mock_cache = MagicMock()
|
||||
mock_enc.return_value = (MagicMock(), mock_cache)
|
||||
|
||||
@@ -177,7 +177,7 @@ class TestSetDefaultProvider:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="provider not found"):
|
||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
@@ -187,7 +187,7 @@ class TestSetDefaultProvider:
|
||||
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
target = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||
session.scalar.return_value = target
|
||||
|
||||
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||
|
||||
@@ -200,7 +200,7 @@ class TestUpdateBuiltinToolProvider:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="you have not added provider"):
|
||||
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
|
||||
@@ -213,7 +213,7 @@ class TestUpdateBuiltinToolProvider:
|
||||
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||
session = _mock_sessionmaker(mock_sm_cls)
|
||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
|
||||
mock_cred_instance = MagicMock()
|
||||
mock_cred_instance.is_editable.return_value = True
|
||||
@@ -274,7 +274,7 @@ class TestGetOauthClient:
|
||||
mock_create_enc.return_value = (mock_encrypter, MagicMock())
|
||||
|
||||
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
|
||||
session.query.return_value.filter_by.return_value.first.return_value = user_client
|
||||
session.scalar.return_value = user_client
|
||||
|
||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||
|
||||
@@ -297,10 +297,7 @@ class TestGetOauthClient:
|
||||
mock_create_enc.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
system_client = MagicMock(encrypted_oauth_params="enc")
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
None, # user client
|
||||
system_client, # system client
|
||||
]
|
||||
session.scalar.side_effect = [None, system_client]
|
||||
|
||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||
|
||||
@@ -325,7 +322,7 @@ class TestGetCustomOauthClientParams:
|
||||
@patch(f"{MODULE}.db")
|
||||
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
|
||||
session = _mock_session(mock_session_cls)
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
|
||||
|
||||
@@ -391,7 +388,7 @@ class TestGetBuiltinProvider:
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.return_value.provider_name = "google"
|
||||
mock_prov_id.return_value.organization = "langgenius"
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||
|
||||
@@ -417,7 +414,7 @@ class TestGetBuiltinProvider:
|
||||
return m
|
||||
|
||||
mock_prov_id.side_effect = prov_id_side_effect
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||
|
||||
@@ -439,7 +436,7 @@ class TestGetBuiltinProvider:
|
||||
|
||||
mock_prov_id.side_effect = prov_id_side_effect
|
||||
db_provider = MagicMock(provider="third-party/custom/custom-tool")
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
||||
session.scalar.return_value = db_provider
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
|
||||
|
||||
@@ -452,7 +449,7 @@ class TestGetBuiltinProvider:
|
||||
session = _mock_session(mock_session_cls)
|
||||
mock_prov_id.side_effect = Exception("parse error")
|
||||
fallback = MagicMock()
|
||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
|
||||
session.scalar.return_value = fallback
|
||||
|
||||
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user