1
0
mirror of synced 2026-01-01 09:02:59 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py
Alexandre Girard 3894134d11 Bump year in license short to 2022 (#13191)
* Bump to 2022

* format
2022-05-25 17:56:49 -07:00

542 lines
19 KiB
Python

#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from unittest.mock import call
import pytest
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
Status,
SyncMode,
Type,
)
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
logger = logging.getLogger("airbyte")
class MockSource(AbstractSource):
def __init__(
self,
check_lambda: Callable[[], Tuple[bool, Optional[Any]]] = None,
streams: List[Stream] = None,
):
self._streams = streams
self.check_lambda = check_lambda
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
if self.check_lambda:
return self.check_lambda()
return (False, "Missing callable.")
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
if not self._streams:
raise Exception("Stream is not set")
return self._streams
def test_successful_check():
"""Tests that if a source returns TRUE for the connection check the appropriate connectionStatus success message is returned"""
expected = AirbyteConnectionStatus(status=Status.SUCCEEDED)
assert expected == MockSource(check_lambda=lambda: (True, None)).check(logger, {})
def test_failed_check():
"""Tests that if a source returns FALSE for the connection check the appropriate connectionStatus failure message is returned"""
expected = AirbyteConnectionStatus(status=Status.FAILED, message="'womp womp'")
assert expected == MockSource(check_lambda=lambda: (False, "womp womp")).check(logger, {})
def test_raising_check():
"""Tests that if a source raises an unexpected exception the appropriate connectionStatus failure message is returned."""
expected = AirbyteConnectionStatus(status=Status.FAILED, message="Exception('this should fail')")
assert expected == MockSource(check_lambda=lambda: exec('raise Exception("this should fail")')).check(logger, {})
class MockStream(Stream):
def __init__(
self,
inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]] = None,
name: str = None,
):
self._inputs_and_mocked_outputs = inputs_and_mocked_outputs
self._name = name
@property
def name(self):
return self._name
def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore
# Remove None values
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if self._inputs_and_mocked_outputs:
for _input, output in self._inputs_and_mocked_outputs:
if kwargs == _input:
return output
raise Exception(f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}")
@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return "pk"
class MockStreamWithState(MockStream):
cursor_field = "cursor"
@property
def state(self):
return {}
@state.setter
def state(self, value):
pass
def test_discover(mocker):
"""Tests that the appropriate AirbyteCatalog is returned from the discover method"""
airbyte_stream1 = AirbyteStream(
name="1",
json_schema={},
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
default_cursor_field=["cursor"],
source_defined_cursor=True,
source_defined_primary_key=[["pk"]],
)
airbyte_stream2 = AirbyteStream(name="2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh])
stream1 = MockStream()
stream2 = MockStream()
mocker.patch.object(stream1, "as_airbyte_stream", return_value=airbyte_stream1)
mocker.patch.object(stream2, "as_airbyte_stream", return_value=airbyte_stream2)
expected = AirbyteCatalog(streams=[airbyte_stream1, airbyte_stream2])
src = MockSource(check_lambda=lambda: (True, None), streams=[stream1, stream2])
assert expected == src.discover(logger, {})
def test_read_nonexistent_stream_raises_exception(mocker):
"""Tests that attempting to sync a stream which the source does not return from the `streams` method raises an exception"""
s1 = MockStream(name="s1")
s2 = MockStream(name="this_stream_doesnt_exist_in_the_source")
mocker.patch.object(MockStream, "get_json_schema", return_value={})
src = MockSource(streams=[s1])
catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s2, SyncMode.full_refresh)])
with pytest.raises(KeyError):
list(src.read(logger, {}, catalog))
def test_read_stream_with_error_gets_display_message(mocker):
stream = MockStream(name="my_stream")
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "read_records", side_effect=RuntimeError("oh no!"))
source = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)])
# without get_error_display_message
with pytest.raises(RuntimeError, match="oh no!"):
list(source.read(logger, {}, catalog))
mocker.patch.object(MockStream, "get_error_display_message", return_value="my message")
with pytest.raises(AirbyteTracedException, match="oh no!") as exc:
list(source.read(logger, {}, catalog))
assert exc.value.message == "my message"
GLOBAL_EMITTED_AT = 1
def _as_record(stream: str, data: Dict[str, Any]) -> AirbyteMessage:
return AirbyteMessage(
type=Type.RECORD,
record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=GLOBAL_EMITTED_AT),
)
def _as_records(stream: str, data: List[Dict[str, Any]]) -> List[AirbyteMessage]:
return [_as_record(stream, datum) for datum in data]
def _configured_stream(stream: Stream, sync_mode: SyncMode):
return ConfiguredAirbyteStream(
stream=stream.as_airbyte_stream(),
sync_mode=sync_mode,
destination_sync_mode=DestinationSyncMode.overwrite,
)
def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]:
for msg in messages:
if msg.type == Type.RECORD and msg.record:
msg.record.emitted_at = GLOBAL_EMITTED_AT
return messages
def test_valid_full_refresh_read_no_slices(mocker):
"""Tests that running a full refresh sync on streams which don't specify slices produces the expected AirbyteMessages"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1")
s2 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s2")
mocker.patch.object(MockStream, "get_json_schema", return_value={})
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.full_refresh),
_configured_stream(s2, SyncMode.full_refresh),
]
)
expected = _as_records("s1", stream_output) + _as_records("s2", stream_output)
messages = _fix_emitted_at(list(src.read(logger, {}, catalog)))
assert expected == messages
def test_valid_full_refresh_read_with_slices(mocker):
"""Tests that running a full refresh sync on streams which use slices produces the expected AirbyteMessages"""
slices = [{"1": "1"}, {"2": "2"}]
# When attempting to sync a slice, just output that slice as a record
s1 = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s1",
)
s2 = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s2",
)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.full_refresh),
_configured_stream(s2, SyncMode.full_refresh),
]
)
expected = [*_as_records("s1", slices), *_as_records("s2", slices)]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog)))
assert expected == messages
def _state(state_data: Dict[str, Any]):
return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=state_data))
class TestIncrementalRead:
def test_with_state_attribute(self, mocker):
"""Test correct state passing for the streams that have a state attribute"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
old_state = {"cursor": "old_value"}
new_state = {"cursor": "new_value"}
s1 = MockStreamWithState(
[
(
{"sync_mode": SyncMode.incremental, "stream_state": old_state},
stream_output,
)
],
name="s1",
)
s2 = MockStreamWithState(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s2",
)
mocker.patch.object(MockStreamWithState, "get_updated_state", return_value={})
state_property = mocker.patch.object(
MockStreamWithState,
"state",
new_callable=mocker.PropertyMock,
return_value=new_state,
)
mocker.patch.object(MockStreamWithState, "get_json_schema", return_value={})
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.incremental),
_configured_stream(s2, SyncMode.incremental),
]
)
expected = [
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_state({"s1": new_state}),
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_state({"s1": new_state, "s2": new_state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state={"s1": old_state})))
assert expected == messages
assert state_property.mock_calls == [
call(old_state), # set state for s1
call(), # get state in the end of slice for s1
call(), # get state in the end of slice for s2
]
def test_with_checkpoint_interval(self, mocker):
"""Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message
after reading N records within a stream.
"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
s1 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s1",
)
s2 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
# Tell the source to output one state message per record
mocker.patch.object(
MockStream,
"state_checkpoint_interval",
new_callable=mocker.PropertyMock,
return_value=1,
)
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.incremental),
_configured_stream(s2, SyncMode.incremental),
]
)
expected = [
_as_record("s1", stream_output[0]),
_state({"s1": state}),
_as_record("s1", stream_output[1]),
_state({"s1": state}),
_state({"s1": state}),
_as_record("s2", stream_output[0]),
_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[1]),
_state({"s1": state, "s2": state}),
_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict))))
assert expected == messages
def test_with_no_interval(self, mocker):
"""Tests that an incremental read which doesn't specify a checkpoint interval outputs
a STATE message only after fully reading the stream and does not output any STATE messages during syncing the stream.
"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
s1 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s1",
)
s2 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.incremental),
_configured_stream(s2, SyncMode.incremental),
]
)
expected = [
*_as_records("s1", stream_output),
_state({"s1": state}),
*_as_records("s2", stream_output),
_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict))))
assert expected == messages
def test_with_slices(self, mocker):
"""Tests that an incremental read which uses slices outputs each record in the slice followed by a STATE message, for each slice"""
slices = [{"1": "1"}, {"2": "2"}]
stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}]
s1 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s1",
)
s2 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.incremental),
_configured_stream(s2, SyncMode.incremental),
]
)
expected = [
# stream 1 slice 1
*_as_records("s1", stream_output),
_state({"s1": state}),
# stream 1 slice 2
*_as_records("s1", stream_output),
_state({"s1": state}),
# stream 2 slice 1
*_as_records("s2", stream_output),
_state({"s1": state, "s2": state}),
# stream 2 slice 2
*_as_records("s2", stream_output),
_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict))))
assert expected == messages
def test_with_slices_and_interval(self, mocker):
"""
Tests that an incremental read which uses slices and a checkpoint interval:
1. outputs all records
2. outputs a state message every N records (N=checkpoint_interval)
3. outputs a state message after reading the entire slice
"""
slices = [{"1": "1"}, {"2": "2"}]
stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}]
s1 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s1",
)
s2 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
mocker.patch.object(
MockStream,
"state_checkpoint_interval",
new_callable=mocker.PropertyMock,
return_value=2,
)
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.incremental),
_configured_stream(s2, SyncMode.incremental),
]
)
expected = [
# stream 1 slice 1
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_state({"s1": state}),
_as_record("s1", stream_output[2]),
_state({"s1": state}),
# stream 1 slice 2
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_state({"s1": state}),
_as_record("s1", stream_output[2]),
_state({"s1": state}),
# stream 2 slice 1
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[2]),
_state({"s1": state, "s2": state}),
# stream 2 slice 2
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[2]),
_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=defaultdict(dict))))
assert expected == messages