349 lines
16 KiB
Python
349 lines
16 KiB
Python
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Iterator, List, Mapping, Optional
|
|
from unittest import TestCase
|
|
from unittest.mock import Mock, patch
|
|
|
|
from airbyte_cdk.models import (
|
|
AirbyteAnalyticsTraceMessage,
|
|
AirbyteCatalog,
|
|
AirbyteErrorTraceMessage,
|
|
AirbyteLogMessage,
|
|
AirbyteMessage,
|
|
AirbyteMessageSerializer,
|
|
AirbyteRecordMessage,
|
|
AirbyteStateBlob,
|
|
AirbyteStateMessage,
|
|
AirbyteStreamState,
|
|
AirbyteStreamStateSerializer,
|
|
AirbyteStreamStatus,
|
|
AirbyteStreamStatusTraceMessage,
|
|
AirbyteTraceMessage,
|
|
ConfiguredAirbyteCatalogSerializer,
|
|
Level,
|
|
StreamDescriptor,
|
|
TraceType,
|
|
Type,
|
|
)
|
|
from airbyte_cdk.sources.abstract_source import AbstractSource
|
|
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, discover, read
|
|
from airbyte_cdk.test.state_builder import StateBuilder
|
|
from orjson import orjson
|
|
|
|
|
|
def _a_state_message(stream_name: str, stream_state: Mapping[str, Any]) -> AirbyteMessage:
|
|
return AirbyteMessage(
|
|
type=Type.STATE,
|
|
state=AirbyteStateMessage(
|
|
stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob(**stream_state))
|
|
),
|
|
)
|
|
|
|
|
|
def _a_status_message(stream_name: str, status: AirbyteStreamStatus) -> AirbyteMessage:
|
|
return AirbyteMessage(
|
|
type=Type.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=0,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name=stream_name),
|
|
status=status,
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
_A_CATALOG_MESSAGE = AirbyteMessage(
|
|
type=Type.CATALOG,
|
|
catalog=AirbyteCatalog(streams=[]),
|
|
)
|
|
_A_RECORD = AirbyteMessage(
|
|
type=Type.RECORD, record=AirbyteRecordMessage(stream="stream", data={"record key": "record value"}, emitted_at=0)
|
|
)
|
|
_A_STATE_MESSAGE = _a_state_message("stream_name", {"state key": "state value for _A_STATE_MESSAGE"})
|
|
_A_LOG = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is an Airbyte log message"))
|
|
_AN_ERROR_MESSAGE = AirbyteMessage(
|
|
type=Type.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.ERROR,
|
|
emitted_at=0,
|
|
error=AirbyteErrorTraceMessage(message="AirbyteErrorTraceMessage message"),
|
|
),
|
|
)
|
|
_AN_ANALYTIC_MESSAGE = AirbyteMessage(
|
|
type=Type.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.ANALYTICS,
|
|
emitted_at=0,
|
|
analytics=AirbyteAnalyticsTraceMessage(type="an analytic type", value="an analytic value"),
|
|
),
|
|
)
|
|
|
|
_A_STREAM_NAME = "a stream name"
|
|
_A_CONFIG = {"config_key": "config_value"}
|
|
_A_CATALOG = ConfiguredAirbyteCatalogSerializer.load(
|
|
{
|
|
"streams": [
|
|
{
|
|
"stream": {
|
|
"name": "a_stream_name",
|
|
"json_schema": {},
|
|
"supported_sync_modes": ["full_refresh"],
|
|
},
|
|
"sync_mode": "full_refresh",
|
|
"destination_sync_mode": "append",
|
|
}
|
|
]
|
|
}
|
|
)
|
|
_A_STATE = StateBuilder().with_stream_state(_A_STREAM_NAME, {"state_key": "state_value"}).build()
|
|
_A_LOG_MESSAGE = "a log message"
|
|
|
|
|
|
def _to_entrypoint_output(messages: List[AirbyteMessage]) -> Iterator[str]:
|
|
return (orjson.dumps(AirbyteMessageSerializer.dump(message)).decode() for message in messages)
|
|
|
|
|
|
def _a_mocked_source() -> AbstractSource:
|
|
source = Mock(spec=AbstractSource)
|
|
source.message_repository = None
|
|
return source
|
|
|
|
|
|
def _validate_tmp_json_file(expected, file_path) -> None:
|
|
with open(file_path) as file:
|
|
assert json.load(file) == expected
|
|
|
|
|
|
def _validate_tmp_catalog(expected, file_path) -> None:
|
|
assert ConfiguredAirbyteCatalogSerializer.load(
|
|
orjson.loads(
|
|
open(file_path).read()
|
|
)
|
|
) == expected
|
|
|
|
|
|
def _create_tmp_file_validation(entrypoint, expected_config, expected_catalog: Optional[Any] = None, expected_state: Optional[Any] = None):
|
|
def _validate_tmp_files(self):
|
|
_validate_tmp_json_file(expected_config, entrypoint.parse_args.call_args.args[0][2])
|
|
if expected_catalog:
|
|
_validate_tmp_catalog(expected_catalog, entrypoint.parse_args.call_args.args[0][4])
|
|
if expected_state:
|
|
_validate_tmp_json_file(expected_state, entrypoint.parse_args.call_args.args[0][6])
|
|
return entrypoint.return_value.run.return_value
|
|
|
|
return _validate_tmp_files
|
|
|
|
|
|
class EntrypointWrapperDiscoverTest(TestCase):
|
|
def setUp(self) -> None:
|
|
self._a_source = _a_mocked_source()
|
|
|
|
@staticmethod
|
|
def test_init_validation_error():
|
|
invalid_message = '{"type": "INVALID_TYPE"}'
|
|
entrypoint_output = EntrypointOutput([invalid_message])
|
|
messages = entrypoint_output._messages
|
|
assert len(messages) == 1
|
|
assert messages[0].type == Type.LOG
|
|
assert messages[0].log.level == Level.INFO
|
|
assert messages[0].log.message == invalid_message
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_when_discover_then_ensure_parameters(self, entrypoint):
|
|
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG)
|
|
|
|
discover(self._a_source, _A_CONFIG)
|
|
|
|
entrypoint.assert_called_once_with(self._a_source)
|
|
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
|
|
assert entrypoint.parse_args.call_count == 1
|
|
assert entrypoint.parse_args.call_args.args[0][0] == "discover"
|
|
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_when_discover_then_ensure_files_are_temporary(self, entrypoint):
|
|
discover(self._a_source, _A_CONFIG)
|
|
|
|
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_logging_during_discover_when_discover_then_output_has_logs(self, entrypoint):
|
|
def _do_some_logging(self):
|
|
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
|
|
return entrypoint.return_value.run.return_value
|
|
|
|
entrypoint.return_value.run.side_effect = _do_some_logging
|
|
|
|
output = discover(self._a_source, _A_CONFIG)
|
|
|
|
assert len(output.logs) == 1
|
|
assert output.logs[0].log.message == _A_LOG_MESSAGE
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_record_when_discover_then_output_has_record(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_CATALOG_MESSAGE])
|
|
output = discover(self._a_source, _A_CONFIG)
|
|
assert AirbyteMessageSerializer.dump(output.catalog) == AirbyteMessageSerializer.dump(_A_CATALOG_MESSAGE)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_log_when_discover_then_output_has_log(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
|
|
output = discover(self._a_source, _A_CONFIG)
|
|
assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump(_A_LOG)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_trace_message_when_discover_then_output_has_trace_messages(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
|
|
output = discover(self._a_source, _A_CONFIG)
|
|
assert AirbyteMessageSerializer.dump(output.analytics_messages[0]) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_unexpected_exception_when_discover_then_print(self, entrypoint, print_mock):
|
|
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
|
|
discover(self._a_source, _A_CONFIG)
|
|
assert print_mock.call_count > 0
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_expected_exception_when_discover_then_do_not_print(self, entrypoint, print_mock):
|
|
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
|
|
discover(self._a_source, _A_CONFIG, expecting_exception=True)
|
|
assert print_mock.call_count == 0
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
|
|
entrypoint.return_value.run.side_effect = ValueError("An error")
|
|
output = discover(self._a_source, _A_CONFIG)
|
|
assert output.errors
|
|
|
|
|
|
class EntrypointWrapperReadTest(TestCase):
|
|
def setUp(self) -> None:
|
|
self._a_source = _a_mocked_source()
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_when_read_then_ensure_parameters(self, entrypoint):
|
|
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
|
|
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
|
|
entrypoint.assert_called_once_with(self._a_source)
|
|
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
|
|
assert entrypoint.parse_args.call_count == 1
|
|
assert entrypoint.parse_args.call_args.args[0][0] == "read"
|
|
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
|
|
assert entrypoint.parse_args.call_args.args[0][3] == "--catalog"
|
|
assert entrypoint.parse_args.call_args.args[0][5] == "--state"
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_when_read_then_ensure_files_are_temporary(self, entrypoint):
|
|
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
|
|
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
|
|
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][4])
|
|
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][6])
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_logging_during_run_when_read_then_output_has_logs(self, entrypoint):
|
|
def _do_some_logging(self):
|
|
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
|
|
return entrypoint.return_value.run.return_value
|
|
|
|
entrypoint.return_value.run.side_effect = _do_some_logging
|
|
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
|
|
assert len(output.logs) == 1
|
|
assert output.logs[0].log.message == _A_LOG_MESSAGE
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_record_when_read_then_output_has_record(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD])
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert AirbyteMessageSerializer.dump(output.records[0]) == AirbyteMessageSerializer.dump(_A_RECORD)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_state_message_when_read_then_output_has_state_message(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE])
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert AirbyteMessageSerializer.dump(output.state_messages[0]) == AirbyteMessageSerializer.dump(_A_STATE_MESSAGE)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_state_message_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD, _A_STATE_MESSAGE])
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert [AirbyteMessageSerializer.dump(message) for message in output.records_and_state_messages] == [
|
|
AirbyteMessageSerializer.dump(message) for message in (_A_RECORD, _A_STATE_MESSAGE)
|
|
]
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_many_state_messages_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint):
|
|
state_value = {"state_key": "last state value"}
|
|
last_emitted_state = AirbyteStreamState(
|
|
stream_descriptor=StreamDescriptor(name="stream_name"), stream_state=AirbyteStateBlob(**state_value)
|
|
)
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE, _a_state_message("stream_name", state_value)])
|
|
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
|
|
assert AirbyteStreamStateSerializer.dump(output.most_recent_state) == AirbyteStreamStateSerializer.dump(last_emitted_state)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_log_when_read_then_output_has_log(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump(_A_LOG)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_trace_message_when_read_then_output_has_trace_messages(self, entrypoint):
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert AirbyteMessageSerializer.dump(output.analytics_messages[0]) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE)
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_stream_statuses_when_read_then_return_statuses(self, entrypoint):
|
|
status_messages = [
|
|
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED),
|
|
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.COMPLETE),
|
|
]
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages)
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert output.get_stream_statuses(_A_STREAM_NAME) == [AirbyteStreamStatus.STARTED, AirbyteStreamStatus.COMPLETE]
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_stream_statuses_for_many_streams_when_read_then_filter_other_streams(self, entrypoint):
|
|
status_messages = [
|
|
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED),
|
|
_a_status_message("another stream name", AirbyteStreamStatus.INCOMPLETE),
|
|
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.COMPLETE),
|
|
]
|
|
entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages)
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert len(output.get_stream_statuses(_A_STREAM_NAME)) == 2
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_unexpected_exception_when_read_then_print(self, entrypoint, print_mock):
|
|
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
|
|
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert print_mock.call_count > 0
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_expected_exception_when_read_then_do_not_print(self, entrypoint, print_mock):
|
|
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
|
|
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE, expecting_exception=True)
|
|
assert print_mock.call_count == 0
|
|
|
|
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
|
|
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
|
|
entrypoint.return_value.run.side_effect = ValueError("An error")
|
|
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
|
|
assert output.errors
|