1
0
mirror of synced 2025-12-26 14:02:10 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/test/test_entrypoint_wrapper.py

349 lines
16 KiB
Python

# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
import json
import logging
import os
from typing import Any, Iterator, List, Mapping, Optional
from unittest import TestCase
from unittest.mock import Mock, patch
from airbyte_cdk.models import (
AirbyteAnalyticsTraceMessage,
AirbyteCatalog,
AirbyteErrorTraceMessage,
AirbyteLogMessage,
AirbyteMessage,
AirbyteMessageSerializer,
AirbyteRecordMessage,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStreamState,
AirbyteStreamStateSerializer,
AirbyteStreamStatus,
AirbyteStreamStatusTraceMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalogSerializer,
Level,
StreamDescriptor,
TraceType,
Type,
)
from airbyte_cdk.sources.abstract_source import AbstractSource
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, discover, read
from airbyte_cdk.test.state_builder import StateBuilder
from orjson import orjson
def _a_state_message(stream_name: str, stream_state: Mapping[str, Any]) -> AirbyteMessage:
return AirbyteMessage(
type=Type.STATE,
state=AirbyteStateMessage(
stream=AirbyteStreamState(stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob(**stream_state))
),
)
def _a_status_message(stream_name: str, status: AirbyteStreamStatus) -> AirbyteMessage:
return AirbyteMessage(
type=Type.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=0,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name=stream_name),
status=status,
),
),
)
_A_CATALOG_MESSAGE = AirbyteMessage(
type=Type.CATALOG,
catalog=AirbyteCatalog(streams=[]),
)
_A_RECORD = AirbyteMessage(
type=Type.RECORD, record=AirbyteRecordMessage(stream="stream", data={"record key": "record value"}, emitted_at=0)
)
_A_STATE_MESSAGE = _a_state_message("stream_name", {"state key": "state value for _A_STATE_MESSAGE"})
_A_LOG = AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.INFO, message="This is an Airbyte log message"))
_AN_ERROR_MESSAGE = AirbyteMessage(
type=Type.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.ERROR,
emitted_at=0,
error=AirbyteErrorTraceMessage(message="AirbyteErrorTraceMessage message"),
),
)
_AN_ANALYTIC_MESSAGE = AirbyteMessage(
type=Type.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.ANALYTICS,
emitted_at=0,
analytics=AirbyteAnalyticsTraceMessage(type="an analytic type", value="an analytic value"),
),
)
_A_STREAM_NAME = "a stream name"
_A_CONFIG = {"config_key": "config_value"}
_A_CATALOG = ConfiguredAirbyteCatalogSerializer.load(
{
"streams": [
{
"stream": {
"name": "a_stream_name",
"json_schema": {},
"supported_sync_modes": ["full_refresh"],
},
"sync_mode": "full_refresh",
"destination_sync_mode": "append",
}
]
}
)
_A_STATE = StateBuilder().with_stream_state(_A_STREAM_NAME, {"state_key": "state_value"}).build()
_A_LOG_MESSAGE = "a log message"
def _to_entrypoint_output(messages: List[AirbyteMessage]) -> Iterator[str]:
return (orjson.dumps(AirbyteMessageSerializer.dump(message)).decode() for message in messages)
def _a_mocked_source() -> AbstractSource:
source = Mock(spec=AbstractSource)
source.message_repository = None
return source
def _validate_tmp_json_file(expected, file_path) -> None:
with open(file_path) as file:
assert json.load(file) == expected
def _validate_tmp_catalog(expected, file_path) -> None:
assert ConfiguredAirbyteCatalogSerializer.load(
orjson.loads(
open(file_path).read()
)
) == expected
def _create_tmp_file_validation(entrypoint, expected_config, expected_catalog: Optional[Any] = None, expected_state: Optional[Any] = None):
def _validate_tmp_files(self):
_validate_tmp_json_file(expected_config, entrypoint.parse_args.call_args.args[0][2])
if expected_catalog:
_validate_tmp_catalog(expected_catalog, entrypoint.parse_args.call_args.args[0][4])
if expected_state:
_validate_tmp_json_file(expected_state, entrypoint.parse_args.call_args.args[0][6])
return entrypoint.return_value.run.return_value
return _validate_tmp_files
class EntrypointWrapperDiscoverTest(TestCase):
def setUp(self) -> None:
self._a_source = _a_mocked_source()
@staticmethod
def test_init_validation_error():
invalid_message = '{"type": "INVALID_TYPE"}'
entrypoint_output = EntrypointOutput([invalid_message])
messages = entrypoint_output._messages
assert len(messages) == 1
assert messages[0].type == Type.LOG
assert messages[0].log.level == Level.INFO
assert messages[0].log.message == invalid_message
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_discover_then_ensure_parameters(self, entrypoint):
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG)
discover(self._a_source, _A_CONFIG)
entrypoint.assert_called_once_with(self._a_source)
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
assert entrypoint.parse_args.call_count == 1
assert entrypoint.parse_args.call_args.args[0][0] == "discover"
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_discover_then_ensure_files_are_temporary(self, entrypoint):
discover(self._a_source, _A_CONFIG)
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_logging_during_discover_when_discover_then_output_has_logs(self, entrypoint):
def _do_some_logging(self):
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
return entrypoint.return_value.run.return_value
entrypoint.return_value.run.side_effect = _do_some_logging
output = discover(self._a_source, _A_CONFIG)
assert len(output.logs) == 1
assert output.logs[0].log.message == _A_LOG_MESSAGE
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_record_when_discover_then_output_has_record(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_CATALOG_MESSAGE])
output = discover(self._a_source, _A_CONFIG)
assert AirbyteMessageSerializer.dump(output.catalog) == AirbyteMessageSerializer.dump(_A_CATALOG_MESSAGE)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_log_when_discover_then_output_has_log(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
output = discover(self._a_source, _A_CONFIG)
assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump(_A_LOG)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_trace_message_when_discover_then_output_has_trace_messages(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
output = discover(self._a_source, _A_CONFIG)
assert AirbyteMessageSerializer.dump(output.analytics_messages[0]) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE)
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_unexpected_exception_when_discover_then_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
discover(self._a_source, _A_CONFIG)
assert print_mock.call_count > 0
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_expected_exception_when_discover_then_do_not_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
discover(self._a_source, _A_CONFIG, expecting_exception=True)
assert print_mock.call_count == 0
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
entrypoint.return_value.run.side_effect = ValueError("An error")
output = discover(self._a_source, _A_CONFIG)
assert output.errors
class EntrypointWrapperReadTest(TestCase):
def setUp(self) -> None:
self._a_source = _a_mocked_source()
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_read_then_ensure_parameters(self, entrypoint):
entrypoint.return_value.run.side_effect = _create_tmp_file_validation(entrypoint, _A_CONFIG, _A_CATALOG, _A_STATE)
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
entrypoint.assert_called_once_with(self._a_source)
entrypoint.return_value.run.assert_called_once_with(entrypoint.parse_args.return_value)
assert entrypoint.parse_args.call_count == 1
assert entrypoint.parse_args.call_args.args[0][0] == "read"
assert entrypoint.parse_args.call_args.args[0][1] == "--config"
assert entrypoint.parse_args.call_args.args[0][3] == "--catalog"
assert entrypoint.parse_args.call_args.args[0][5] == "--state"
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_when_read_then_ensure_files_are_temporary(self, entrypoint):
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][2])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][4])
assert not os.path.exists(entrypoint.parse_args.call_args.args[0][6])
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_logging_during_run_when_read_then_output_has_logs(self, entrypoint):
def _do_some_logging(self):
logging.getLogger("any logger").info(_A_LOG_MESSAGE)
return entrypoint.return_value.run.return_value
entrypoint.return_value.run.side_effect = _do_some_logging
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert len(output.logs) == 1
assert output.logs[0].log.message == _A_LOG_MESSAGE
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_record_when_read_then_output_has_record(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert AirbyteMessageSerializer.dump(output.records[0]) == AirbyteMessageSerializer.dump(_A_RECORD)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_state_message_when_read_then_output_has_state_message(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert AirbyteMessageSerializer.dump(output.state_messages[0]) == AirbyteMessageSerializer.dump(_A_STATE_MESSAGE)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_state_message_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_RECORD, _A_STATE_MESSAGE])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert [AirbyteMessageSerializer.dump(message) for message in output.records_and_state_messages] == [
AirbyteMessageSerializer.dump(message) for message in (_A_RECORD, _A_STATE_MESSAGE)
]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_many_state_messages_and_records_when_read_then_output_has_records_and_state_message(self, entrypoint):
state_value = {"state_key": "last state value"}
last_emitted_state = AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="stream_name"), stream_state=AirbyteStateBlob(**state_value)
)
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_STATE_MESSAGE, _a_state_message("stream_name", state_value)])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert AirbyteStreamStateSerializer.dump(output.most_recent_state) == AirbyteStreamStateSerializer.dump(last_emitted_state)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_log_when_read_then_output_has_log(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_A_LOG])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert AirbyteMessageSerializer.dump(output.logs[0]) == AirbyteMessageSerializer.dump(_A_LOG)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_trace_message_when_read_then_output_has_trace_messages(self, entrypoint):
entrypoint.return_value.run.return_value = _to_entrypoint_output([_AN_ANALYTIC_MESSAGE])
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert AirbyteMessageSerializer.dump(output.analytics_messages[0]) == AirbyteMessageSerializer.dump(_AN_ANALYTIC_MESSAGE)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_stream_statuses_when_read_then_return_statuses(self, entrypoint):
status_messages = [
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED),
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.COMPLETE),
]
entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages)
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.get_stream_statuses(_A_STREAM_NAME) == [AirbyteStreamStatus.STARTED, AirbyteStreamStatus.COMPLETE]
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_stream_statuses_for_many_streams_when_read_then_filter_other_streams(self, entrypoint):
status_messages = [
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.STARTED),
_a_status_message("another stream name", AirbyteStreamStatus.INCOMPLETE),
_a_status_message(_A_STREAM_NAME, AirbyteStreamStatus.COMPLETE),
]
entrypoint.return_value.run.return_value = _to_entrypoint_output(status_messages)
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert len(output.get_stream_statuses(_A_STREAM_NAME)) == 2
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_unexpected_exception_when_read_then_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should be printed")
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert print_mock.call_count > 0
@patch("airbyte_cdk.test.entrypoint_wrapper.print", create=True)
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_expected_exception_when_read_then_do_not_print(self, entrypoint, print_mock):
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE, expecting_exception=True)
assert print_mock.call_count == 0
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
def test_given_uncaught_exception_when_read_then_output_has_error(self, entrypoint):
entrypoint.return_value.run.side_effect = ValueError("An error")
output = read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE)
assert output.errors