diff --git a/airbyte-cdk/python/airbyte_cdk/base_python/cdk/abstract_source.py b/airbyte-cdk/python/airbyte_cdk/base_python/cdk/abstract_source.py index 8f4624ad1be..6f765457f96 100644 --- a/airbyte-cdk/python/airbyte_cdk/base_python/cdk/abstract_source.py +++ b/airbyte-cdk/python/airbyte_cdk/base_python/cdk/abstract_source.py @@ -97,8 +97,13 @@ class AbstractSource(Source, ABC): # get the streams once in case the connector needs to make any queries to generate them stream_instances = {s.name: s for s in self.streams(config)} for configured_stream in catalog.streams: + stream_instance = stream_instances.get(configured_stream.stream.name) + if not stream_instance: + raise KeyError( + f"The requested stream {configured_stream.stream.name} was not found in the source. Available streams: {stream_instances.keys()}" + ) + try: - stream_instance = stream_instances[configured_stream.stream.name] yield from self._read_stream( logger=logger, stream_instance=stream_instance, configured_stream=configured_stream, connector_state=connector_state ) diff --git a/airbyte-cdk/python/type_check_and_test.sh b/airbyte-cdk/python/type_check_and_test.sh index 17f2de1a13c..6e1d405a582 100755 --- a/airbyte-cdk/python/type_check_and_test.sh +++ b/airbyte-cdk/python/type_check_and_test.sh @@ -12,4 +12,4 @@ printf "\n" # Test with Coverage Report echo "Running tests.." # The -s flag instructs PyTest to capture stdout logging; simplifying debugging. -pytest -s --cov=airbyte_cdk unit_tests/ +pytest -s -vv --cov=airbyte_cdk unit_tests/ diff --git a/airbyte-cdk/python/unit_tests/base_python/test_abstract_source.py b/airbyte-cdk/python/unit_tests/base_python/test_abstract_source.py index e840ade72e6..bad04583830 100644 --- a/airbyte-cdk/python/unit_tests/base_python/test_abstract_source.py +++ b/airbyte-cdk/python/unit_tests/base_python/test_abstract_source.py @@ -21,11 +21,24 @@ # SOFTWARE. -from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union +from collections import defaultdict +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import pytest -from airbyte_cdk import AirbyteCatalog, AirbyteConnectionStatus, AirbyteStream, Status, SyncMode +from airbyte_cdk import ( + AirbyteCatalog, + AirbyteConnectionStatus, + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStream, + ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, + Status, + SyncMode, + Type, +) from airbyte_cdk.base_python import AbstractSource, AirbyteLogger, Stream +from airbyte_cdk.models import AirbyteStateMessage, DestinationSyncMode class MockSource(AbstractSource): @@ -34,9 +47,13 @@ class MockSource(AbstractSource): self.check_lambda = check_lambda def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]: - return self.check_lambda() + if self.check_lambda: + return self.check_lambda() + return (False, "Missing callable.") def streams(self, config: Mapping[str, Any]) -> List[Stream]: + if not self._streams: + raise Exception("Stream is not set") return self._streams @@ -46,29 +63,41 @@ def logger() -> AirbyteLogger: def test_successful_check(): + """ Tests that if a source returns TRUE for the connection check the appropriate connectionStatus success message is returned """ expected = AirbyteConnectionStatus(status=Status.SUCCEEDED) assert expected == MockSource(check_lambda=lambda: (True, None)).check(logger, {}) def test_failed_check(): + """ Tests that if a source returns FALSE for the connection check the appropriate connectionStatus failure message is returned """ expected = AirbyteConnectionStatus(status=Status.FAILED, message="womp womp") assert expected == MockSource(check_lambda=lambda: (False, "womp womp")).check(logger, {}) def test_raising_check(): + """ Tests that if a source raises an unexpected exception the connection check the appropriate connectionStatus failure message is returned """ expected = AirbyteConnectionStatus(status=Status.FAILED, message=f"{Exception('this should fail')}") assert expected == MockSource(check_lambda=lambda: exec('raise Exception("this should fail")')).check(logger, {}) class MockStream(Stream): - def read_records( - self, - sync_mode: SyncMode, - cursor_field: List[str] = None, - stream_slice: Mapping[str, Any] = None, - stream_state: Mapping[str, Any] = None, - ) -> Iterable[Mapping[str, Any]]: - pass + def __init__(self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]] = None, name: str = None): + self._inputs_and_mocked_outputs = inputs_and_mocked_outputs + self._name = name + + @property + def name(self): + return self._name + + def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore + # Remove None values + kwargs = {k: v for k, v in kwargs.items() if v is not None} + if self._inputs_and_mocked_outputs: + for _input, output in self._inputs_and_mocked_outputs: + if kwargs == _input: + return output + + raise Exception(f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}") @property def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: @@ -76,6 +105,7 @@ class MockStream(Stream): def test_discover(mocker): + """ Tests that the appropriate AirbyteCatalog is returned from the discover method """ airbyte_stream1 = AirbyteStream( name="1", json_schema={}, @@ -97,25 +127,237 @@ def test_discover(mocker): assert expected == src.discover(logger, {}) -def test_read_nonexistent_stream_raises_exception(): - pass +def test_read_nonexistent_stream_raises_exception(mocker, logger): + """Tests that attempting to sync a stream which the source does not return from the `streams` method raises an exception """ + s1 = MockStream(name="s1") + s2 = MockStream(name="this_stream_doesnt_exist_in_the_source") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = MockSource(streams=[s1]) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s2, SyncMode.full_refresh)]) + with pytest.raises(KeyError): + list(src.read(logger, {}, catalog)) -def test_valid_fullrefresh_read_no_slices(): - pass +GLOBAL_EMITTED_AT = 1 -def test_valid_full_refresh_read_with_slices(): - pass +def _as_record(stream: str, data: Dict[str, Any]) -> AirbyteMessage: + return AirbyteMessage(type=Type.RECORD, record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=GLOBAL_EMITTED_AT)) -def test_valid_incremental_read_with_record_interval(): - pass +def _as_records(stream: str, data: List[Dict[str, Any]]) -> List[AirbyteMessage]: + return [_as_record(stream, datum) for datum in data] -def test_valid_incremental_read_with_no_interval(): - pass +def _configured_stream(stream: Stream, sync_mode: SyncMode): + return ConfiguredAirbyteStream( + stream=stream.as_airbyte_stream(), sync_mode=sync_mode, destination_sync_mode=DestinationSyncMode.overwrite + ) -def test_valid_incremental_read_with_slices(): - pass +def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]: + for msg in messages: + if msg.type == Type.RECORD and msg.record: + msg.record.emitted_at = GLOBAL_EMITTED_AT + return messages + + +def test_valid_full_refresh_read_no_slices(logger, mocker): + """Tests that running a full refresh sync on streams which don't specify slices produces the expected AirbyteMessages""" + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1") + s2 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s2") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = MockSource(streams=[s1, s2]) + catalog = ConfiguredAirbyteCatalog( + streams=[_configured_stream(s1, SyncMode.full_refresh), _configured_stream(s2, SyncMode.full_refresh)] + ) + + expected = _as_records("s1", stream_output) + _as_records("s2", stream_output) + messages = _fix_emitted_at(list(src.read(logger, {}, catalog))) + + assert expected == messages + + +def test_valid_full_refresh_read_with_slices(mocker, logger): + """Tests that running a full refresh sync on streams which use slices produces the expected AirbyteMessages""" + slices = [{"1": "1"}, {"2": "2"}] + # When attempting to sync a slice, just output that slice as a record + s1 = MockStream([({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], name="s1") + s2 = MockStream([({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices], name="s2") + + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", return_value=slices) + + src = MockSource(streams=[s1, s2]) + catalog = ConfiguredAirbyteCatalog( + streams=[_configured_stream(s1, SyncMode.full_refresh), _configured_stream(s2, SyncMode.full_refresh)] + ) + + expected = [*_as_records("s1", slices), *_as_records("s2", slices)] + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog))) + + assert expected == messages + + +def _state(state_data: Dict[str, Any]): + return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=state_data)) + + +def test_valid_incremental_read_with_checkpoint_interval(mocker, logger): + """Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message after reading N records within a stream""" + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + s1 = MockStream([({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], name="s1") + s2 = MockStream([({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], name="s2") + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + # Tell the source to output one state message per record + mocker.patch.object(MockStream, "state_checkpoint_interval", new_callable=mocker.PropertyMock, return_value=1) + + src = MockSource(streams=[s1, s2]) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s1, SyncMode.incremental), _configured_stream(s2, SyncMode.incremental)]) + + expected = [ + _as_record("s1", stream_output[0]), + _state({"s1": state}), + _as_record("s1", stream_output[1]), + _state({"s1": state}), + _state({"s1": state}), + _as_record("s2", stream_output[0]), + _state({"s1": state, "s2": state}), + _as_record("s2", stream_output[1]), + _state({"s1": state, "s2": state}), + _state({"s1": state, "s2": state}), + ] + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict)))) + + assert expected == messages + + +def test_valid_incremental_read_with_no_interval(mocker, logger): + """Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message only after fully reading the stream and does + not output any STATE messages during syncing the stream.""" + stream_output = [{"k1": "v1"}, {"k2": "v2"}] + s1 = MockStream([({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], name="s1") + s2 = MockStream([({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)], name="s2") + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + + src = MockSource(streams=[s1, s2]) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s1, SyncMode.incremental), _configured_stream(s2, SyncMode.incremental)]) + + expected = [ + *_as_records("s1", stream_output), + _state({"s1": state}), + *_as_records("s2", stream_output), + _state({"s1": state, "s2": state}), + ] + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict)))) + + assert expected == messages + + +def test_valid_incremental_read_with_slices(mocker, logger): + """Tests that an incremental read which uses slices outputs each record in the slice followed by a STATE message, for each slice""" + slices = [{"1": "1"}, {"2": "2"}] + stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}] + s1 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": mocker.ANY}, stream_output) for s in slices], name="s1" + ) + s2 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": mocker.ANY}, stream_output) for s in slices], name="s2" + ) + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", return_value=slices) + + src = MockSource(streams=[s1, s2]) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s1, SyncMode.incremental), _configured_stream(s2, SyncMode.incremental)]) + + expected = [ + # stream 1 slice 1 + *_as_records("s1", stream_output), + _state({"s1": state}), + # stream 1 slice 2 + *_as_records("s1", stream_output), + _state({"s1": state}), + # stream 2 slice 1 + *_as_records("s2", stream_output), + _state({"s1": state, "s2": state}), + # stream 2 slice 2 + *_as_records("s2", stream_output), + _state({"s1": state, "s2": state}), + ] + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict)))) + + assert expected == messages + + +def test_valid_incremental_read_with_slices_and_interval(mocker, logger): + """ + Tests that an incremental read which uses slices and a checkpoint interval: + 1. outputs all records + 2. outputs a state message every N records (N=checkpoint_interval) + 3. outputs a state message after reading the entire slice + """ + slices = [{"1": "1"}, {"2": "2"}] + stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}] + s1 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": mocker.ANY}, stream_output) for s in slices], name="s1" + ) + s2 = MockStream( + [({"sync_mode": SyncMode.incremental, "stream_slice": s, "stream_state": mocker.ANY}, stream_output) for s in slices], name="s2" + ) + state = {"cursor": "value"} + mocker.patch.object(MockStream, "get_updated_state", return_value=state) + mocker.patch.object(MockStream, "supports_incremental", return_value=True) + mocker.patch.object(MockStream, "get_json_schema", return_value={}) + mocker.patch.object(MockStream, "stream_slices", return_value=slices) + mocker.patch.object(MockStream, "state_checkpoint_interval", new_callable=mocker.PropertyMock, return_value=2) + + src = MockSource(streams=[s1, s2]) + catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s1, SyncMode.incremental), _configured_stream(s2, SyncMode.incremental)]) + + expected = [ + # stream 1 slice 1 + _as_record("s1", stream_output[0]), + _as_record("s1", stream_output[1]), + _state({"s1": state}), + _as_record("s1", stream_output[2]), + _state({"s1": state}), + # stream 1 slice 2 + _as_record("s1", stream_output[0]), + _as_record("s1", stream_output[1]), + _state({"s1": state}), + _as_record("s1", stream_output[2]), + _state({"s1": state}), + # stream 2 slice 1 + _as_record("s2", stream_output[0]), + _as_record("s2", stream_output[1]), + _state({"s1": state, "s2": state}), + _as_record("s2", stream_output[2]), + _state({"s1": state, "s2": state}), + # stream 2 slice 2 + _as_record("s2", stream_output[0]), + _as_record("s2", stream_output[1]), + _state({"s1": state, "s2": state}), + _as_record("s2", stream_output[2]), + _state({"s1": state, "s2": state}), + ] + + messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict)))) + + assert expected == messages