1
0
mirror of synced 2025-12-25 02:09:19 -05:00

[ISSUE #7865] use HttpClient for bulk, authentication and schema disc… (#38563)

This commit is contained in:
Maxime Carbonneau-Leclerc
2024-05-27 09:56:29 -04:00
committed by GitHub
parent e2565fe21c
commit c6f8d23eef
13 changed files with 288 additions and 232 deletions

View File

@@ -27,14 +27,12 @@ from airbyte_cdk.models import (
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.test.catalog_builder import CatalogBuilder
from airbyte_cdk.test.entrypoint_wrapper import read
from airbyte_cdk.test.state_builder import StateBuilder
from airbyte_cdk.utils import AirbyteTracedException
from conftest import encoding_symbols_parameters, generate_stream
from requests.exceptions import ChunkedEncodingError, HTTPError
from requests.exceptions import ChunkedEncodingError
from salesforce_job_response_builder import JobInfoResponseBuilder
from source_salesforce.api import Salesforce
from source_salesforce.exceptions import AUTHENTICATION_ERROR_MESSAGE_MAPPING
from source_salesforce.source import SourceSalesforce
from source_salesforce.streams import (
CSV_FIELD_SIZE_LIMIT,
@@ -84,30 +82,27 @@ def test_stream_slice_step_validation(stream_slice_step: str, expected_error_mes
@pytest.mark.parametrize(
"login_status_code, login_json_resp, expected_error_msg, is_config_error",
"login_status_code, login_json_resp, expected_error_msg",
[
(
400,
{"error": "invalid_grant", "error_description": "expired access/refresh token"},
AUTHENTICATION_ERROR_MESSAGE_MAPPING.get("expired access/refresh token"),
True,
"The authentication to SalesForce has expired. Re-authenticate to restore access to SalesForce.",
),
(
400,
{"error": "invalid_grant", "error_description": "Authentication failure."},
'An error occurred: {"error": "invalid_grant", "error_description": "Authentication failure."}',
False,
),
(
401,
{"error": "Unauthorized", "error_description": "Unautorized"},
'An error occurred: {"error": "Unauthorized", "error_description": "Unautorized"}',
False,
),
],
)
def test_login_authentication_error_handler(
stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg, is_config_error
stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg
):
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
logger = logging.getLogger("airbyte")
@@ -115,24 +110,19 @@ def test_login_authentication_error_handler(
"POST", "https://login.salesforce.com/services/oauth2/token", json=login_json_resp, status_code=login_status_code
)
if is_config_error:
with pytest.raises(AirbyteTracedException) as err:
source.check_connection(logger, stream_config)
assert err.value.message == expected_error_msg
else:
result, msg = source.check_connection(logger, stream_config)
assert result is False
assert msg == expected_error_msg
with pytest.raises(AirbyteTracedException) as err:
source.check_connection(logger, stream_config)
assert err.value.message == expected_error_msg
def test_bulk_sync_creation_failed(stream_config, stream_api):
stream: BulkIncrementalSalesforceStream = generate_stream("Account", stream_config, stream_api)
with requests_mock.Mocker() as m:
m.register_uri("POST", stream.path(), status_code=400, json=[{"message": "test_error"}])
with pytest.raises(HTTPError) as err:
with pytest.raises(AirbyteTracedException) as err:
stream_slices = next(iter(stream.stream_slices(sync_mode=SyncMode.incremental)))
next(stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices))
assert err.value.response.json()[0]["message"] == "test_error"
assert "test_error" in str(err.value.message)
def test_bulk_stream_fallback_to_rest(mocker, requests_mock, stream_config, stream_api):
@@ -340,12 +330,6 @@ def test_encoding_symbols(stream_config, stream_api, chunk_size, content_type_he
@pytest.mark.parametrize(
"login_status_code, login_json_resp, discovery_status_code, discovery_resp_json, expected_error_msg",
(
(
403,
[{"errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded."}],
200,
{},
"API Call limit is exceeded. Make sure that you have enough API allocation for your organization needs or retry later. For more information, see https://developer.salesforce.com/docs/atlas.en-us.salesforce_app_limits_cheatsheet.meta/salesforce_app_limits_cheatsheet/salesforce_app_limits_platform_api.htm"),
(
200,
{"access_token": "access_token", "instance_url": "https://instance_url"},
@@ -366,9 +350,9 @@ def test_check_connection_rate_limit(
m.register_uri(
"GET", "https://instance_url/services/data/v57.0/sobjects", json=discovery_resp_json, status_code=discovery_status_code
)
result, msg = source.check_connection(logger, stream_config)
assert result is False
assert msg == expected_error_msg
with pytest.raises(AirbyteTracedException) as exception:
source.check_connection(logger, stream_config)
assert exception.value.message == expected_error_msg
def configure_request_params_mock(stream_1, stream_2):
@@ -446,50 +430,47 @@ def test_csv_reader_dialect_unix():
assert result == data
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_retryable_error_when_download_data_then_retry(send_http_request_patch):
send_http_request_patch.return_value.iter_content.side_effect = [HTTPError(), _A_CHUNKED_RESPONSE]
@pytest.fixture(name="mocked_response")
def _create_mocked_response():
http_response = Mock()
http_response.headers = {}
return http_response
@patch("source_salesforce.streams.HttpClient")
def test_given_retryable_error_when_download_data_then_retry(mocked_http_client, mocked_response):
mocked_http_client.return_value.send_request.return_value = (Mock(), mocked_response)
mocked_response.iter_content.side_effect = [ChunkedEncodingError(), _A_CHUNKED_RESPONSE]
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).download_data(url="any url")
assert send_http_request_patch.call_count == 2
assert mocked_response.iter_content.call_count == 2
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_first_download_fail_when_download_data_then_retry_job_only_once(send_http_request_patch):
@patch("source_salesforce.streams.HttpClient")
def test_given_first_download_fail_when_download_data_then_retry_job_only_once(mocked_http_client, mocked_response):
sf_api = Mock()
sf_api.generate_schema.return_value = JobInfoResponseBuilder().with_state("JobComplete").get_response()
sf_api.instance_url = "http://test_given_first_download_fail_when_download_data_then_retry_job.com"
job_creation_return_values = [_A_JSON_RESPONSE, _A_SUCCESSFUL_JOB_CREATION_RESPONSE]
send_http_request_patch.return_value.json.side_effect = job_creation_return_values * 2
send_http_request_patch.return_value.iter_content.side_effect = HTTPError()
mocked_http_client.return_value.send_request.return_value = (Mock(), mocked_response)
mocked_response.json.side_effect = job_creation_return_values * 2
mocked_response.iter_content.side_effect = ChunkedEncodingError()
with pytest.raises(Exception):
list(BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=sf_api, pk=_A_PK).read_records(SyncMode.full_refresh))
assert send_http_request_patch.call_count == (len(job_creation_return_values) + _NUMBER_OF_DOWNLOAD_TRIES) * 2
assert mocked_response.json.call_count == len(job_creation_return_values) * 2
assert mocked_response.iter_content.call_count == _NUMBER_OF_DOWNLOAD_TRIES * 2
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_http_errors_when_create_stream_job_then_retry(send_http_request_patch):
send_http_request_patch.return_value.json.side_effect = [HTTPError(), _A_JSON_RESPONSE]
@patch("source_salesforce.streams.HttpClient")
def test_given_retryable_error_that_are_not_http_errors_when_create_stream_job_then_retry(mocked_http_client, mocked_response):
mocked_http_client.return_value.send_request.return_value = (Mock(), mocked_response)
mocked_response.json.side_effect = [ChunkedEncodingError(), _A_JSON_RESPONSE]
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
assert send_http_request_patch.call_count == 2
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_fail_with_http_errors_when_create_stream_job_then_handle_error(send_http_request_patch):
mocked_response = Mock()
mocked_response.status_code = 666
send_http_request_patch.return_value.json.side_effect = HTTPError(response=mocked_response)
with pytest.raises(HTTPError):
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
def test_given_retryable_error_that_are_not_http_errors_when_create_stream_job_then_retry(send_http_request_patch):
send_http_request_patch.return_value.json.side_effect = [ChunkedEncodingError(), _A_JSON_RESPONSE]
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
assert send_http_request_patch.call_count == 2
assert mocked_http_client.return_value.send_request.call_count == 2
@pytest.mark.parametrize(

View File

@@ -69,7 +69,7 @@ class FullRefreshTest(TestCase):
self._config = ConfigBuilder().client_id(_CLIENT_ID).client_secret(_CLIENT_SECRET).refresh_token(_REFRESH_TOKEN)
@HttpMocker()
def test_given_error_on_fetch_chunk_when_read_then_retry(self, http_mocker: HttpMocker) -> None:
def test_given_error_on_fetch_chunk_of_properties_when_read_then_retry(self, http_mocker: HttpMocker) -> None:
given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
given_stream(http_mocker, _BASE_URL, _STREAM_NAME, SalesforceDescribeResponseBuilder().field(_A_FIELD_NAME))
http_mocker.get(

View File

@@ -7,7 +7,7 @@ from airbyte_cdk.test.mock_http.response_builder import HttpResponseBuilder, fin
class JobCreateResponseBuilder:
def __init__(self):
def __init__(self) -> None:
self._response = {
"id": "any_id",
"operation": "query",
@@ -37,7 +37,7 @@ class JobCreateResponseBuilder:
class JobInfoResponseBuilder:
def __init__(self):
def __init__(self) -> None:
self._response = find_template("job_response", __file__)
self._status_code = 200