diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index d5f8cd30bd..9e7de36593 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from typing import Any from graphon.model_runtime.entities.provider_entities import FormType -from sqlalchemy import func, select +from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -54,11 +54,13 @@ class DatasourceProviderService: remove oauth custom client params """ with sessionmaker(bind=db.engine).begin() as session: - session.query(DatasourceOauthTenantParamConfig).filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - ).delete() + session.execute( + delete(DatasourceOauthTenantParamConfig).where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + ) def decrypt_datasource_provider_credentials( self, @@ -110,15 +112,21 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: if credential_id: - datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id) + .limit(1) ) else: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() + .limit(1) ) if not datasource_provider: return {} @@ -173,12 +181,15 @@ class DatasourceProviderService: get all datasource credentials by provider """ with sessionmaker(bind=db.engine).begin() as session: - datasource_providers = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_providers = session.scalars( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .all() - ) + ).all() if not datasource_providers: return [] current_user = get_current_user() @@ -232,15 +243,15 @@ class DatasourceProviderService: update datasource provider name """ with sessionmaker(bind=db.engine).begin() as session: - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -250,16 +261,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=name, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: raise ValueError("Authorization name is already exists") target_provider.name = name @@ -273,26 +284,31 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=target_provider.provider, - plugin_id=target_provider.plugin_id, - is_default=True, - ).update({"is_default": False}) + session.execute( + update(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == target_provider.provider, + DatasourceProvider.plugin_id == target_provider.plugin_id, + DatasourceProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True @@ -311,14 +327,14 @@ class DatasourceProviderService: if client_params is None and enabled is None: return with sessionmaker(bind=db.engine).begin() as session: - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if not tenant_oauth_client_params: @@ -351,9 +367,14 @@ class DatasourceProviderService: """ with Session(db.engine).no_autoflush as session: return ( - session.query(DatasourceOauthParamConfig) - .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id) - .first() + session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + .limit(1) + ) is not None ) @@ -423,15 +444,15 @@ class DatasourceProviderService: plugin_id = datasource_provider_id.plugin_id with Session(db.engine).no_autoflush as session: # get tenant oauth client params - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - enabled=True, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == provider, + DatasourceOauthTenantParamConfig.plugin_id == plugin_id, + DatasourceOauthTenantParamConfig.enabled.is_(True), ) - .first() + .limit(1) ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) @@ -443,8 +464,13 @@ class DatasourceProviderService: is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if is_verified: # fallback to system oauth client params - oauth_client_params = ( - session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + oauth_client_params = session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == provider, + DatasourceOauthParamConfig.plugin_id == plugin_id, + ) + .limit(1) ) if oauth_client_params: return oauth_client_params.system_credentials @@ -455,15 +481,13 @@ class DatasourceProviderService: def generate_next_datasource_provider_name( session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, + db_providers = session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, ) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -485,8 +509,10 @@ class DatasourceProviderService: with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" with redis_client.lock(lock, timeout=20): - target_provider = ( - session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id) + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -496,25 +522,28 @@ class DatasourceProviderService: db_provider_name = target_provider.name else: name_conflict = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=CredentialType.OAUTH2.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == CredentialType.OAUTH2.value, + ) ) - .count() + or 0 ) if name_conflict > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -556,25 +585,27 @@ class DatasourceProviderService: ) else: if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=credential_type.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == credential_type.value, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -627,11 +658,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.plugin_id == plugin_id, + DatasourceProvider.provider == provider_name, + DatasourceProvider.name == db_provider_name, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") try: @@ -918,21 +954,31 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) - .first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == auth_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .limit(1) ) if not datasource_provider: raise ValueError("Datasource provider not found") # update name if name and name != datasource_provider.name: if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") datasource_provider.name = name diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 70ecc158d6..c00a4938bb 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -179,11 +179,11 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session): - mock_db_session.query().first.return_value = MagicMock() + mock_db_session.scalar.return_value = MagicMock() assert service.is_system_oauth_params_exist(make_id()) is True def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None assert service.is_system_oauth_params_exist(make_id()) is False # ----------------------------------------------------------------------- @@ -205,7 +205,7 @@ class TestDatasourceProviderService: def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session): service.remove_oauth_custom_client_params("t1", make_id()) - mock_db_session.query().delete.assert_called_once() + mock_db_session.execute.assert_called_once() # ----------------------------------------------------------------------- # setup_oauth_custom_client_params (315-351) @@ -217,14 +217,14 @@ class TestDatasourceProviderService: mock_db_session.add.assert_not_called() def test_should_create_new_config_when_none_exists(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True) mock_db_session.add.assert_called_once() def test_should_update_existing_config_when_record_found(self, service, mock_db_session): existing = MagicMock() - mock_db_session.query().first.return_value = existing + mock_db_session.scalar.return_value = existing with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False) mock_db_session.add.assert_not_called() # update in place, no add @@ -255,7 +255,7 @@ class TestDatasourceProviderService: def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user): with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): @@ -264,7 +264,7 @@ class TestDatasourceProviderService: p.auth_type = "oauth2" p.expires_at = 0 # expired p.encrypted_credentials = {"tok": "x"} - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "get_oauth_client", return_value={"oc": "v"}), @@ -278,7 +278,7 @@ class TestDatasourceProviderService: p.auth_type = "api_key" p.expires_at = -1 # sentinel: never expires p.encrypted_credentials = {"k": "v"} - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}), @@ -292,7 +292,7 @@ class TestDatasourceProviderService: p.auth_type = "api_key" p.expires_at = -1 p.encrypted_credentials = {} - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}), @@ -306,7 +306,7 @@ class TestDatasourceProviderService: def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user): with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): - mock_db_session.query().all.return_value = [] + mock_db_session.scalars.return_value.all.return_value = [] assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == [] def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user): @@ -314,7 +314,7 @@ class TestDatasourceProviderService: p.auth_type = "oauth2" p.expires_at = 0 p.encrypted_credentials = {"t": "x"} - mock_db_session.query().all.return_value = [p] + mock_db_session.scalars.return_value.all.return_value = [p] with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "get_oauth_client", return_value={"oc": "v"}), @@ -328,22 +328,21 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(ValueError, match="not found"): service.update_datasource_provider_name("t1", make_id(), "new", "cred-id") def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "same" - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.is_default = False - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 1 # conflict + mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count with pytest.raises(ValueError, match="already exists"): service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") @@ -351,8 +350,7 @@ class TestDatasourceProviderService: p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.is_default = False - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") assert p.name == "new_name" @@ -361,7 +359,7 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(ValueError, match="not found"): service.set_default_datasource_provider("t1", make_id(), "bad-id") @@ -369,7 +367,7 @@ class TestDatasourceProviderService: target = MagicMock(spec=DatasourceProvider) target.provider = "provider" target.plugin_id = "org/plug" - mock_db_session.query().first.return_value = target + mock_db_session.scalar.return_value = target service.set_default_datasource_provider("t1", make_id(), "new-id") assert target.is_default is True @@ -428,13 +426,13 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_use_tenant_config_when_available(self, service, mock_db_session): - mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"}) + mock_db_session.scalar.return_value = MagicMock(client_params={"k": "v"}) with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_oauth_client("t1", make_id()) assert result == {"k": "dec"} def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session): - mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] + mock_db_session.scalar.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] with ( patch.object(service.provider_manager, "fetch_datasource_provider"), patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True), @@ -444,7 +442,7 @@ class TestDatasourceProviderService: def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session): """Neither tenant nor system credentials → raises ValueError.""" - mock_db_session.query().first.side_effect = [None, None] + mock_db_session.scalar.side_effect = [None, None] with ( patch.object(service.provider_manager, "fetch_datasource_provider"), patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False), @@ -457,15 +455,14 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): """Conflict on name results in auto-incremented name, not an error.""" - mock_db_session.query().count.return_value = 1 # conflict first, then auto-named - mock_db_session.query().all.return_value = [] + mock_db_session.scalar.return_value = 1 # conflict first, then auto-named with ( patch.object(service, "extract_secret_variables", return_value=[]), patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"), @@ -475,8 +472,7 @@ class TestDatasourceProviderService: def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session): """name=None causes auto-generation via generate_next_datasource_provider_name.""" - mock_db_session.query().count.return_value = 0 - mock_db_session.query().all.return_value = [] + mock_db_session.scalar.return_value = 0 with ( patch.object(service, "extract_secret_variables", return_value=[]), patch.object(service, "generate_next_datasource_provider_name", return_value="auto"), @@ -485,13 +481,13 @@ class TestDatasourceProviderService: mock_db_session.add.assert_called_once() def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=["secret_key"]): service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"}) self._enc.encrypt_token.assert_called() def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {}) self._redis.lock.assert_called() @@ -501,23 +497,21 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with patch.object(service, "extract_secret_variables", return_value=[]): with pytest.raises(ValueError, match="not found"): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id") def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 1 # conflict - mock_db_session.query().all.return_value = [] + mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count + mock_db_session.scalars.return_value.all.return_value = [] with patch.object(service, "extract_secret_variables", return_value=["tok"]): service.reauthorize_datasource_oauth_provider( "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" @@ -525,16 +519,14 @@ class TestDatasourceProviderService: def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with patch.object(service, "extract_secret_variables", return_value=["tok"]): service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id") self._enc.encrypt_token.assert_called() def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") self._redis.lock.assert_called() @@ -545,13 +537,13 @@ class TestDatasourceProviderService: def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user): """explicit name supplied + conflict → raises ValueError immediately.""" - mock_db_session.query().count.return_value = 1 + mock_db_session.scalar.return_value = 1 with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="already exists"): service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"}) def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")), @@ -561,7 +553,7 @@ class TestDatasourceProviderService: service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"}) def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials"), @@ -571,7 +563,7 @@ class TestDatasourceProviderService: mock_db_session.add.assert_called_once() def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials"), @@ -694,7 +686,7 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="not found"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name") @@ -704,8 +696,7 @@ class TestDatasourceProviderService: p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "e"} - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 1 + mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="already exists"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name") @@ -717,8 +708,7 @@ class TestDatasourceProviderService: p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "e"} - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "extract_secret_variables", return_value=["sk"]), @@ -733,8 +723,7 @@ class TestDatasourceProviderService: p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "old_enc"} - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "extract_secret_variables", return_value=["sk"]),