diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index faa978afdc..d5f8cd30bd 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -5,7 +5,7 @@ from typing import Any from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -53,13 +53,12 @@ class DatasourceProviderService: """ remove oauth custom client params """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.query(DatasourceOauthTenantParamConfig).filter_by( tenant_id=tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, ).delete() - session.commit() def decrypt_datasource_provider_credentials( self, @@ -109,7 +108,7 @@ class DatasourceProviderService: """ get credential by id """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: if credential_id: datasource_provider = ( session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() @@ -156,7 +155,6 @@ class DatasourceProviderService: datasource_provider=datasource_provider, ) datasource_provider.expires_at = refreshed_credentials.expires_at - session.commit() return self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, @@ -174,7 +172,7 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: datasource_providers = ( session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) @@ -224,7 +222,6 @@ class DatasourceProviderService: provider=provider, ) real_credentials_list.append(real_credentials) - session.commit() return real_credentials_list @@ -234,7 +231,7 @@ class DatasourceProviderService: """ update datasource provider name """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: target_provider = ( session.query(DatasourceProvider) .filter_by( @@ -266,7 +263,6 @@ class DatasourceProviderService: raise ValueError("Authorization name is already exists") target_provider.name = name - session.commit() return def set_default_datasource_provider( @@ -275,7 +271,7 @@ class DatasourceProviderService: """ set default datasource provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get provider target_provider = ( session.query(DatasourceProvider) @@ -300,7 +296,6 @@ class DatasourceProviderService: # set new default provider target_provider.is_default = True - session.commit() return {"result": "success"} def setup_oauth_custom_client_params( @@ -315,7 +310,7 @@ class DatasourceProviderService: """ if client_params is None and enabled is None: return - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: tenant_oauth_client_params = ( session.query(DatasourceOauthTenantParamConfig) .filter_by( @@ -349,7 +344,6 @@ class DatasourceProviderService: if enabled is not None: tenant_oauth_client_params.enabled = enabled - session.commit() def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool: """ @@ -488,7 +482,7 @@ class DatasourceProviderService: """ update datasource oauth provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" with redis_client.lock(lock, timeout=20): target_provider = ( @@ -535,7 +529,6 @@ class DatasourceProviderService: target_provider.expires_at = expire_at target_provider.encrypted_credentials = credentials target_provider.avatar_url = avatar_url or target_provider.avatar_url - session.commit() def add_datasource_oauth_provider( self, @@ -550,7 +543,7 @@ class DatasourceProviderService: add datasource oauth provider """ credential_type = CredentialType.OAUTH2 - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}" with redis_client.lock(lock, timeout=60): db_provider_name = name @@ -604,7 +597,6 @@ class DatasourceProviderService: expires_at=expire_at, ) session.add(datasource_provider) - session.commit() def add_datasource_api_key_provider( self, @@ -623,7 +615,7 @@ class DatasourceProviderService: provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" with redis_client.lock(lock, timeout=20): db_provider_name = name or self.generate_next_datasource_provider_name( @@ -670,7 +662,6 @@ class DatasourceProviderService: encrypted_credentials=credentials, ) session.add(datasource_provider) - session.commit() def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]: """ @@ -926,7 +917,7 @@ class DatasourceProviderService: update datasource credentials. """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: datasource_provider = ( session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) @@ -980,7 +971,6 @@ class DatasourceProviderService: encrypted_credentials[key] = value datasource_provider.encrypted_credentials = encrypted_credentials - session.commit() def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index bc4120e2af..70ecc158d6 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -40,7 +40,10 @@ class TestDatasourceProviderService: q returns itself for .filter_by(), .order_by(), .where() so any SQLAlchemy chaining pattern works without multiple brittle sub-mocks. """ - with patch("services.datasource_provider_service.Session") as mock_cls: + with ( + patch("services.datasource_provider_service.Session") as mock_cls, + patch("services.datasource_provider_service.sessionmaker") as mock_sm, + ): sess = MagicMock(spec=Session) q = MagicMock() @@ -63,6 +66,8 @@ class TestDatasourceProviderService: mock_cls.return_value.__enter__.return_value = sess mock_cls.return_value.no_autoflush.__enter__.return_value = sess + mock_sm.return_value.begin.return_value.__enter__.return_value = sess + mock_sm.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) yield sess @@ -266,7 +271,6 @@ class TestDatasourceProviderService: patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), ): service.get_datasource_credentials("t1", "prov", "org/plug") - mock_db_session.commit.assert_called_once() def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): """API key credentials with expires_at=-1 skip refresh and return directly.""" @@ -333,7 +337,6 @@ class TestDatasourceProviderService: p.name = "same" mock_db_session.query().first.return_value = p service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") - mock_db_session.commit.assert_not_called() def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) @@ -352,7 +355,6 @@ class TestDatasourceProviderService: mock_db_session.query().count.return_value = 0 service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") assert p.name == "new_name" - mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # set_default_datasource_provider (lines 277-303) @@ -370,7 +372,6 @@ class TestDatasourceProviderService: mock_db_session.query().first.return_value = target service.set_default_datasource_provider("t1", make_id(), "new-id") assert target.is_default is True - mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # get_oauth_encrypter (lines 404-420) @@ -460,7 +461,6 @@ class TestDatasourceProviderService: with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): """Conflict on name results in auto-incremented name, not an error.""" @@ -512,7 +512,6 @@ class TestDatasourceProviderService: mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") - mock_db_session.commit.assert_called_once() def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) @@ -523,7 +522,6 @@ class TestDatasourceProviderService: service.reauthorize_datasource_oauth_provider( "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" ) - mock_db_session.commit.assert_called_once() def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) @@ -571,7 +569,6 @@ class TestDatasourceProviderService: ): service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"}) mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): mock_db_session.query().count.return_value = 0 @@ -747,7 +744,6 @@ class TestDatasourceProviderService: # encrypter must have been called with the new secret value self._enc.encrypt_token.assert_called() # commit must be called exactly once - mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # remove_datasource_credentials (lines 980-997) @@ -758,7 +754,6 @@ class TestDatasourceProviderService: mock_db_session.scalar.return_value = p service.remove_datasource_credentials("t1", "id", "prov", "org/plug") mock_db_session.delete.assert_called_once_with(p) - mock_db_session.commit.assert_called_once() def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): """No error raised; no delete called when record doesn't exist (lines 994 branch)."""