1
0
mirror of synced 2026-01-01 09:02:59 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/sources/test_abstract_source.py
Alexandre Girard 1a608f846a Stream returns AirbyteMessages (#18572)
* method yielding airbytemessage

* move to Stream

* update abstract source

* reset

* missing file

* add test docker image

* script to run acceptance tests with local cdk

* Update conftest to use a different image

* extract to method

* dont use a different image tag

* Always install local cdk

* break the cdk

* get path from current working directory

* or

* ignore unit test

* debug log

* Revert "AMI change: ami-0f23be2f917510c26 -> ami-005924fb76f7477ce (#18689)"

This reverts commit bf06decf73.

* build from the top

* Update source-acceptance-test

* fix

* copy setup

* some work on the gradle plugin

* reset to master

* delete unused file

* delete unused file

* reset to master

* optional argument

* delete dead code

* use latest cdk with sendgrid

* fix sendgrid dockerfile

* break the cdk

* use local file

* Revert "break the cdk"

This reverts commit 600c195541.

* dont raise an exception

* reset to master

* unit tests

* missing test

* more unit tests

* newline

* reset to master

* remove files

* reset

* Update abstract source

* remove method from stream

* convert to airbytemessage

* unittests

* Update

* unit test

* remove debug logs

* Revert "remove debug logs"

This reverts commit a1a139ef37.

* Revert "Revert "remove debug logs""

This reverts commit b1d62cdb60.

* Revert "reset to master"

This reverts commit 3fa6a004c1.

* Revert "Revert "reset to master""

This reverts commit 5dac7c2804.

* reset to master

* reset to master

* test

* Revert "test"

This reverts commit 2f91b803b0.

* test

* Revert "test"

This reverts commit 62d95ebbb5.

* test

* Revert "test"

This reverts commit 27150ba341.

* format
2022-11-07 21:23:11 -08:00

941 lines
37 KiB
Python

#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
import copy
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import call
import pytest
from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteLogMessage,
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStateType,
AirbyteStream,
AirbyteStreamState,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
Level,
Status,
StreamDescriptor,
SyncMode,
Type,
)
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.streams import IncrementalMixin, Stream
from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
logger = logging.getLogger("airbyte")
class MockSource(AbstractSource):
def __init__(
self,
check_lambda: Callable[[], Tuple[bool, Optional[Any]]] = None,
streams: List[Stream] = None,
per_stream: bool = True,
):
self._streams = streams
self.check_lambda = check_lambda
self.per_stream = per_stream
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
if self.check_lambda:
return self.check_lambda()
return False, "Missing callable."
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
if not self._streams:
raise Exception("Stream is not set")
return self._streams
@property
def per_stream_state_enabled(self) -> bool:
return self.per_stream
class StreamNoStateMethod(Stream):
name = "managers"
primary_key = None
def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]:
return {}
class MockStreamOverridesStateMethod(Stream, IncrementalMixin):
name = "teams"
primary_key = None
cursor_field = "updated_at"
_cursor_value = ""
start_date = "1984-12-12"
def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]:
return {}
@property
def state(self) -> MutableMapping[str, Any]:
return {self.cursor_field: self._cursor_value} if self._cursor_value else {}
@state.setter
def state(self, value: MutableMapping[str, Any]):
self._cursor_value = value.get(self.cursor_field, self.start_date)
def test_successful_check():
"""Tests that if a source returns TRUE for the connection check the appropriate connectionStatus success message is returned"""
expected = AirbyteConnectionStatus(status=Status.SUCCEEDED)
assert expected == MockSource(check_lambda=lambda: (True, None)).check(logger, {})
def test_failed_check():
"""Tests that if a source returns FALSE for the connection check the appropriate connectionStatus failure message is returned"""
expected = AirbyteConnectionStatus(status=Status.FAILED, message="'womp womp'")
assert expected == MockSource(check_lambda=lambda: (False, "womp womp")).check(logger, {})
def test_raising_check():
"""Tests that if a source raises an unexpected exception the appropriate connectionStatus failure message is returned."""
expected = AirbyteConnectionStatus(status=Status.FAILED, message="Exception('this should fail')")
assert expected == MockSource(check_lambda=lambda: exec('raise Exception("this should fail")')).check(logger, {})
class MockStream(Stream):
def __init__(
self,
inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]] = None,
name: str = None,
):
self._inputs_and_mocked_outputs = inputs_and_mocked_outputs
self._name = name
@property
def name(self):
return self._name
def read_records(self, **kwargs) -> Iterable[Mapping[str, Any]]: # type: ignore
# Remove None values
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if self._inputs_and_mocked_outputs:
for _input, output in self._inputs_and_mocked_outputs:
if kwargs == _input:
return output
raise Exception(f"No mocked output supplied for input: {kwargs}. Mocked inputs/outputs: {self._inputs_and_mocked_outputs}")
@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return "pk"
class MockStreamWithState(MockStream):
cursor_field = "cursor"
def __init__(self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[Mapping[str, Any]]]], name: str, state=None):
super().__init__(inputs_and_mocked_outputs, name)
self._state = state
@property
def state(self):
return self._state
@state.setter
def state(self, value):
pass
class MockStreamEmittingAirbyteMessages(MockStreamWithState):
def __init__(
self, inputs_and_mocked_outputs: List[Tuple[Mapping[str, Any], Iterable[AirbyteMessage]]] = None, name: str = None, state=None
):
super().__init__(inputs_and_mocked_outputs, name, state)
self._inputs_and_mocked_outputs = inputs_and_mocked_outputs
self._name = name
@property
def name(self):
return self._name
@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return "pk"
@property
def state(self) -> MutableMapping[str, Any]:
return {self.cursor_field: self._cursor_value} if self._cursor_value else {}
@state.setter
def state(self, value: MutableMapping[str, Any]):
self._cursor_value = value.get(self.cursor_field, self.start_date)
def test_discover(mocker):
"""Tests that the appropriate AirbyteCatalog is returned from the discover method"""
airbyte_stream1 = AirbyteStream(
name="1",
json_schema={},
supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental],
default_cursor_field=["cursor"],
source_defined_cursor=True,
source_defined_primary_key=[["pk"]],
)
airbyte_stream2 = AirbyteStream(name="2", json_schema={}, supported_sync_modes=[SyncMode.full_refresh])
stream1 = MockStream()
stream2 = MockStream()
mocker.patch.object(stream1, "as_airbyte_stream", return_value=airbyte_stream1)
mocker.patch.object(stream2, "as_airbyte_stream", return_value=airbyte_stream2)
expected = AirbyteCatalog(streams=[airbyte_stream1, airbyte_stream2])
src = MockSource(check_lambda=lambda: (True, None), streams=[stream1, stream2])
assert expected == src.discover(logger, {})
def test_read_nonexistent_stream_raises_exception(mocker):
"""Tests that attempting to sync a stream which the source does not return from the `streams` method raises an exception"""
s1 = MockStream(name="s1")
s2 = MockStream(name="this_stream_doesnt_exist_in_the_source")
mocker.patch.object(MockStream, "get_json_schema", return_value={})
src = MockSource(streams=[s1])
catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(s2, SyncMode.full_refresh)])
with pytest.raises(KeyError):
list(src.read(logger, {}, catalog))
def test_read_stream_with_error_gets_display_message(mocker):
stream = MockStream(name="my_stream")
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "read_records", side_effect=RuntimeError("oh no!"))
source = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(streams=[_configured_stream(stream, SyncMode.full_refresh)])
# without get_error_display_message
with pytest.raises(RuntimeError, match="oh no!"):
list(source.read(logger, {}, catalog))
mocker.patch.object(MockStream, "get_error_display_message", return_value="my message")
with pytest.raises(AirbyteTracedException, match="oh no!") as exc:
list(source.read(logger, {}, catalog))
assert exc.value.message == "my message"
GLOBAL_EMITTED_AT = 1
def _as_record(stream: str, data: Dict[str, Any]) -> AirbyteMessage:
return AirbyteMessage(
type=Type.RECORD,
record=AirbyteRecordMessage(stream=stream, data=data, emitted_at=GLOBAL_EMITTED_AT),
)
def _as_records(stream: str, data: List[Dict[str, Any]]) -> List[AirbyteMessage]:
return [_as_record(stream, datum) for datum in data]
def _as_state(state_data: Dict[str, Any], stream_name: str = "", per_stream_state: Dict[str, Any] = None):
if per_stream_state:
return AirbyteMessage(
type=Type.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name=stream_name), stream_state=AirbyteStateBlob.parse_obj(per_stream_state)
),
data=state_data,
),
)
return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=state_data))
def _configured_stream(stream: Stream, sync_mode: SyncMode):
return ConfiguredAirbyteStream(
stream=stream.as_airbyte_stream(),
sync_mode=sync_mode,
destination_sync_mode=DestinationSyncMode.overwrite,
)
def _fix_emitted_at(messages: List[AirbyteMessage]) -> List[AirbyteMessage]:
for msg in messages:
if msg.type == Type.RECORD and msg.record:
msg.record.emitted_at = GLOBAL_EMITTED_AT
return messages
def test_valid_full_refresh_read_no_slices(mocker):
"""Tests that running a full refresh sync on streams which don't specify slices produces the expected AirbyteMessages"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
s1 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s1")
s2 = MockStream([({"sync_mode": SyncMode.full_refresh}, stream_output)], name="s2")
mocker.patch.object(MockStream, "get_json_schema", return_value={})
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.full_refresh),
_configured_stream(s2, SyncMode.full_refresh),
]
)
expected = _as_records("s1", stream_output) + _as_records("s2", stream_output)
messages = _fix_emitted_at(list(src.read(logger, {}, catalog)))
assert expected == messages
def test_valid_full_refresh_read_with_slices(mocker):
"""Tests that running a full refresh sync on streams which use slices produces the expected AirbyteMessages"""
slices = [{"1": "1"}, {"2": "2"}]
# When attempting to sync a slice, just output that slice as a record
s1 = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s1",
)
s2 = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s2",
)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
src = MockSource(streams=[s1, s2])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(s1, SyncMode.full_refresh),
_configured_stream(s2, SyncMode.full_refresh),
]
)
expected = [*_as_records("s1", slices), *_as_records("s2", slices)]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog)))
assert expected == messages
class TestIncrementalRead:
@pytest.mark.parametrize(
"use_legacy",
[
pytest.param(True, id="test_incoming_stream_state_as_legacy_format"),
pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(True, id="test_source_emits_state_as_per_stream_format"),
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
def test_with_state_attribute(self, mocker, use_legacy, per_stream_enabled):
"""Test correct state passing for the streams that have a state attribute"""
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
old_state = {"cursor": "old_value"}
if use_legacy:
input_state = {"s1": old_state}
else:
input_state = [
AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="s1"), stream_state=AirbyteStateBlob.parse_obj(old_state)
),
),
]
new_state_from_connector = {"cursor": "new_value"}
stream_1 = MockStreamWithState(
[
(
{"sync_mode": SyncMode.incremental, "stream_state": old_state},
stream_output,
)
],
name="s1",
)
stream_2 = MockStreamWithState(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s2",
)
mocker.patch.object(MockStreamWithState, "get_updated_state", return_value={})
state_property = mocker.patch.object(
MockStreamWithState,
"state",
new_callable=mocker.PropertyMock,
return_value=new_state_from_connector,
)
mocker.patch.object(MockStreamWithState, "get_json_schema", return_value={})
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = [
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_as_state({"s1": new_state_from_connector}, "s1", new_state_from_connector)
if per_stream_enabled
else _as_state({"s1": new_state_from_connector}),
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_as_state({"s1": new_state_from_connector, "s2": new_state_from_connector}, "s2", new_state_from_connector)
if per_stream_enabled
else _as_state({"s1": new_state_from_connector, "s2": new_state_from_connector}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert messages == expected
assert state_property.mock_calls == [
call(old_state), # set state for s1
call(), # get state in the end of slice for s1
call(), # get state in the end of slice for s2
]
@pytest.mark.parametrize(
"use_legacy",
[
pytest.param(True, id="test_incoming_stream_state_as_legacy_format"),
pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(True, id="test_source_emits_state_as_per_stream_format"),
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
def test_with_checkpoint_interval(self, mocker, use_legacy, per_stream_enabled):
"""Tests that an incremental read which doesn't specify a checkpoint interval outputs a STATE message
after reading N records within a stream.
"""
if use_legacy:
input_state = defaultdict(dict)
else:
input_state = []
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
stream_1 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s1",
)
stream_2 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
# Tell the source to output one state message per record
mocker.patch.object(
MockStream,
"state_checkpoint_interval",
new_callable=mocker.PropertyMock,
return_value=1,
)
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = [
_as_record("s1", stream_output[0]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_record("s1", stream_output[1]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_record("s2", stream_output[0]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[1]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
@pytest.mark.parametrize(
"use_legacy",
[
pytest.param(True, id="test_incoming_stream_state_as_legacy_format"),
pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(True, id="test_source_emits_state_as_per_stream_format"),
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
def test_with_no_interval(self, mocker, use_legacy, per_stream_enabled):
"""Tests that an incremental read which doesn't specify a checkpoint interval outputs
a STATE message only after fully reading the stream and does not output any STATE messages during syncing the stream.
"""
if use_legacy:
input_state = defaultdict(dict)
else:
input_state = []
stream_output = [{"k1": "v1"}, {"k2": "v2"}]
stream_1 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s1",
)
stream_2 = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_state": {}}, stream_output)],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = [
*_as_records("s1", stream_output),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
*_as_records("s2", stream_output),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
@pytest.mark.parametrize(
"use_legacy",
[
pytest.param(True, id="test_incoming_stream_state_as_legacy_format"),
pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(True, id="test_source_emits_state_as_per_stream_format"),
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
def test_with_slices(self, mocker, use_legacy, per_stream_enabled):
"""Tests that an incremental read which uses slices outputs each record in the slice followed by a STATE message, for each slice"""
if use_legacy:
input_state = defaultdict(dict)
else:
input_state = []
slices = [{"1": "1"}, {"2": "2"}]
stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}]
stream_1 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s1",
)
stream_2 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = [
# stream 1 slice 1
*_as_records("s1", stream_output),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 1 slice 2
*_as_records("s1", stream_output),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 2 slice 1
*_as_records("s2", stream_output),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
# stream 2 slice 2
*_as_records("s2", stream_output),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
@pytest.mark.parametrize(
"use_legacy",
[
pytest.param(True, id="test_incoming_stream_state_as_legacy_format"),
pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(True, id="test_source_emits_state_as_per_stream_format"),
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize("slices", [pytest.param([], id="test_slices_as_list"), pytest.param(iter([]), id="test_slices_as_iterator")])
def test_no_slices(self, mocker, use_legacy, per_stream_enabled, slices):
"""
Tests that an incremental read returns at least one state messages even if no records were read:
1. outputs a state message after reading the entire stream
"""
if use_legacy:
input_state = defaultdict(dict)
else:
input_state = []
stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}]
state = {"cursor": "value"}
stream_1 = MockStreamWithState(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s1",
state=state,
)
stream_2 = MockStreamWithState(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s2",
state=state,
)
mocker.patch.object(MockStreamWithState, "supports_incremental", return_value=True)
mocker.patch.object(MockStreamWithState, "get_json_schema", return_value={})
mocker.patch.object(MockStreamWithState, "stream_slices", return_value=slices)
mocker.patch.object(
MockStreamWithState,
"state_checkpoint_interval",
new_callable=mocker.PropertyMock,
return_value=2,
)
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = [
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
@pytest.mark.parametrize(
"use_legacy",
[
pytest.param(True, id="test_incoming_stream_state_as_legacy_format"),
pytest.param(False, id="test_incoming_stream_state_as_per_stream_format"),
],
)
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(True, id="test_source_emits_state_as_per_stream_format"),
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
def test_with_slices_and_interval(self, mocker, use_legacy, per_stream_enabled):
"""
Tests that an incremental read which uses slices and a checkpoint interval:
1. outputs all records
2. outputs a state message every N records (N=checkpoint_interval)
3. outputs a state message after reading the entire slice
"""
if use_legacy:
input_state = defaultdict(dict)
else:
input_state = []
slices = [{"1": "1"}, {"2": "2"}]
stream_output = [{"k1": "v1"}, {"k2": "v2"}, {"k3": "v3"}]
stream_1 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s1",
)
stream_2 = MockStream(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s2",
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
mocker.patch.object(
MockStream,
"state_checkpoint_interval",
new_callable=mocker.PropertyMock,
return_value=2,
)
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = [
# stream 1 slice 1
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_record("s1", stream_output[2]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 1 slice 2
_as_record("s1", stream_output[0]),
_as_record("s1", stream_output[1]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
_as_record("s1", stream_output[2]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 2 slice 1
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[2]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
# stream 2 slice 2
_as_record("s2", stream_output[0]),
_as_record("s2", stream_output[1]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
_as_record("s2", stream_output[2]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
@pytest.mark.parametrize(
"per_stream_enabled",
[
pytest.param(False, id="test_source_emits_state_as_per_stream_format"),
],
)
def test_emit_non_records(self, mocker, per_stream_enabled):
"""
Tests that an incremental read which uses slices and a checkpoint interval:
1. outputs all records
2. outputs a state message every N records (N=checkpoint_interval)
3. outputs a state message after reading the entire slice
"""
input_state = []
slices = [{"1": "1"}, {"2": "2"}]
stream_output = [
{"k1": "v1"},
AirbyteLogMessage(level=Level.INFO, message="HELLO"),
{"k2": "v2"},
{"k3": "v3"},
]
stream_1 = MockStreamEmittingAirbyteMessages(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s1",
state=copy.deepcopy(input_state),
)
stream_2 = MockStreamEmittingAirbyteMessages(
[
(
{
"sync_mode": SyncMode.incremental,
"stream_slice": s,
"stream_state": mocker.ANY,
},
stream_output,
)
for s in slices
],
name="s2",
state=copy.deepcopy(input_state),
)
state = {"cursor": "value"}
mocker.patch.object(MockStream, "get_updated_state", return_value=state)
mocker.patch.object(MockStream, "supports_incremental", return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)
mocker.patch.object(
MockStream,
"state_checkpoint_interval",
new_callable=mocker.PropertyMock,
return_value=2,
)
src = MockSource(streams=[stream_1, stream_2], per_stream=per_stream_enabled)
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream_1, SyncMode.incremental),
_configured_stream(stream_2, SyncMode.incremental),
]
)
expected = _fix_emitted_at(
[
# stream 1 slice 1
stream_data_to_airbyte_message("s1", stream_output[0]),
stream_data_to_airbyte_message("s1", stream_output[1]),
stream_data_to_airbyte_message("s1", stream_output[2]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
stream_data_to_airbyte_message("s1", stream_output[3]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 1 slice 2
stream_data_to_airbyte_message("s1", stream_output[0]),
stream_data_to_airbyte_message("s1", stream_output[1]),
stream_data_to_airbyte_message("s1", stream_output[2]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
stream_data_to_airbyte_message("s1", stream_output[3]),
_as_state({"s1": state}, "s1", state) if per_stream_enabled else _as_state({"s1": state}),
# stream 2 slice 1
stream_data_to_airbyte_message("s2", stream_output[0]),
stream_data_to_airbyte_message("s2", stream_output[1]),
stream_data_to_airbyte_message("s2", stream_output[2]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
stream_data_to_airbyte_message("s2", stream_output[3]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
# stream 2 slice 2
stream_data_to_airbyte_message("s2", stream_output[0]),
stream_data_to_airbyte_message("s2", stream_output[1]),
stream_data_to_airbyte_message("s2", stream_output[2]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
stream_data_to_airbyte_message("s2", stream_output[3]),
_as_state({"s1": state, "s2": state}, "s2", state) if per_stream_enabled else _as_state({"s1": state, "s2": state}),
]
)
messages = _fix_emitted_at(list(src.read(logger, {}, catalog, state=input_state)))
assert expected == messages
def test_checkpoint_state_from_stream_instance():
teams_stream = MockStreamOverridesStateMethod()
managers_stream = StreamNoStateMethod()
src = MockSource(streams=[teams_stream, managers_stream])
state_manager = ConnectorStateManager({"teams": teams_stream, "managers": managers_stream}, [])
# The stream_state passed to checkpoint_state() should be ignored since stream implements state function
teams_stream.state = {"updated_at": "2022-09-11"}
actual_message = src._checkpoint_state(teams_stream, {"ignored": "state"}, state_manager)
assert actual_message == _as_state({"teams": {"updated_at": "2022-09-11"}}, "teams", {"updated_at": "2022-09-11"})
# The stream_state passed to checkpoint_state() should be used since the stream does not implement state function
actual_message = src._checkpoint_state(managers_stream, {"updated": "expected_here"}, state_manager)
assert actual_message == _as_state(
{"teams": {"updated_at": "2022-09-11"}, "managers": {"updated": "expected_here"}}, "managers", {"updated": "expected_here"}
)