1
0
mirror of synced 2025-12-25 02:09:19 -05:00

bug(cdk) Always return a connection status even if an exception was raised (#45205)

This commit is contained in:
Alexandre Girard
2024-09-26 21:26:26 -07:00
committed by GitHub
parent 65622a9b0a
commit f01a43cc4b
5 changed files with 87 additions and 14 deletions

View File

@@ -20,6 +20,7 @@ from airbyte_cdk.connector import TConfig
from airbyte_cdk.exception_handler import init_uncaught_exception_handler
from airbyte_cdk.logger import init_logger
from airbyte_cdk.models import ( # type: ignore [attr-defined]
AirbyteConnectionStatus,
AirbyteMessage,
AirbyteMessageSerializer,
AirbyteStateStats,
@@ -139,12 +140,28 @@ class AirbyteEntrypoint(object):
self.validate_connection(source_spec, config)
except AirbyteTracedException as traced_exc:
connection_status = traced_exc.as_connection_status_message()
# The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error
# If the failure is not exceptional, we'll emit a failed connection status message and return
if traced_exc.failure_type != FailureType.config_error:
raise traced_exc
if connection_status:
yield from self._emit_queued_messages(self.source)
yield connection_status
return
check_result = self.source.check(self.logger, config)
try:
check_result = self.source.check(self.logger, config)
except AirbyteTracedException as traced_exc:
yield traced_exc.as_airbyte_message()
# The platform uses the exit code to surface unexpected failures so we raise the exception if the failure type not a config error
# If the failure is not exceptional, we'll emit a failed connection status message and return
if traced_exc.failure_type != FailureType.config_error:
raise traced_exc
else:
yield AirbyteMessage(
type=Type.CONNECTION_STATUS, connectionStatus=AirbyteConnectionStatus(status=Status.FAILED, message=traced_exc.message)
)
return
if check_result.status == Status.SUCCEEDED:
self.logger.info("Check succeeded")
else:

View File

@@ -1,7 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
from airbyte_cdk import AirbyteTracedException
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError
from unit_tests.sources.file_based.helpers import (
FailingSchemaValidationPolicy,
@@ -130,7 +130,7 @@ error_empty_stream_scenario = (
_base_failure_scenario.copy()
.set_name("error_empty_stream_scenario")
.set_source_builder(_base_failure_scenario.copy().source_builder.copy().set_files({}))
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.EMPTY_STREAM.value)
.set_expected_check_error(None, FileBasedSourceError.EMPTY_STREAM.value)
).build()
@@ -142,7 +142,7 @@ error_listing_files_scenario = (
TestErrorListMatchingFilesInMemoryFilesStreamReader(files=_base_failure_scenario.source_builder._files, file_type="csv")
)
)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.ERROR_LISTING_FILES.value)
.set_expected_check_error(None, FileBasedSourceError.ERROR_LISTING_FILES.value)
).build()
@@ -154,7 +154,7 @@ error_reading_file_scenario = (
TestErrorOpenFileInMemoryFilesStreamReader(files=_base_failure_scenario.source_builder._files, file_type="csv")
)
)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.ERROR_READING_FILE.value)
.set_expected_check_error(None, FileBasedSourceError.ERROR_READING_FILE.value)
).build()
@@ -216,5 +216,5 @@ error_multi_stream_scenario = (
],
}
)
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.ERROR_READING_FILE.value)
.set_expected_check_error(None, FileBasedSourceError.ERROR_READING_FILE.value)
).build()

View File

@@ -1940,7 +1940,7 @@ schemaless_with_user_input_schema_fails_connection_check_scenario: TestScenario[
}
)
.set_expected_check_status("FAILED")
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_check_error(None, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()
@@ -2030,7 +2030,7 @@ schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario:
}
)
.set_expected_check_status("FAILED")
.set_expected_check_error(AirbyteTracedException, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_check_error(None, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_discover_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()
@@ -3240,7 +3240,6 @@ earlier_csv_scenario: TestScenario[InMemoryFilesSource] = (
}
)
.set_expected_records(None)
.set_expected_check_error(AirbyteTracedException, None)
).build()
csv_no_records_scenario: TestScenario[InMemoryFilesSource] = (
@@ -3340,5 +3339,4 @@ csv_no_files_scenario: TestScenario[InMemoryFilesSource] = (
}
)
.set_expected_records(None)
.set_expected_check_error(AirbyteTracedException, None)
).build()

