refactor(services): migrate builtin_tools_manage_service to SQLAlchemy 2.0 select() API (#34973)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wdeveloper16
2026-04-12 03:25:51 +02:00
committed by GitHub
parent 45561bed9d
commit 4ef67fef3a
2 changed files with 117 additions and 93 deletions

View File

@@ -4,7 +4,7 @@ from collections.abc import Mapping
from pathlib import Path
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@@ -47,11 +47,15 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider)
with sessionmaker(bind=db.engine).begin() as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.execute(
delete(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
)
.execution_options(synchronize_session=False)
)
return {"result": "success"}
@staticmethod
@@ -151,13 +155,13 @@ class BuiltinToolManageService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
db_provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
.limit(1)
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
@@ -228,7 +232,13 @@ class BuiltinToolManageService:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
session.scalar(
select(func.count(BuiltinToolProvider.id)).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
)
or 0
)
# check if the provider count is reached the limit
@@ -304,16 +314,15 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type,
db_providers = session.scalars(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.credential_type == credential_type,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
).all()
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
@@ -375,13 +384,13 @@ class BuiltinToolManageService:
delete tool provider
"""
with sessionmaker(bind=db.engine).begin() as session:
db_provider = (
session.query(BuiltinToolProvider)
db_provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
.limit(1)
)
if db_provider is None:
@@ -405,14 +414,26 @@ class BuiltinToolManageService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
target_provider = session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id)
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
session.execute(
update(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.user_id == user_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.is_default.is_(True),
)
.values(is_default=False)
.execution_options(synchronize_session=False)
)
# set new default provider
target_provider.is_default = True
@@ -426,10 +447,13 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
system_client = session.scalar(
select(ToolOAuthSystemClient)
.where(
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
ToolOAuthSystemClient.provider == tool_provider.provider_name,
)
.limit(1)
)
return system_client is not None
@@ -440,15 +464,15 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
user_client = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
return user_client is not None and user_client.enabled
@@ -465,15 +489,15 @@ class BuiltinToolManageService:
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
user_client = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.enabled.is_(True),
)
.first()
.limit(1)
)
oauth_params: Mapping[str, Any] | None = None
if user_client:
@@ -487,10 +511,13 @@ class BuiltinToolManageService:
if not is_verified:
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
system_client = session.scalar(
select(ToolOAuthSystemClient)
.where(
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
ToolOAuthSystemClient.provider == tool_provider.provider_name,
)
.limit(1)
)
if system_client:
try:
@@ -582,8 +609,8 @@ class BuiltinToolManageService:
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
@@ -592,11 +619,11 @@ class BuiltinToolManageService:
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
.limit(1)
)
else:
provider = (
session.query(BuiltinToolProvider)
provider = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
@@ -606,7 +633,7 @@ class BuiltinToolManageService:
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
.limit(1)
)
if provider is None:
@@ -616,14 +643,14 @@ class BuiltinToolManageService:
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
return session.scalar(
select(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
.limit(1)
)
@staticmethod
@@ -648,14 +675,14 @@ class BuiltinToolManageService:
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with sessionmaker(bind=db.engine).begin() as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
custom_client_params = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
)
.first()
.limit(1)
)
# if the record does not exist, create a basic record
@@ -692,14 +719,14 @@ class BuiltinToolManageService:
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
custom_oauth_client_params = session.scalar(
select(ToolOAuthTenantClient)
.where(
ToolOAuthTenantClient.tenant_id == tenant_id,
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
ToolOAuthTenantClient.provider == tool_provider.provider_name,
)
.first()
.limit(1)
)
if custom_oauth_client_params is None:
return {}

View File

@@ -32,7 +32,7 @@ class TestDeleteCustomOauthClientParams:
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
assert result == {"result": "success"}
session.query.return_value.filter_by.return_value.delete.assert_called_once()
session.execute.assert_called_once()
class TestListBuiltinToolProviderTools:
@@ -111,7 +111,7 @@ class TestIsOauthSystemClientExists:
@patch(f"{MODULE}.db")
def test_true_when_exists(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
session.scalar.return_value = MagicMock()
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
@@ -119,7 +119,7 @@ class TestIsOauthSystemClientExists:
@patch(f"{MODULE}.db")
def test_false_when_missing(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
@@ -129,7 +129,7 @@ class TestIsOauthCustomClientEnabled:
@patch(f"{MODULE}.db")
def test_true_when_enabled(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
session.scalar.return_value = MagicMock(enabled=True)
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
@@ -137,7 +137,7 @@ class TestIsOauthCustomClientEnabled:
@patch(f"{MODULE}.db")
def test_false_when_none(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
@@ -149,7 +149,7 @@ class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
@@ -161,7 +161,7 @@ class TestDeleteBuiltinToolProvider:
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock()
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
mock_cache = MagicMock()
mock_enc.return_value = (MagicMock(), mock_cache)
@@ -177,7 +177,7 @@ class TestSetDefaultProvider:
@patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="provider not found"):
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@@ -187,7 +187,7 @@ class TestSetDefaultProvider:
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
target = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = target
session.scalar.return_value = target
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@@ -200,7 +200,7 @@ class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.db")
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None
session.scalar.return_value = None
with pytest.raises(ValueError, match="you have not added provider"):
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
@@ -213,7 +213,7 @@ class TestUpdateBuiltinToolProvider:
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock(credential_type="api_key", credentials="{}")
session.query.return_value.where.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
mock_cred_instance = MagicMock()
mock_cred_instance.is_editable.return_value = True
@@ -274,7 +274,7 @@ class TestGetOauthClient:
mock_create_enc.return_value = (mock_encrypter, MagicMock())
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
session.query.return_value.filter_by.return_value.first.return_value = user_client
session.scalar.return_value = user_client
result = BuiltinToolManageService.get_oauth_client("t", "google")
@@ -297,10 +297,7 @@ class TestGetOauthClient:
mock_create_enc.return_value = (MagicMock(), MagicMock())
system_client = MagicMock(encrypted_oauth_params="enc")
session.query.return_value.filter_by.return_value.first.side_effect = [
None, # user client
system_client, # system client
]
session.scalar.side_effect = [None, system_client]
result = BuiltinToolManageService.get_oauth_client("t", "google")
@@ -325,7 +322,7 @@ class TestGetCustomOauthClientParams:
@patch(f"{MODULE}.db")
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
session = _mock_session(mock_session_cls)
session.query.return_value.filter_by.return_value.first.return_value = None
session.scalar.return_value = None
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
@@ -391,7 +388,7 @@ class TestGetBuiltinProvider:
session = _mock_session(mock_session_cls)
mock_prov_id.return_value.provider_name = "google"
mock_prov_id.return_value.organization = "langgenius"
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
session.scalar.return_value = None
result = BuiltinToolManageService.get_builtin_provider("google", "t")
@@ -417,7 +414,7 @@ class TestGetBuiltinProvider:
return m
mock_prov_id.side_effect = prov_id_side_effect
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
result = BuiltinToolManageService.get_builtin_provider("google", "t")
@@ -439,7 +436,7 @@ class TestGetBuiltinProvider:
mock_prov_id.side_effect = prov_id_side_effect
db_provider = MagicMock(provider="third-party/custom/custom-tool")
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
session.scalar.return_value = db_provider
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
@@ -452,7 +449,7 @@ class TestGetBuiltinProvider:
session = _mock_session(mock_session_cls)
mock_prov_id.side_effect = Exception("parse error")
fallback = MagicMock()
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
session.scalar.return_value = fallback
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")