mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 21:00:25 -04:00
refactor(services): migrate datasource_provider_service to SQLAlchemy 2.0 select() API (#34974)
This commit is contained in:
@@ -4,7 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@@ -54,11 +54,13 @@ class DatasourceProviderService:
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.execute(
|
||||
delete(DatasourceOauthTenantParamConfig).where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
)
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@@ -110,15 +112,21 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id)
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
@@ -173,12 +181,15 @@ class DatasourceProviderService:
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
datasource_providers = session.scalars(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not datasource_providers:
|
||||
return []
|
||||
current_user = get_current_user()
|
||||
@@ -232,15 +243,15 @@ class DatasourceProviderService:
|
||||
update datasource provider name
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == credential_id,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
@@ -250,16 +261,16 @@ class DatasourceProviderService:
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == name,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
@@ -273,26 +284,31 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
id=credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == credential_id,
|
||||
DatasourceProvider.provider == datasource_provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=target_provider.provider,
|
||||
plugin_id=target_provider.plugin_id,
|
||||
is_default=True,
|
||||
).update({"is_default": False})
|
||||
session.execute(
|
||||
update(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == target_provider.provider,
|
||||
DatasourceProvider.plugin_id == target_provider.plugin_id,
|
||||
DatasourceProvider.is_default.is_(True),
|
||||
)
|
||||
.values(is_default=False)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
@@ -311,14 +327,14 @@ class DatasourceProviderService:
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
tenant_oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not tenant_oauth_client_params:
|
||||
@@ -351,9 +367,14 @@ class DatasourceProviderService:
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
||||
.first()
|
||||
session.scalar(
|
||||
select(DatasourceOauthParamConfig)
|
||||
.where(
|
||||
DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name,
|
||||
DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
@@ -423,15 +444,15 @@ class DatasourceProviderService:
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
# get tenant oauth client params
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
enabled=True,
|
||||
tenant_oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthTenantParamConfig)
|
||||
.where(
|
||||
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
|
||||
DatasourceOauthTenantParamConfig.provider == provider,
|
||||
DatasourceOauthTenantParamConfig.plugin_id == plugin_id,
|
||||
DatasourceOauthTenantParamConfig.enabled.is_(True),
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
@@ -443,8 +464,13 @@ class DatasourceProviderService:
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
|
||||
if is_verified:
|
||||
# fallback to system oauth client params
|
||||
oauth_client_params = (
|
||||
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
oauth_client_params = session.scalar(
|
||||
select(DatasourceOauthParamConfig)
|
||||
.where(
|
||||
DatasourceOauthParamConfig.provider == provider,
|
||||
DatasourceOauthParamConfig.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if oauth_client_params:
|
||||
return oauth_client_params.system_credentials
|
||||
@@ -455,15 +481,13 @@ class DatasourceProviderService:
|
||||
def generate_next_datasource_provider_name(
|
||||
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
||||
) -> str:
|
||||
db_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
db_providers = session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
@@ -485,8 +509,10 @@ class DatasourceProviderService:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
|
||||
target_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
@@ -496,25 +522,28 @@ class DatasourceProviderService:
|
||||
db_provider_name = target_provider.name
|
||||
else:
|
||||
name_conflict = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=CredentialType.OAUTH2.value,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
DatasourceProvider.auth_type == CredentialType.OAUTH2.value,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
if name_conflict > 0:
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
for provider in session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
).all()
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
@@ -556,25 +585,27 @@ class DatasourceProviderService:
|
||||
)
|
||||
else:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
DatasourceProvider.auth_type == credential_type.value,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
or 0
|
||||
) > 0:
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
for provider in session.scalars(
|
||||
select(DatasourceProvider).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider_id.provider_name,
|
||||
DatasourceProvider.plugin_id == provider_id.plugin_id,
|
||||
)
|
||||
).all()
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
@@ -627,11 +658,16 @@ class DatasourceProviderService:
|
||||
|
||||
# check name is exist
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
DatasourceProvider.provider == provider_name,
|
||||
DatasourceProvider.name == db_provider_name,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
try:
|
||||
@@ -918,21 +954,31 @@ class DatasourceProviderService:
|
||||
"""
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
datasource_provider = session.scalar(
|
||||
select(DatasourceProvider)
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.id == auth_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not datasource_provider:
|
||||
raise ValueError("Datasource provider not found")
|
||||
# update name
|
||||
if name and name != datasource_provider.name:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
session.scalar(
|
||||
select(func.count(DatasourceProvider.id)).where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.name == name,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
) > 0:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
datasource_provider.name = name
|
||||
|
||||
|
||||
@@ -179,11 +179,11 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = MagicMock()
|
||||
mock_db_session.scalar.return_value = MagicMock()
|
||||
assert service.is_system_oauth_params_exist(make_id()) is True
|
||||
|
||||
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
assert service.is_system_oauth_params_exist(make_id()) is False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -205,7 +205,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
|
||||
service.remove_oauth_custom_client_params("t1", make_id())
|
||||
mock_db_session.query().delete.assert_called_once()
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# setup_oauth_custom_client_params (315-351)
|
||||
@@ -217,14 +217,14 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
|
||||
existing = MagicMock()
|
||||
mock_db_session.query().first.return_value = existing
|
||||
mock_db_session.scalar.return_value = existing
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
|
||||
mock_db_session.add.assert_not_called() # update in place, no add
|
||||
@@ -255,7 +255,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
|
||||
|
||||
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
|
||||
@@ -264,7 +264,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0 # expired
|
||||
p.encrypted_credentials = {"tok": "x"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
@@ -278,7 +278,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "api_key"
|
||||
p.expires_at = -1 # sentinel: never expires
|
||||
p.encrypted_credentials = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
|
||||
@@ -292,7 +292,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "api_key"
|
||||
p.expires_at = -1
|
||||
p.encrypted_credentials = {}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
|
||||
@@ -306,7 +306,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
|
||||
@@ -314,7 +314,7 @@ class TestDatasourceProviderService:
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0
|
||||
p.encrypted_credentials = {"t": "x"}
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
mock_db_session.scalars.return_value.all.return_value = [p]
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
@@ -328,22 +328,21 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
|
||||
|
||||
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.is_default = False
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1 # conflict
|
||||
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
|
||||
@@ -351,8 +350,7 @@ class TestDatasourceProviderService:
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.is_default = False
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
|
||||
@@ -361,7 +359,7 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.set_default_datasource_provider("t1", make_id(), "bad-id")
|
||||
|
||||
@@ -369,7 +367,7 @@ class TestDatasourceProviderService:
|
||||
target = MagicMock(spec=DatasourceProvider)
|
||||
target.provider = "provider"
|
||||
target.plugin_id = "org/plug"
|
||||
mock_db_session.query().first.return_value = target
|
||||
mock_db_session.scalar.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
|
||||
@@ -428,13 +426,13 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
|
||||
mock_db_session.scalar.return_value = MagicMock(client_params={"k": "v"})
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_oauth_client("t1", make_id())
|
||||
assert result == {"k": "dec"}
|
||||
|
||||
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
|
||||
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
|
||||
mock_db_session.scalar.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
||||
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
|
||||
@@ -444,7 +442,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
|
||||
"""Neither tenant nor system credentials → raises ValueError."""
|
||||
mock_db_session.query().first.side_effect = [None, None]
|
||||
mock_db_session.scalar.side_effect = [None, None]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
||||
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
|
||||
@@ -457,15 +455,14 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalar.return_value = 1 # conflict first, then auto-named
|
||||
with (
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
|
||||
@@ -475,8 +472,7 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
|
||||
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
|
||||
@@ -485,13 +481,13 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
|
||||
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
|
||||
self._redis.lock.assert_called()
|
||||
@@ -501,23 +497,21 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
|
||||
|
||||
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1 # conflict
|
||||
mock_db_session.query().all.return_value = []
|
||||
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
@@ -525,16 +519,14 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
||||
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
self._redis.lock.assert_called()
|
||||
@@ -545,13 +537,13 @@ class TestDatasourceProviderService:
|
||||
|
||||
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
|
||||
"""explicit name supplied + conflict → raises ValueError immediately."""
|
||||
mock_db_session.query().count.return_value = 1
|
||||
mock_db_session.scalar.return_value = 1
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
|
||||
|
||||
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
|
||||
@@ -561,7 +553,7 @@ class TestDatasourceProviderService:
|
||||
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
|
||||
|
||||
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
@@ -571,7 +563,7 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
@@ -694,7 +686,7 @@ class TestDatasourceProviderService:
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
|
||||
@@ -704,8 +696,7 @@ class TestDatasourceProviderService:
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "e"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1
|
||||
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
|
||||
@@ -717,8 +708,7 @@ class TestDatasourceProviderService:
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "e"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
@@ -733,8 +723,7 @@ class TestDatasourceProviderService:
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "old_enc"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
|
||||
Reference in New Issue
Block a user