This commit is contained in:
committed by
GitHub
parent
e2565fe21c
commit
c6f8d23eef
@@ -27,14 +27,12 @@ from airbyte_cdk.models import (
|
||||
from airbyte_cdk.sources.streams import Stream
|
||||
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
|
||||
from airbyte_cdk.test.catalog_builder import CatalogBuilder
|
||||
from airbyte_cdk.test.entrypoint_wrapper import read
|
||||
from airbyte_cdk.test.state_builder import StateBuilder
|
||||
from airbyte_cdk.utils import AirbyteTracedException
|
||||
from conftest import encoding_symbols_parameters, generate_stream
|
||||
from requests.exceptions import ChunkedEncodingError, HTTPError
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
from salesforce_job_response_builder import JobInfoResponseBuilder
|
||||
from source_salesforce.api import Salesforce
|
||||
from source_salesforce.exceptions import AUTHENTICATION_ERROR_MESSAGE_MAPPING
|
||||
from source_salesforce.source import SourceSalesforce
|
||||
from source_salesforce.streams import (
|
||||
CSV_FIELD_SIZE_LIMIT,
|
||||
@@ -84,30 +82,27 @@ def test_stream_slice_step_validation(stream_slice_step: str, expected_error_mes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"login_status_code, login_json_resp, expected_error_msg, is_config_error",
|
||||
"login_status_code, login_json_resp, expected_error_msg",
|
||||
[
|
||||
(
|
||||
400,
|
||||
{"error": "invalid_grant", "error_description": "expired access/refresh token"},
|
||||
AUTHENTICATION_ERROR_MESSAGE_MAPPING.get("expired access/refresh token"),
|
||||
True,
|
||||
"The authentication to SalesForce has expired. Re-authenticate to restore access to SalesForce.",
|
||||
),
|
||||
(
|
||||
400,
|
||||
{"error": "invalid_grant", "error_description": "Authentication failure."},
|
||||
'An error occurred: {"error": "invalid_grant", "error_description": "Authentication failure."}',
|
||||
False,
|
||||
),
|
||||
(
|
||||
401,
|
||||
{"error": "Unauthorized", "error_description": "Unautorized"},
|
||||
'An error occurred: {"error": "Unauthorized", "error_description": "Unautorized"}',
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_login_authentication_error_handler(
|
||||
stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg, is_config_error
|
||||
stream_config, requests_mock, login_status_code, login_json_resp, expected_error_msg
|
||||
):
|
||||
source = SourceSalesforce(_ANY_CATALOG, _ANY_CONFIG, _ANY_STATE)
|
||||
logger = logging.getLogger("airbyte")
|
||||
@@ -115,24 +110,19 @@ def test_login_authentication_error_handler(
|
||||
"POST", "https://login.salesforce.com/services/oauth2/token", json=login_json_resp, status_code=login_status_code
|
||||
)
|
||||
|
||||
if is_config_error:
|
||||
with pytest.raises(AirbyteTracedException) as err:
|
||||
source.check_connection(logger, stream_config)
|
||||
assert err.value.message == expected_error_msg
|
||||
else:
|
||||
result, msg = source.check_connection(logger, stream_config)
|
||||
assert result is False
|
||||
assert msg == expected_error_msg
|
||||
with pytest.raises(AirbyteTracedException) as err:
|
||||
source.check_connection(logger, stream_config)
|
||||
assert err.value.message == expected_error_msg
|
||||
|
||||
|
||||
def test_bulk_sync_creation_failed(stream_config, stream_api):
|
||||
stream: BulkIncrementalSalesforceStream = generate_stream("Account", stream_config, stream_api)
|
||||
with requests_mock.Mocker() as m:
|
||||
m.register_uri("POST", stream.path(), status_code=400, json=[{"message": "test_error"}])
|
||||
with pytest.raises(HTTPError) as err:
|
||||
with pytest.raises(AirbyteTracedException) as err:
|
||||
stream_slices = next(iter(stream.stream_slices(sync_mode=SyncMode.incremental)))
|
||||
next(stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slices))
|
||||
assert err.value.response.json()[0]["message"] == "test_error"
|
||||
assert "test_error" in str(err.value.message)
|
||||
|
||||
|
||||
def test_bulk_stream_fallback_to_rest(mocker, requests_mock, stream_config, stream_api):
|
||||
@@ -340,12 +330,6 @@ def test_encoding_symbols(stream_config, stream_api, chunk_size, content_type_he
|
||||
@pytest.mark.parametrize(
|
||||
"login_status_code, login_json_resp, discovery_status_code, discovery_resp_json, expected_error_msg",
|
||||
(
|
||||
(
|
||||
403,
|
||||
[{"errorCode": "REQUEST_LIMIT_EXCEEDED", "message": "TotalRequests Limit exceeded."}],
|
||||
200,
|
||||
{},
|
||||
"API Call limit is exceeded. Make sure that you have enough API allocation for your organization needs or retry later. For more information, see https://developer.salesforce.com/docs/atlas.en-us.salesforce_app_limits_cheatsheet.meta/salesforce_app_limits_cheatsheet/salesforce_app_limits_platform_api.htm"),
|
||||
(
|
||||
200,
|
||||
{"access_token": "access_token", "instance_url": "https://instance_url"},
|
||||
@@ -366,9 +350,9 @@ def test_check_connection_rate_limit(
|
||||
m.register_uri(
|
||||
"GET", "https://instance_url/services/data/v57.0/sobjects", json=discovery_resp_json, status_code=discovery_status_code
|
||||
)
|
||||
result, msg = source.check_connection(logger, stream_config)
|
||||
assert result is False
|
||||
assert msg == expected_error_msg
|
||||
with pytest.raises(AirbyteTracedException) as exception:
|
||||
source.check_connection(logger, stream_config)
|
||||
assert exception.value.message == expected_error_msg
|
||||
|
||||
|
||||
def configure_request_params_mock(stream_1, stream_2):
|
||||
@@ -446,50 +430,47 @@ def test_csv_reader_dialect_unix():
|
||||
assert result == data
|
||||
|
||||
|
||||
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
|
||||
def test_given_retryable_error_when_download_data_then_retry(send_http_request_patch):
|
||||
send_http_request_patch.return_value.iter_content.side_effect = [HTTPError(), _A_CHUNKED_RESPONSE]
|
||||
@pytest.fixture(name="mocked_response")
|
||||
def _create_mocked_response():
|
||||
http_response = Mock()
|
||||
http_response.headers = {}
|
||||
return http_response
|
||||
|
||||
|
||||
@patch("source_salesforce.streams.HttpClient")
|
||||
def test_given_retryable_error_when_download_data_then_retry(mocked_http_client, mocked_response):
|
||||
mocked_http_client.return_value.send_request.return_value = (Mock(), mocked_response)
|
||||
mocked_response.iter_content.side_effect = [ChunkedEncodingError(), _A_CHUNKED_RESPONSE]
|
||||
|
||||
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).download_data(url="any url")
|
||||
assert send_http_request_patch.call_count == 2
|
||||
|
||||
assert mocked_response.iter_content.call_count == 2
|
||||
|
||||
|
||||
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
|
||||
def test_given_first_download_fail_when_download_data_then_retry_job_only_once(send_http_request_patch):
|
||||
@patch("source_salesforce.streams.HttpClient")
|
||||
def test_given_first_download_fail_when_download_data_then_retry_job_only_once(mocked_http_client, mocked_response):
|
||||
sf_api = Mock()
|
||||
sf_api.generate_schema.return_value = JobInfoResponseBuilder().with_state("JobComplete").get_response()
|
||||
sf_api.instance_url = "http://test_given_first_download_fail_when_download_data_then_retry_job.com"
|
||||
job_creation_return_values = [_A_JSON_RESPONSE, _A_SUCCESSFUL_JOB_CREATION_RESPONSE]
|
||||
send_http_request_patch.return_value.json.side_effect = job_creation_return_values * 2
|
||||
send_http_request_patch.return_value.iter_content.side_effect = HTTPError()
|
||||
|
||||
mocked_http_client.return_value.send_request.return_value = (Mock(), mocked_response)
|
||||
mocked_response.json.side_effect = job_creation_return_values * 2
|
||||
mocked_response.iter_content.side_effect = ChunkedEncodingError()
|
||||
|
||||
with pytest.raises(Exception):
|
||||
list(BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=sf_api, pk=_A_PK).read_records(SyncMode.full_refresh))
|
||||
|
||||
assert send_http_request_patch.call_count == (len(job_creation_return_values) + _NUMBER_OF_DOWNLOAD_TRIES) * 2
|
||||
assert mocked_response.json.call_count == len(job_creation_return_values) * 2
|
||||
assert mocked_response.iter_content.call_count == _NUMBER_OF_DOWNLOAD_TRIES * 2
|
||||
|
||||
|
||||
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
|
||||
def test_given_http_errors_when_create_stream_job_then_retry(send_http_request_patch):
|
||||
send_http_request_patch.return_value.json.side_effect = [HTTPError(), _A_JSON_RESPONSE]
|
||||
@patch("source_salesforce.streams.HttpClient")
|
||||
def test_given_retryable_error_that_are_not_http_errors_when_create_stream_job_then_retry(mocked_http_client, mocked_response):
|
||||
mocked_http_client.return_value.send_request.return_value = (Mock(), mocked_response)
|
||||
mocked_response.json.side_effect = [ChunkedEncodingError(), _A_JSON_RESPONSE]
|
||||
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
|
||||
assert send_http_request_patch.call_count == 2
|
||||
|
||||
|
||||
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
|
||||
def test_given_fail_with_http_errors_when_create_stream_job_then_handle_error(send_http_request_patch):
|
||||
mocked_response = Mock()
|
||||
mocked_response.status_code = 666
|
||||
send_http_request_patch.return_value.json.side_effect = HTTPError(response=mocked_response)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
|
||||
|
||||
|
||||
@patch("source_salesforce.source.BulkSalesforceStream._non_retryable_send_http_request")
|
||||
def test_given_retryable_error_that_are_not_http_errors_when_create_stream_job_then_retry(send_http_request_patch):
|
||||
send_http_request_patch.return_value.json.side_effect = [ChunkedEncodingError(), _A_JSON_RESPONSE]
|
||||
BulkSalesforceStream(stream_name=_A_STREAM_NAME, sf_api=Mock(), pk=_A_PK).create_stream_job(query="any query", url="any url")
|
||||
assert send_http_request_patch.call_count == 2
|
||||
assert mocked_http_client.return_value.send_request.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -69,7 +69,7 @@ class FullRefreshTest(TestCase):
|
||||
self._config = ConfigBuilder().client_id(_CLIENT_ID).client_secret(_CLIENT_SECRET).refresh_token(_REFRESH_TOKEN)
|
||||
|
||||
@HttpMocker()
|
||||
def test_given_error_on_fetch_chunk_when_read_then_retry(self, http_mocker: HttpMocker) -> None:
|
||||
def test_given_error_on_fetch_chunk_of_properties_when_read_then_retry(self, http_mocker: HttpMocker) -> None:
|
||||
given_authentication(http_mocker, _CLIENT_ID, _CLIENT_SECRET, _REFRESH_TOKEN, _INSTANCE_URL)
|
||||
given_stream(http_mocker, _BASE_URL, _STREAM_NAME, SalesforceDescribeResponseBuilder().field(_A_FIELD_NAME))
|
||||
http_mocker.get(
|
||||
|
||||
@@ -7,7 +7,7 @@ from airbyte_cdk.test.mock_http.response_builder import HttpResponseBuilder, fin
|
||||
|
||||
|
||||
class JobCreateResponseBuilder:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._response = {
|
||||
"id": "any_id",
|
||||
"operation": "query",
|
||||
@@ -37,7 +37,7 @@ class JobCreateResponseBuilder:
|
||||
|
||||
|
||||
class JobInfoResponseBuilder:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._response = find_template("job_response", __file__)
|
||||
self._status_code = 200
|
||||
|
||||
|
||||
Reference in New Issue
Block a user