mirror of
https://github.com/langgenius/dify.git
synced 2026-04-10 12:00:26 -04:00
refactor(api): use sessionmaker in datasource provider service (#34811)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import FormType
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
@@ -53,13 +53,12 @@ class DatasourceProviderService:
|
||||
"""
|
||||
remove oauth custom client params
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
def decrypt_datasource_provider_credentials(
|
||||
self,
|
||||
@@ -109,7 +108,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get credential by id
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
if credential_id:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
@@ -156,7 +155,6 @@ class DatasourceProviderService:
|
||||
datasource_provider=datasource_provider,
|
||||
)
|
||||
datasource_provider.expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
return self.decrypt_datasource_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
@@ -174,7 +172,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
get all datasource credentials by provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_providers = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
@@ -224,7 +222,6 @@ class DatasourceProviderService:
|
||||
provider=provider,
|
||||
)
|
||||
real_credentials_list.append(real_credentials)
|
||||
session.commit()
|
||||
|
||||
return real_credentials_list
|
||||
|
||||
@@ -234,7 +231,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource provider name
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
@@ -266,7 +263,6 @@ class DatasourceProviderService:
|
||||
raise ValueError("Authorization name is already exists")
|
||||
|
||||
target_provider.name = name
|
||||
session.commit()
|
||||
return
|
||||
|
||||
def set_default_datasource_provider(
|
||||
@@ -275,7 +271,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
set default datasource provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
# get provider
|
||||
target_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
@@ -300,7 +296,6 @@ class DatasourceProviderService:
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
@@ -315,7 +310,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
@@ -349,7 +344,6 @@ class DatasourceProviderService:
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
session.commit()
|
||||
|
||||
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
@@ -488,7 +482,7 @@ class DatasourceProviderService:
|
||||
"""
|
||||
update datasource oauth provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
target_provider = (
|
||||
@@ -535,7 +529,6 @@ class DatasourceProviderService:
|
||||
target_provider.expires_at = expire_at
|
||||
target_provider.encrypted_credentials = credentials
|
||||
target_provider.avatar_url = avatar_url or target_provider.avatar_url
|
||||
session.commit()
|
||||
|
||||
def add_datasource_oauth_provider(
|
||||
self,
|
||||
@@ -550,7 +543,7 @@ class DatasourceProviderService:
|
||||
add datasource oauth provider
|
||||
"""
|
||||
credential_type = CredentialType.OAUTH2
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
|
||||
with redis_client.lock(lock, timeout=60):
|
||||
db_provider_name = name
|
||||
@@ -604,7 +597,6 @@ class DatasourceProviderService:
|
||||
expires_at=expire_at,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def add_datasource_api_key_provider(
|
||||
self,
|
||||
@@ -623,7 +615,7 @@ class DatasourceProviderService:
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
db_provider_name = name or self.generate_next_datasource_provider_name(
|
||||
@@ -670,7 +662,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials=credentials,
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
|
||||
"""
|
||||
@@ -926,7 +917,7 @@ class DatasourceProviderService:
|
||||
update datasource credentials.
|
||||
"""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
||||
@@ -980,7 +971,6 @@ class DatasourceProviderService:
|
||||
encrypted_credentials[key] = value
|
||||
|
||||
datasource_provider.encrypted_credentials = encrypted_credentials
|
||||
session.commit()
|
||||
|
||||
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -40,7 +40,10 @@ class TestDatasourceProviderService:
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
"""
|
||||
with patch("services.datasource_provider_service.Session") as mock_cls:
|
||||
with (
|
||||
patch("services.datasource_provider_service.Session") as mock_cls,
|
||||
patch("services.datasource_provider_service.sessionmaker") as mock_sm,
|
||||
):
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
@@ -63,6 +66,8 @@ class TestDatasourceProviderService:
|
||||
|
||||
mock_cls.return_value.__enter__.return_value = sess
|
||||
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__enter__.return_value = sess
|
||||
mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
yield sess
|
||||
|
||||
@@ -266,7 +271,6 @@ class TestDatasourceProviderService:
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
||||
):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
||||
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
||||
@@ -333,7 +337,6 @@ class TestDatasourceProviderService:
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@@ -352,7 +355,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# set_default_datasource_provider (lines 277-303)
|
||||
@@ -370,7 +372,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().first.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_oauth_encrypter (lines 404-420)
|
||||
@@ -460,7 +461,6 @@ class TestDatasourceProviderService:
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
@@ -512,7 +512,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@@ -523,7 +522,6 @@ class TestDatasourceProviderService:
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
@@ -571,7 +569,6 @@ class TestDatasourceProviderService:
|
||||
):
|
||||
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
@@ -747,7 +744,6 @@ class TestDatasourceProviderService:
|
||||
# encrypter must have been called with the new secret value
|
||||
self._enc.encrypt_token.assert_called()
|
||||
# commit must be called exactly once
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# remove_datasource_credentials (lines 980-997)
|
||||
@@ -758,7 +754,6 @@ class TestDatasourceProviderService:
|
||||
mock_db_session.scalar.return_value = p
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_called_once_with(p)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
||||
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
||||
|
||||
Reference in New Issue
Block a user