423 lines
16 KiB
Python
423 lines
16 KiB
Python
#
|
|
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
#
|
|
import logging
|
|
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
|
|
from unittest.mock import Mock
|
|
|
|
import freezegun
|
|
from airbyte_cdk.models import (
|
|
AirbyteMessage,
|
|
AirbyteRecordMessage,
|
|
AirbyteStream,
|
|
AirbyteStreamStatus,
|
|
AirbyteStreamStatusTraceMessage,
|
|
AirbyteTraceMessage,
|
|
ConfiguredAirbyteCatalog,
|
|
ConfiguredAirbyteStream,
|
|
DestinationSyncMode,
|
|
StreamDescriptor,
|
|
SyncMode,
|
|
TraceType,
|
|
)
|
|
from airbyte_cdk.models import Type as MessageType
|
|
from airbyte_cdk.sources import AbstractSource
|
|
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
|
|
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
|
|
from airbyte_cdk.sources.message import InMemoryMessageRepository
|
|
from airbyte_cdk.sources.streams import Stream
|
|
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
|
|
from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor
|
|
from airbyte_cdk.sources.streams.core import StreamData
|
|
from airbyte_cdk.utils import AirbyteTracedException
|
|
from unit_tests.sources.streams.concurrent.scenarios.thread_based_concurrent_stream_source_builder import NeverLogSliceLogger
|
|
|
|
|
|
class _MockStream(Stream):
|
|
def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], name: str):
|
|
self._slice_to_records = slice_to_records
|
|
self._name = name
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
@property
|
|
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
|
|
return None
|
|
|
|
def stream_slices(
|
|
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
|
|
) -> Iterable[Optional[Mapping[str, Any]]]:
|
|
for partition in self._slice_to_records.keys():
|
|
yield {"partition": partition}
|
|
|
|
def read_records(
|
|
self,
|
|
sync_mode: SyncMode,
|
|
cursor_field: Optional[List[str]] = None,
|
|
stream_slice: Optional[Mapping[str, Any]] = None,
|
|
stream_state: Optional[Mapping[str, Any]] = None,
|
|
) -> Iterable[StreamData]:
|
|
for record_or_exception in self._slice_to_records[stream_slice["partition"]]:
|
|
if isinstance(record_or_exception, Exception):
|
|
raise record_or_exception
|
|
else:
|
|
yield record_or_exception
|
|
|
|
def get_json_schema(self) -> Mapping[str, Any]:
|
|
return {}
|
|
|
|
|
|
class _MockSource(AbstractSource):
|
|
message_repository = InMemoryMessageRepository()
|
|
|
|
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
|
|
pass
|
|
|
|
def set_streams(self, streams):
|
|
self._streams = streams
|
|
|
|
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
|
|
return self._streams
|
|
|
|
|
|
class _MockConcurrentSource(ConcurrentSourceAdapter):
|
|
message_repository = InMemoryMessageRepository()
|
|
|
|
def __init__(self, logger):
|
|
concurrent_source = ConcurrentSource.create(1, 1, logger, NeverLogSliceLogger(), self.message_repository)
|
|
super().__init__(concurrent_source)
|
|
|
|
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
|
|
pass
|
|
|
|
def set_streams(self, streams):
|
|
self._streams = streams
|
|
|
|
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
|
|
return self._streams
|
|
|
|
|
|
@freezegun.freeze_time("2020-01-01T00:00:00")
|
|
def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_exceptions_are_raised():
|
|
records_stream_1_partition_1 = [
|
|
{"id": 1, "partition": "1"},
|
|
{"id": 2, "partition": "1"},
|
|
]
|
|
records_stream_1_partition_2 = [
|
|
{"id": 3, "partition": "2"},
|
|
{"id": 4, "partition": "2"},
|
|
]
|
|
records_stream_2_partition_1 = [
|
|
{"id": 100, "partition": "A"},
|
|
{"id": 200, "partition": "A"},
|
|
]
|
|
records_stream_2_partition_2 = [
|
|
{"id": 300, "partition": "B"},
|
|
{"id": 400, "partition": "B"},
|
|
]
|
|
stream_1_slice_to_partition = {"1": records_stream_1_partition_1, "2": records_stream_1_partition_2}
|
|
stream_2_slice_to_partition = {"A": records_stream_2_partition_1, "B": records_stream_2_partition_2}
|
|
state = None
|
|
logger = _init_logger()
|
|
|
|
source, concurrent_source = _init_sources([stream_1_slice_to_partition, stream_2_slice_to_partition], state, logger)
|
|
|
|
config = {}
|
|
catalog = _create_configured_catalog(source._streams)
|
|
# FIXME this is currently unused in this test
|
|
# messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None)
|
|
messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, None)
|
|
|
|
expected_messages = [
|
|
AirbyteMessage(
|
|
type=MessageType.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=1577836800000.0,
|
|
error=None,
|
|
estimate=None,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
|
|
),
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=1577836800000.0,
|
|
error=None,
|
|
estimate=None,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
|
|
),
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream0",
|
|
data=records_stream_1_partition_1[0],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream0",
|
|
data=records_stream_1_partition_1[1],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream0",
|
|
data=records_stream_1_partition_2[0],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream0",
|
|
data=records_stream_1_partition_2[1],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=1577836800000.0,
|
|
error=None,
|
|
estimate=None,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
|
|
),
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=1577836800000.0,
|
|
error=None,
|
|
estimate=None,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED)
|
|
),
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=1577836800000.0,
|
|
error=None,
|
|
estimate=None,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING)
|
|
),
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream1",
|
|
data=records_stream_2_partition_1[0],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream1",
|
|
data=records_stream_2_partition_1[1],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream1",
|
|
data=records_stream_2_partition_2[0],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.RECORD,
|
|
record=AirbyteRecordMessage(
|
|
stream="stream1",
|
|
data=records_stream_2_partition_2[1],
|
|
emitted_at=1577836800000,
|
|
),
|
|
),
|
|
AirbyteMessage(
|
|
type=MessageType.TRACE,
|
|
trace=AirbyteTraceMessage(
|
|
type=TraceType.STREAM_STATUS,
|
|
emitted_at=1577836800000.0,
|
|
error=None,
|
|
estimate=None,
|
|
stream_status=AirbyteStreamStatusTraceMessage(
|
|
stream_descriptor=StreamDescriptor(name="stream1"), status=AirbyteStreamStatus(AirbyteStreamStatus.COMPLETE)
|
|
),
|
|
),
|
|
),
|
|
]
|
|
_verify_messages(expected_messages, messages_from_concurrent_source)
|
|
|
|
|
|
@freezegun.freeze_time("2020-01-01T00:00:00")
|
|
def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_a_traced_exception_is_raised():
|
|
records = [{"id": 1, "partition": "1"}, AirbyteTracedException()]
|
|
stream_slice_to_partition = {"1": records}
|
|
|
|
logger = _init_logger()
|
|
state = None
|
|
source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
|
|
config = {}
|
|
catalog = _create_configured_catalog(source._streams)
|
|
messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException)
|
|
messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException)
|
|
|
|
_assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source)
|
|
_assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source)
|
|
_assert_errors(messages_from_abstract_source, messages_from_concurrent_source)
|
|
|
|
|
|
@freezegun.freeze_time("2020-01-01T00:00:00")
|
|
def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_exception_is_raised():
|
|
records = [{"id": 1, "partition": "1"}, RuntimeError()]
|
|
stream_slice_to_partition = {"1": records}
|
|
logger = _init_logger()
|
|
|
|
state = None
|
|
|
|
source, concurrent_source = _init_sources([stream_slice_to_partition], state, logger)
|
|
config = {}
|
|
catalog = _create_configured_catalog(source._streams)
|
|
messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException)
|
|
messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException)
|
|
|
|
_assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source)
|
|
_assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source)
|
|
_assert_errors(messages_from_abstract_source, messages_from_concurrent_source)
|
|
|
|
|
|
def _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source):
|
|
status_from_concurrent_source = [
|
|
message
|
|
for message in messages_from_concurrent_source
|
|
if message.type == MessageType.TRACE and message.trace.type == TraceType.STREAM_STATUS
|
|
]
|
|
|
|
assert status_from_concurrent_source
|
|
_verify_messages(
|
|
[
|
|
message
|
|
for message in messages_from_abstract_source
|
|
if message.type == MessageType.TRACE and message.trace.type == TraceType.STREAM_STATUS
|
|
],
|
|
status_from_concurrent_source,
|
|
)
|
|
|
|
|
|
def _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source):
|
|
records_from_concurrent_source = [message for message in messages_from_concurrent_source if message.type == MessageType.RECORD]
|
|
|
|
assert records_from_concurrent_source
|
|
_verify_messages(
|
|
[message for message in messages_from_abstract_source if message.type == MessageType.RECORD],
|
|
records_from_concurrent_source,
|
|
)
|
|
|
|
|
|
def _assert_errors(messages_from_abstract_source, messages_from_concurrent_source):
|
|
errors_from_concurrent_source = [
|
|
message
|
|
for message in messages_from_concurrent_source
|
|
if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR
|
|
]
|
|
errors_from_abstract_source = [
|
|
message for message in messages_from_abstract_source if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR
|
|
]
|
|
|
|
assert errors_from_concurrent_source
|
|
# exceptions might differ from both framework hence we only assert the count
|
|
assert len(errors_from_concurrent_source) == len(errors_from_abstract_source)
|
|
|
|
|
|
def _init_logger():
|
|
logger = Mock()
|
|
logger.level = logging.INFO
|
|
logger.isEnabledFor.return_value = False
|
|
return logger
|
|
|
|
|
|
def _init_sources(stream_slice_to_partitions, state, logger):
|
|
source = _init_source(stream_slice_to_partitions, state, logger, _MockSource())
|
|
concurrent_source = _init_source(stream_slice_to_partitions, state, logger, _MockConcurrentSource(logger))
|
|
return source, concurrent_source
|
|
|
|
|
|
def _init_source(stream_slice_to_partitions, state, logger, source):
|
|
streams = [
|
|
StreamFacade.create_from_stream(
|
|
_MockStream(stream_slices, f"stream{i}"),
|
|
source,
|
|
logger,
|
|
state,
|
|
FinalStateCursor(stream_name=f"stream{i}", stream_namespace=None, message_repository=InMemoryMessageRepository()),
|
|
)
|
|
for i, stream_slices in enumerate(stream_slice_to_partitions)
|
|
]
|
|
source.set_streams(streams)
|
|
return source
|
|
|
|
|
|
def _create_configured_catalog(streams):
|
|
return ConfiguredAirbyteCatalog(
|
|
streams=[
|
|
ConfiguredAirbyteStream(
|
|
stream=AirbyteStream(name=s.name, json_schema={}, supported_sync_modes=[SyncMode.full_refresh]),
|
|
sync_mode=SyncMode.full_refresh,
|
|
cursor_field=None,
|
|
destination_sync_mode=DestinationSyncMode.overwrite,
|
|
)
|
|
for s in streams
|
|
]
|
|
)
|
|
|
|
|
|
def _read_from_source(source, logger, config, catalog, state, expected_exception):
|
|
messages = []
|
|
try:
|
|
for m in source.read(logger, config, catalog, state):
|
|
messages.append(m)
|
|
except Exception as e:
|
|
if expected_exception:
|
|
assert isinstance(e, expected_exception)
|
|
return messages
|
|
|
|
|
|
def _verify_messages(expected_messages, messages_from_concurrent_source):
|
|
assert _compare(expected_messages, messages_from_concurrent_source)
|
|
|
|
|
|
def _compare(s, t):
|
|
# Use a compare method that does not require ordering or hashing the elements
|
|
# We can't rely on the ordering because of the multithreading
|
|
# AirbyteMessage does not implement __eq__ and __hash__
|
|
t = list(t)
|
|
try:
|
|
for elem in s:
|
|
t.remove(elem)
|
|
except ValueError:
|
|
print(f"ValueError: {elem}")
|
|
return False
|
|
return not t
|