View File

@@ -206,7 +206,15 @@ def check(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenar
["check", "--config", make_file(tmp_path / "config.json", scenario.config)],
)
captured = capsys.readouterr()
return json.loads(captured.out.splitlines()[0])["connectionStatus"] # type: ignore
return _find_connection_status(captured.out.splitlines())
def _find_connection_status(output: List[str]) -> Mapping[str, Any]:
for line in output:
json_line = json.loads(line)
if "connectionStatus" in json_line:
return json_line["connectionStatus"]
raise ValueError("No valid connectionStatus found in output")
def discover(capsys: CaptureFixture[str], tmp_path: PosixPath, scenario: TestScenario[AbstractSource]) -> Dict[str, Any]:

View File

@@ -10,6 +10,7 @@ from typing import Any, List, Mapping, MutableMapping, Union
from unittest import mock
from unittest.mock import MagicMock, patch
import freezegun
import pytest
import requests
from airbyte_cdk import AirbyteEntrypoint
@@ -32,6 +33,7 @@ from airbyte_cdk.models import (
AirbyteStreamStatusTraceMessage,
AirbyteTraceMessage,
ConnectorSpecification,
FailureType,
OrchestratorType,
Status,
StreamDescriptor,
@@ -151,6 +153,8 @@ def _wrap_message(submessage: Union[AirbyteConnectionStatus, ConnectorSpecificat
message = AirbyteMessage(type=Type.CATALOG, catalog=submessage)
elif isinstance(submessage, AirbyteRecordMessage):
message = AirbyteMessage(type=Type.RECORD, record=submessage)
elif isinstance(submessage, AirbyteTraceMessage):
message = AirbyteMessage(type=Type.TRACE, trace=submessage)
else:
raise Exception(f"Unknown message type: {submessage}")
@@ -219,13 +223,59 @@ def test_run_check(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
assert spec_mock.called
@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = ValueError("Any error")
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=ValueError("Any error"))
mocker.patch.object(MockSource, "check", side_effect=exception)
with pytest.raises(ValueError):
messages = list(entrypoint.run(parsed_args))
assert [orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode()] == messages
list(entrypoint.run(parsed_args))
@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_traced_exception(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = AirbyteTracedException.from_exception(ValueError("Any error"))
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=exception)
with pytest.raises(AirbyteTracedException):
list(entrypoint.run(parsed_args))
@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_config_error(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = AirbyteTracedException.from_exception(ValueError("Any error"))
exception.failure_type = FailureType.config_error
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=exception)
messages = list(entrypoint.run(parsed_args))
expected_trace = exception.as_airbyte_message()
expected_trace.emitted_at = 1
expected_trace.trace.emitted_at = 1
expected_messages = [
orjson.dumps(AirbyteMessageSerializer.dump(MESSAGE_FROM_REPOSITORY)).decode(),
orjson.dumps(AirbyteMessageSerializer.dump(expected_trace)).decode(),
_wrap_message(
AirbyteConnectionStatus(
status=Status.FAILED,
message=AirbyteTracedException.from_exception(exception).message
)
),
]
assert messages == expected_messages
@freezegun.freeze_time("1970-01-01T00:00:00.001Z")
def test_run_check_with_transient_error(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):
exception = AirbyteTracedException.from_exception(ValueError("Any error"))
exception.failure_type = FailureType.transient_error
parsed_args = Namespace(command="check", config="config_path")
mocker.patch.object(MockSource, "check", side_effect=exception)
with pytest.raises(AirbyteTracedException):
list(entrypoint.run(parsed_args))
def test_run_discover(entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock):