1
0
mirror of synced 2025-12-26 05:05:18 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/sources/test_source_read.py

423 lines
16 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
from unittest.mock import Mock
import freezegun
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStream,
AirbyteStreamStatus,
AirbyteStreamStatusTraceMessage,
AirbyteTraceMessage,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
DestinationSyncMode,
StreamDescriptor,
SyncMode,
TraceType,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
from airbyte_cdk.sources.message import InMemoryMessageRepository
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.utils import AirbyteTracedException
from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger
class _MockStream(Stream):
def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], name: str):
self._slice_to_records = slice_to_records
self._name = name
@property
def name(self) -> str:
return self._name
@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return None
def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
for partition in self._slice_to_records.keys():
yield {"partition": partition}
def read_records(
self,
sync_mode: SyncMode,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[StreamData]:
for record_or_exception in self._slice_to_records[stream_slice["partition"]]:
if isinstance(record_or_exception, Exception):
raise record_or_exception
else:
yield record_or_exception
def get_json_schema(self) -> Mapping[str, Any]:
return {}
class _MockSource(AbstractSource):
message_repository = InMemoryMessageRepository()
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
pass
def set_streams(self, streams):
self._streams = streams
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
return self._streams
class _MockConcurrentSource(ConcurrentSourceAdapter):
message_repository = InMemoryMessageRepository()
def __init__(self, logger):
concurrent_source = ConcurrentSource.create(1, 1, logger, NeverLogSliceLogger(), self.message_repository)
super().__init__(concurrent_source)
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
pass
def set_streams(self, streams):
self._streams = streams
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
return self._streams
@freezegun.freeze_time("2020-01-01T00:00:00")
def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_exceptions_are_raised():
records_stream_1_partition_1 = [
{"id": 1, "partition": "1"},
{"id": 2, "partition": "1"},
]
records_stream_1_partition_2 = [
{"id": 3, "partition": "2"},
{"id": 4, "partition": "2"},
]
records_stream_2_partition_1 = [
{"id": 100, "partition": "A"},
{"id": 200, "partition": "A"},
]
records_stream_2_partition_2 = [
{"id": 300, "partition": "B"},
{"id": 400, "partition": "B"},
]
stream_1_slice_to_partition = {"1": records_stream_1_partition_1, "2": records_stream_1_partition_2}
stream_2_slice_to_partition = {"A": records_stream_2_partition_1, "B": records_stream_2_partition_2}
state = None
logger = _init_logger()
source, concurrent_source = _init_sources([stream_1_slice_to_partition, stream_2_slice_to_partition], state, logger)
config = {}
catalog = _create_configured_catalog(source._streams)
# FIXME this is currently unused in this test
# messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None)
messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, None)
expected_messages = [
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=1577836800000.0,
error=None,
estimate=None,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
),
),
),
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=1577836800000.0,
error=None,
estimate=None,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
),
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream0",
data=records_stream_1_partition_1[0],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream0",
data=records_stream_1_partition_1[1],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream0",
data=records_stream_1_partition_2[0],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream0",
data=records_stream_1_partition_2[1],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=1577836800000.0,
error=None,
estimate=None,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
),
),
),
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=1577836800000.0,
error=None,
estimate=None,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
),
),
),
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=1577836800000.0,
error=None,
estimate=None,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
),
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream1",
data=records_stream_2_partition_1[0],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream1",
data=records_stream_2_partition_1[1],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream1",
data=records_stream_2_partition_2[0],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.RECORD,
record=AirbyteRecordMessage(
stream="stream1",
data=records_stream_2_partition_2[1],
emitted_at=1577836800000,
),
),
AirbyteMessage(
type=MessageType.TRACE,
trace=AirbyteTraceMessage(
type=TraceType.STREAM_STATUS,
emitted_at=1577836800000.0,
error=None,
estimate=None,
stream_status=AirbyteStreamStatusTraceMessage(
stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
),
),
),
]
_verify_messages(expected_messages, messages_from_concurrent_source)
@freezegun.freeze_time("2020-01-01T00:00:00")
def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_a_traced_exception_is_raised():
records = [{"id": 1, "partition": "1"}, AirbyteTracedException()]
stream_slice_to_partition = {"1": records}
logger = _init_logger()
state = None
source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
config = {}
catalog = _create_configured_catalog(source._streams)
messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException)
messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException)
_assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source)
_assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source)
_assert_errors(messages_from_abstract_source, messages_from_concurrent_source)
@freezegun.freeze_time("2020-01-01T00:00:00")
def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_exception_is_raised():
records = [{"id": 1, "partition": "1"}, RuntimeError()]
stream_slice_to_partition = {"1": records}
logger = _init_logger()
state = None
source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
config = {}
catalog = _create_configured_catalog(source._streams)
messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException)
messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException)
_assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source)
_assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source)
_assert_errors(messages_from_abstract_source, messages_from_concurrent_source)
def _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source):
status_from_concurrent_source = [
message
for message in messages_from_concurrent_source
if message.type == MessageType.TRACE and message.trace.type == TraceType.STREAM_STATUS
]
assert status_from_concurrent_source
_verify_messages(
[
message
for message in messages_from_abstract_source
if message.type == MessageType.TRACE and message.trace.type == TraceType.STREAM_STATUS
],
status_from_concurrent_source,
)
def _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source):
records_from_concurrent_source = [message for message in messages_from_concurrent_source if message.type == MessageType.RECORD]
assert records_from_concurrent_source
_verify_messages(
[message for message in messages_from_abstract_source if message.type == MessageType.RECORD],
records_from_concurrent_source,
)
def _assert_errors(messages_from_abstract_source, messages_from_concurrent_source):
errors_from_concurrent_source = [
message
for message in messages_from_concurrent_source
if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR
]
errors_from_abstract_source = [
message for message in messages_from_abstract_source if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR
]
assert errors_from_concurrent_source
# exceptions might differ from both framework hence we only assert the count
assert len(errors_from_concurrent_source) == len(errors_from_abstract_source)
def _init_logger():
logger = Mock()
logger.level = logging.INFO
logger.isEnabledFor.return_value = False
return logger
def _init_sources(stream_slice_to_partitions, state, logger):
source = _init_source(stream_slice_to_partitions, state, logger, _MockSource())
concurrent_source = _init_source(stream_slice_to_partitions, state, logger, _MockConcurrentSource(logger))
return source, concurrent_source
def _init_source(stream_slice_to_partitions, state, logger, source):
streams = [
StreamFacade.create_from_stream(
_MockStream(stream_slices, f"stream{i}"),
source,
logger,
state,
FinalStateCursor(stream_name=f"stream{i}", stream_namespace=None, message_repository=InMemoryMessageRepository()),
)
for i, stream_slices in enumerate(stream_slice_to_partitions)
]
source.set_streams(streams)
return source
def _create_configured_catalog(streams):
return ConfiguredAirbyteCatalog(
streams=[
ConfiguredAirbyteStream(
stream=AirbyteStream(name=s.name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]),
sync_mode=SyncMode.full_refresh,
cursor_field=None,
destination_sync_mode=DestinationSyncMode.overwrite,
)
for s in streams
]
)
def _read_from_source(source, logger, config, catalog, state, expected_exception):
messages = []
try:
for m in source.read(logger, config, catalog, state):
messages.append(m)
except Exception as e:
if expected_exception:
assert isinstance(e, expected_exception)
return messages
def _verify_messages(expected_messages, messages_from_concurrent_source):
assert _compare(expected_messages, messages_from_concurrent_source)
def _compare(s, t):
# Use a compare method that does not require ordering or hashing the elements
# We can't rely on the ordering because of the multithreading
# AirbyteMessage does not implement __eq__ and __hash__
t = list(t)
try:
for elem in s:
t.remove(elem)
except ValueError:
print(f"ValueError: {elem}")
return False
return not t