refactor(services): migrate datasource_provider_service to SQLAlchemy 2.0 select() API (#34974)

This commit is contained in:
wdeveloper16
2026-04-12 03:23:24 +02:00
committed by GitHub
parent 7515eee0a8
commit 7bd5e80323
2 changed files with 205 additions and 170 deletions

View File

@@ -4,7 +4,7 @@ from collections.abc import Mapping
from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@@ -54,11 +54,13 @@ class DatasourceProviderService:
remove oauth custom client params
"""
with sessionmaker(bind=db.engine).begin() as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.execute(
delete(DatasourceOauthTenantParamConfig).where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
)
def decrypt_datasource_provider_credentials(
self,
@@ -110,15 +112,21 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id)
.limit(1)
)
else:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
.limit(1)
)
if not datasource_provider:
return {}
@@ -173,12 +181,15 @@ class DatasourceProviderService:
get all datasource credentials by provider
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
datasource_providers = session.scalars(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.all()
)
).all()
if not datasource_providers:
return []
current_user = get_current_user()
@@ -232,15 +243,15 @@ class DatasourceProviderService:
update datasource provider name
"""
with sessionmaker(bind=db.engine).begin() as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
target_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == credential_id,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
@@ -250,16 +261,16 @@ class DatasourceProviderService:
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == name,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
)
.count()
> 0
):
or 0
) > 0:
raise ValueError("Authorization name is already exists")
target_provider.name = name
@@ -273,26 +284,31 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
target_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == credential_id,
DatasourceProvider.provider == datasource_provider_id.provider_name,
DatasourceProvider.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})
session.execute(
update(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == target_provider.provider,
DatasourceProvider.plugin_id == target_provider.plugin_id,
DatasourceProvider.is_default.is_(True),
)
.values(is_default=False)
.execution_options(synchronize_session=False)
)
# set new default provider
target_provider.is_default = True
@@ -311,14 +327,14 @@ class DatasourceProviderService:
if client_params is None and enabled is None:
return
with sessionmaker(bind=db.engine).begin() as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
tenant_oauth_client_params = session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.first()
.limit(1)
)
if not tenant_oauth_client_params:
@@ -351,9 +367,14 @@ class DatasourceProviderService:
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
session.scalar(
select(DatasourceOauthParamConfig)
.where(
DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name,
DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id,
)
.limit(1)
)
is not None
)
@@ -423,15 +444,15 @@ class DatasourceProviderService:
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
tenant_oauth_client_params = session.scalar(
select(DatasourceOauthTenantParamConfig)
.where(
DatasourceOauthTenantParamConfig.tenant_id == tenant_id,
DatasourceOauthTenantParamConfig.provider == provider,
DatasourceOauthTenantParamConfig.plugin_id == plugin_id,
DatasourceOauthTenantParamConfig.enabled.is_(True),
)
.first()
.limit(1)
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
@@ -443,8 +464,13 @@ class DatasourceProviderService:
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if is_verified:
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
oauth_client_params = session.scalar(
select(DatasourceOauthParamConfig)
.where(
DatasourceOauthParamConfig.provider == provider,
DatasourceOauthParamConfig.plugin_id == plugin_id,
)
.limit(1)
)
if oauth_client_params:
return oauth_client_params.system_credentials
@@ -455,15 +481,13 @@ class DatasourceProviderService:
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
db_providers = session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
.all()
)
).all()
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
@@ -485,8 +509,10 @@ class DatasourceProviderService:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
target_provider = session.scalar(
select(DatasourceProvider)
.where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id)
.limit(1)
)
if target_provider is None:
raise ValueError("provider not found")
@@ -496,25 +522,28 @@ class DatasourceProviderService:
db_provider_name = target_provider.name
else:
name_conflict = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=CredentialType.OAUTH2.value,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == db_provider_name,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
DatasourceProvider.auth_type == CredentialType.OAUTH2.value,
)
)
.count()
or 0
)
if name_conflict > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
for provider in session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
).all()
],
db_provider_name,
)
@@ -556,25 +585,27 @@ class DatasourceProviderService:
)
else:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == db_provider_name,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
DatasourceProvider.auth_type == credential_type.value,
)
)
.count()
> 0
):
or 0
) > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
for provider in session.scalars(
select(DatasourceProvider).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider_id.provider_name,
DatasourceProvider.plugin_id == provider_id.plugin_id,
)
).all()
],
db_provider_name,
)
@@ -627,11 +658,16 @@ class DatasourceProviderService:
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
.count()
> 0
):
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.plugin_id == plugin_id,
DatasourceProvider.provider == provider_name,
DatasourceProvider.name == db_provider_name,
)
)
or 0
) > 0:
raise ValueError("Authorization name is already exists")
try:
@@ -918,21 +954,31 @@ class DatasourceProviderService:
"""
with sessionmaker(bind=db.engine).begin() as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
datasource_provider = session.scalar(
select(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.id == auth_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.limit(1)
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
session.scalar(
select(func.count(DatasourceProvider.id)).where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.name == name,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
)
or 0
) > 0:
raise ValueError("Authorization name is already exists")
datasource_provider.name = name

View File

@@ -179,11 +179,11 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
mock_db_session.query().first.return_value = MagicMock()
mock_db_session.scalar.return_value = MagicMock()
assert service.is_system_oauth_params_exist(make_id()) is True
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
assert service.is_system_oauth_params_exist(make_id()) is False
# -----------------------------------------------------------------------
@@ -205,7 +205,7 @@ class TestDatasourceProviderService:
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
service.remove_oauth_custom_client_params("t1", make_id())
mock_db_session.query().delete.assert_called_once()
mock_db_session.execute.assert_called_once()
# -----------------------------------------------------------------------
# setup_oauth_custom_client_params (315-351)
@@ -217,14 +217,14 @@ class TestDatasourceProviderService:
mock_db_session.add.assert_not_called()
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
mock_db_session.add.assert_called_once()
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
existing = MagicMock()
mock_db_session.query().first.return_value = existing
mock_db_session.scalar.return_value = existing
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
mock_db_session.add.assert_not_called() # update in place, no add
@@ -255,7 +255,7 @@ class TestDatasourceProviderService:
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
@@ -264,7 +264,7 @@ class TestDatasourceProviderService:
p.auth_type = "oauth2"
p.expires_at = 0 # expired
p.encrypted_credentials = {"tok": "x"}
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
@@ -278,7 +278,7 @@ class TestDatasourceProviderService:
p.auth_type = "api_key"
p.expires_at = -1 # sentinel: never expires
p.encrypted_credentials = {"k": "v"}
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
@@ -292,7 +292,7 @@ class TestDatasourceProviderService:
p.auth_type = "api_key"
p.expires_at = -1
p.encrypted_credentials = {}
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
@@ -306,7 +306,7 @@ class TestDatasourceProviderService:
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
mock_db_session.query().all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
@@ -314,7 +314,7 @@ class TestDatasourceProviderService:
p.auth_type = "oauth2"
p.expires_at = 0
p.encrypted_credentials = {"t": "x"}
mock_db_session.query().all.return_value = [p]
mock_db_session.scalars.return_value.all.return_value = [p]
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
@@ -328,22 +328,21 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="not found"):
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "same"
mock_db_session.query().first.return_value = p
mock_db_session.scalar.return_value = p
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.is_default = False
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1 # conflict
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
with pytest.raises(ValueError, match="already exists"):
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
@@ -351,8 +350,7 @@ class TestDatasourceProviderService:
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.is_default = False
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
assert p.name == "new_name"
@@ -361,7 +359,7 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="not found"):
service.set_default_datasource_provider("t1", make_id(), "bad-id")
@@ -369,7 +367,7 @@ class TestDatasourceProviderService:
target = MagicMock(spec=DatasourceProvider)
target.provider = "provider"
target.plugin_id = "org/plug"
mock_db_session.query().first.return_value = target
mock_db_session.scalar.return_value = target
service.set_default_datasource_provider("t1", make_id(), "new-id")
assert target.is_default is True
@@ -428,13 +426,13 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
mock_db_session.scalar.return_value = MagicMock(client_params={"k": "v"})
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_oauth_client("t1", make_id())
assert result == {"k": "dec"}
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
mock_db_session.scalar.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
with (
patch.object(service.provider_manager, "fetch_datasource_provider"),
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
@@ -444,7 +442,7 @@ class TestDatasourceProviderService:
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
"""Neither tenant nor system credentials → raises ValueError."""
mock_db_session.query().first.side_effect = [None, None]
mock_db_session.scalar.side_effect = [None, None]
with (
patch.object(service.provider_manager, "fetch_datasource_provider"),
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
@@ -457,15 +455,14 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
"""Conflict on name results in auto-incremented name, not an error."""
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
mock_db_session.query().all.return_value = []
mock_db_session.scalar.return_value = 1 # conflict first, then auto-named
with (
patch.object(service, "extract_secret_variables", return_value=[]),
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
@@ -475,8 +472,7 @@ class TestDatasourceProviderService:
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
mock_db_session.query().count.return_value = 0
mock_db_session.query().all.return_value = []
mock_db_session.scalar.return_value = 0
with (
patch.object(service, "extract_secret_variables", return_value=[]),
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
@@ -485,13 +481,13 @@ class TestDatasourceProviderService:
mock_db_session.add.assert_called_once()
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
self._enc.encrypt_token.assert_called()
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
self._redis.lock.assert_called()
@@ -501,23 +497,21 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with patch.object(service, "extract_secret_variables", return_value=[]):
with pytest.raises(ValueError, match="not found"):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1 # conflict
mock_db_session.query().all.return_value = []
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
mock_db_session.scalars.return_value.all.return_value = []
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
service.reauthorize_datasource_oauth_provider(
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
@@ -525,16 +519,14 @@ class TestDatasourceProviderService:
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
self._enc.encrypt_token.assert_called()
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
self._redis.lock.assert_called()
@@ -545,13 +537,13 @@ class TestDatasourceProviderService:
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
"""explicit name supplied + conflict → raises ValueError immediately."""
mock_db_session.query().count.return_value = 1
mock_db_session.scalar.return_value = 1
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="already exists"):
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
@@ -561,7 +553,7 @@ class TestDatasourceProviderService:
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials"),
@@ -571,7 +563,7 @@ class TestDatasourceProviderService:
mock_db_session.add.assert_called_once()
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials"),
@@ -694,7 +686,7 @@ class TestDatasourceProviderService:
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
mock_db_session.query().first.return_value = None
mock_db_session.scalar.return_value = None
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="not found"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
@@ -704,8 +696,7 @@ class TestDatasourceProviderService:
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "e"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1
mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="already exists"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
@@ -717,8 +708,7 @@ class TestDatasourceProviderService:
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "e"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "extract_secret_variables", return_value=["sk"]),
@@ -733,8 +723,7 @@ class TestDatasourceProviderService:
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "old_enc"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "extract_secret_variables", return_value=["sk"]),