refactor(api): use sessionmaker in datasource provider service (#34811)

This commit is contained in:
carlos4s
2026-04-09 00:43:13 -05:00
committed by GitHub
parent be1f4b34f8
commit a76a8876d1
2 changed files with 17 additions and 32 deletions

View File

@@ -5,7 +5,7 @@ from typing import Any
from graphon.model_runtime.entities.provider_entities import FormType
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@@ -53,13 +53,12 @@ class DatasourceProviderService:
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def decrypt_datasource_provider_credentials(
self,
@@ -109,7 +108,7 @@ class DatasourceProviderService:
"""
get credential by id
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
@@ -156,7 +155,6 @@ class DatasourceProviderService:
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
session.commit()
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
@@ -174,7 +172,7 @@ class DatasourceProviderService:
"""
get all datasource credentials by provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
@@ -224,7 +222,6 @@ class DatasourceProviderService:
provider=provider,
)
real_credentials_list.append(real_credentials)
session.commit()
return real_credentials_list
@@ -234,7 +231,7 @@ class DatasourceProviderService:
"""
update datasource provider name
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
@@ -266,7 +263,6 @@ class DatasourceProviderService:
raise ValueError("Authorization name is already exists")
target_provider.name = name
session.commit()
return
def set_default_datasource_provider(
@@ -275,7 +271,7 @@ class DatasourceProviderService:
"""
set default datasource provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
@@ -300,7 +296,6 @@ class DatasourceProviderService:
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
def setup_oauth_custom_client_params(
@@ -315,7 +310,7 @@ class DatasourceProviderService:
"""
if client_params is None and enabled is None:
return
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
@@ -349,7 +344,6 @@ class DatasourceProviderService:
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
@@ -488,7 +482,7 @@ class DatasourceProviderService:
"""
update datasource oauth provider
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
@@ -535,7 +529,6 @@ class DatasourceProviderService:
target_provider.expires_at = expire_at
target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
session.commit()
def add_datasource_oauth_provider(
self,
@@ -550,7 +543,7 @@ class DatasourceProviderService:
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=60):
db_provider_name = name
@@ -604,7 +597,6 @@ class DatasourceProviderService:
expires_at=expire_at,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
@@ -623,7 +615,7 @@ class DatasourceProviderService:
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
@@ -670,7 +662,6 @@ class DatasourceProviderService:
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
"""
@@ -926,7 +917,7 @@ class DatasourceProviderService:
update datasource credentials.
"""
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
@@ -980,7 +971,6 @@ class DatasourceProviderService:
encrypted_credentials[key] = value
datasource_provider.encrypted_credentials = encrypted_credentials
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""

View File

@@ -40,7 +40,10 @@ class TestDatasourceProviderService:
q returns itself for .filter_by(), .order_by(), .where() so any
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
"""
with patch("services.datasource_provider_service.Session") as mock_cls:
with (
patch("services.datasource_provider_service.Session") as mock_cls,
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
):
sess = MagicMock(spec=Session)
q = MagicMock()
@@ -63,6 +66,8 @@ class TestDatasourceProviderService:
mock_cls.return_value.__enter__.return_value = sess
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
yield sess
@@ -266,7 +271,6 @@ class TestDatasourceProviderService:
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
):
service.get_datasource_credentials("t1", "prov", "org/plug")
mock_db_session.commit.assert_called_once()
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
"""API key credentials with expires_at=-1 skip refresh and return directly."""
@@ -333,7 +337,6 @@ class TestDatasourceProviderService:
p.name = "same"
mock_db_session.query().first.return_value = p
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
mock_db_session.commit.assert_not_called()
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@@ -352,7 +355,6 @@ class TestDatasourceProviderService:
mock_db_session.query().count.return_value = 0
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
assert p.name == "new_name"
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# set_default_datasource_provider (lines 277-303)
@@ -370,7 +372,6 @@ class TestDatasourceProviderService:
mock_db_session.query().first.return_value = target
service.set_default_datasource_provider("t1", make_id(), "new-id")
assert target.is_default is True
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# get_oauth_encrypter (lines 404-420)
@@ -460,7 +461,6 @@ class TestDatasourceProviderService:
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
"""Conflict on name results in auto-incremented name, not an error."""
@@ -512,7 +512,6 @@ class TestDatasourceProviderService:
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@@ -523,7 +522,6 @@ class TestDatasourceProviderService:
service.reauthorize_datasource_oauth_provider(
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
)
mock_db_session.commit.assert_called_once()
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
@@ -571,7 +569,6 @@ class TestDatasourceProviderService:
):
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
@@ -747,7 +744,6 @@ class TestDatasourceProviderService:
# encrypter must have been called with the new secret value
self._enc.encrypt_token.assert_called()
# commit must be called exactly once
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# remove_datasource_credentials (lines 980-997)
@@ -758,7 +754,6 @@ class TestDatasourceProviderService:
mock_db_session.scalar.return_value = p
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_called_once_with(p)
mock_db_session.commit.assert_called_once()
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""