1
0
mirror of synced 2025-12-26 14:02:10 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/sources/streams/test_stream_read.py

594 lines
24 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import logging
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Union
from unittest.mock import Mock
import pytest
from airbyte_cdk.models import (
AirbyteLogMessage,
AirbyteMessage,
AirbyteStateBlob,
AirbyteStateMessage,
AirbyteStateType,
AirbyteStream,
AirbyteStreamState,
ConfiguredAirbyteStream,
DestinationSyncMode,
Level,
StreamDescriptor,
SyncMode,
)
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, FinalStateCursor
from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer
from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.core import CheckpointMixin, StreamData
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig
from airbyte_cdk.sources.utils.slice_logger import DebugSliceLogger
_A_CURSOR_FIELD = ["NESTED", "CURSOR"]
_DEFAULT_INTERNAL_CONFIG = InternalConfig()
_STREAM_NAME = "STREAM"
_NO_STATE = None
class _MockStream(Stream):
def __init__(self, slice_to_records: Mapping[str, List[Mapping[str, Any]]], json_schema: Dict[str, Any] = None):
self._slice_to_records = slice_to_records
self._mocked_json_schema = json_schema or {}
@property
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
return None
def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
for partition in self._slice_to_records.keys():
yield {"partition_key": partition}
def read_records(
self,
sync_mode: SyncMode,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[StreamData]:
yield from self._slice_to_records[stream_slice["partition_key"]]
def get_json_schema(self) -> Mapping[str, Any]:
return self._mocked_json_schema
class _MockIncrementalStream(_MockStream, CheckpointMixin):
_state = {}
@property
def state(self) -> MutableMapping[str, Any]:
return self._state
@state.setter
def state(self, value: MutableMapping[str, Any]) -> None:
"""State setter, accept state serialized by state getter."""
self._state = value
@property
def cursor_field(self) -> Union[str, List[str]]:
return ["created_at"]
def read_records(
self,
sync_mode: SyncMode,
cursor_field: Optional[List[str]] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[StreamData]:
cursor = self.cursor_field[0]
for record in self._slice_to_records[stream_slice["partition_key"]]:
yield record
if cursor not in self._state:
self._state[cursor] = record.get(cursor)
else:
self._state[cursor] = max(self._state[cursor], record.get(cursor))
class MockConcurrentCursor(Cursor):
_state: MutableMapping[str, Any]
_message_repository: MessageRepository
def __init__(self, message_repository: MessageRepository):
self._message_repository = message_repository
self._state = {}
@property
def state(self) -> MutableMapping[str, Any]:
return self._state
def observe(self, record: Record) -> None:
partition = str(record.data.get("partition"))
timestamp = record.data.get("created_at")
self._state[partition] = {"created_at": timestamp}
def close_partition(self, partition: Partition) -> None:
self._message_repository.emit_message(
AirbyteMessage(
type=MessageType.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(**self._state),
),
),
)
)
def ensure_at_least_one_state_emitted(self) -> None:
pass
def _stream(slice_to_partition_mapping, slice_logger, logger, message_repository, json_schema=None):
return _MockStream(slice_to_partition_mapping, json_schema=json_schema)
def _concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, cursor: Optional[Cursor] = None):
stream = _stream(slice_to_partition_mapping, slice_logger, logger, message_repository)
cursor = cursor or FinalStateCursor(stream_name=stream.name, stream_namespace=stream.namespace, message_repository=message_repository)
source = Mock()
source._slice_logger = slice_logger
source.message_repository = message_repository
stream = StreamFacade.create_from_stream(stream, source, logger, _NO_STATE, cursor)
stream.logger.setLevel(logger.level)
return stream
def _incremental_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, timestamp):
stream = _MockIncrementalStream(slice_to_partition_mapping)
return stream
def _incremental_concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, cursor):
stream = _concurrent_stream(slice_to_partition_mapping, slice_logger, logger, message_repository, cursor)
return stream
def _stream_with_no_cursor_field(slice_to_partition_mapping, slice_logger, logger, message_repository):
def get_updated_state(current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> MutableMapping[str, Any]:
raise Exception("I shouldn't be invoked by a full_refresh stream")
mock_stream = _MockStream(slice_to_partition_mapping)
mock_stream.get_updated_state = get_updated_state
return mock_stream
@pytest.mark.parametrize(
"constructor",
[
pytest.param(_stream, id="synchronous_reader"),
pytest.param(_concurrent_stream, id="concurrent_reader"),
],
)
def test_full_refresh_read_a_single_slice_with_debug(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object.
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
records = [
{"id": 1, "partition_key": 1},
{"id": 2, "partition_key": 1},
]
slice_to_partition = {1: records}
slice_logger = DebugSliceLogger()
logger = _mock_logger(True)
message_repository = InMemoryMessageRepository(Level.DEBUG)
stream = constructor(slice_to_partition, slice_logger, logger, message_repository)
state_manager = ConnectorStateManager()
expected_records = [
AirbyteMessage(
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.INFO,
message='slice:{"partition_key": 1}',
),
),
*records,
]
# Synchronous streams emit a final state message to indicate that the stream has finished reading
# Concurrent streams don't emit their own state messages - the concurrent source observes the cursor
# and emits the state messages. Therefore, we can only check the value of the cursor's state at the end
if constructor == _stream:
expected_records.append(
AirbyteMessage(
type=MessageType.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(__ab_no_cursor_state_message=True),
),
),
),
)
actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
if constructor == _concurrent_stream:
assert hasattr(stream._cursor, "state")
assert str(stream._cursor.state) == "{'__ab_no_cursor_state_message': True}"
assert actual_records == expected_records
@pytest.mark.parametrize(
"constructor",
[
pytest.param(_stream, id="synchronous_reader"),
pytest.param(_concurrent_stream, id="concurrent_reader"),
],
)
def test_full_refresh_read_a_single_slice(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object.
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
message_repository = InMemoryMessageRepository(Level.INFO)
state_manager = ConnectorStateManager()
records = [
{"id": 1, "partition": 1},
{"id": 2, "partition": 1},
]
slice_to_partition = {1: records}
stream = constructor(slice_to_partition, slice_logger, logger, message_repository)
expected_records = [*records]
# Synchronous streams emit a final state message to indicate that the stream has finished reading
# Concurrent streams don't emit their own state messages - the concurrent source observes the cursor
# and emits the state messages. Therefore, we can only check the value of the cursor's state at the end
if constructor == _stream:
expected_records.append(
AirbyteMessage(
type=MessageType.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(__ab_no_cursor_state_message=True),
),
),
),
)
actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
if constructor == _concurrent_stream:
assert hasattr(stream._cursor, "state")
assert str(stream._cursor.state) == "{'__ab_no_cursor_state_message': True}"
assert actual_records == expected_records
@pytest.mark.parametrize(
"constructor",
[
pytest.param(_stream, id="synchronous_reader"),
pytest.param(_concurrent_stream, id="concurrent_reader"),
pytest.param(_stream_with_no_cursor_field, id="no_cursor_field"),
],
)
def test_full_refresh_read_two_slices(constructor):
# This test verifies that a concurrent stream adapted from a Stream behaves the same as the Stream object
# It is done by running the same test cases on both streams
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema={}),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
message_repository = InMemoryMessageRepository(Level.INFO)
state_manager = ConnectorStateManager()
records_partition_1 = [
{"id": 1, "partition": 1},
{"id": 2, "partition": 1},
]
records_partition_2 = [
{"id": 3, "partition": 2},
{"id": 4, "partition": 2},
]
slice_to_partition = {1: records_partition_1, 2: records_partition_2}
stream = constructor(slice_to_partition, slice_logger, logger, message_repository)
expected_records = [
*records_partition_1,
*records_partition_2,
]
# Synchronous streams emit a final state message to indicate that the stream has finished reading
# Concurrent streams don't emit their own state messages - the concurrent source observes the cursor
# and emits the state messages. Therefore, we can only check the value of the cursor's state at the end
if constructor == _stream or constructor == _stream_with_no_cursor_field:
expected_records.append(
AirbyteMessage(
type=MessageType.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="__mock_stream", namespace=None),
stream_state=AirbyteStateBlob(__ab_no_cursor_state_message=True),
),
),
),
)
actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
if constructor == _concurrent_stream:
assert hasattr(stream._cursor, "state")
assert str(stream._cursor.state) == "{'__ab_no_cursor_state_message': True}"
for record in expected_records:
assert record in actual_records
assert len(actual_records) == len(expected_records)
def test_incremental_read_two_slices():
# This test verifies that a stream running in incremental mode emits state messages correctly
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}),
sync_mode=SyncMode.incremental,
cursor_field=["created_at"],
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
message_repository = InMemoryMessageRepository(Level.INFO)
state_manager = ConnectorStateManager()
timestamp = "1708899427"
records_partition_1 = [
{"id": 1, "partition": 1, "created_at": "1708899000"},
{"id": 2, "partition": 1, "created_at": "1708899000"},
]
records_partition_2 = [
{"id": 3, "partition": 2, "created_at": "1708899400"},
{"id": 4, "partition": 2, "created_at": "1708899427"},
]
slice_to_partition = {1: records_partition_1, 2: records_partition_2}
stream = _incremental_stream(slice_to_partition, slice_logger, logger, message_repository, timestamp)
expected_records = [
*records_partition_1,
_create_state_message("__mock_incremental_stream", {"created_at": timestamp}),
*records_partition_2,
_create_state_message("__mock_incremental_stream", {"created_at": timestamp}),
]
actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
for record in expected_records:
assert record in actual_records
assert len(actual_records) == len(expected_records)
def test_concurrent_incremental_read_two_slices():
# This test verifies that an incremental concurrent stream manages state correctly for multiple slices syncing concurrently
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh, SyncMode.incremental], json_schema={}),
sync_mode=SyncMode.incremental,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
message_repository = InMemoryMessageRepository(Level.INFO)
state_manager = ConnectorStateManager()
slice_timestamp_1 = "1708850000"
slice_timestamp_2 = "1708950000"
cursor = MockConcurrentCursor(message_repository)
records_partition_1 = [
{"id": 1, "partition": 1, "created_at": "1708800000"},
{"id": 2, "partition": 1, "created_at": slice_timestamp_1},
]
records_partition_2 = [
{"id": 3, "partition": 2, "created_at": "1708900000"},
{"id": 4, "partition": 2, "created_at": slice_timestamp_2},
]
slice_to_partition = {1: records_partition_1, 2: records_partition_2}
stream = _incremental_concurrent_stream(slice_to_partition, slice_logger, logger, message_repository, cursor)
expected_records = [
*records_partition_1,
*records_partition_2,
]
expected_state = _create_state_message(
"__mock_stream", {"1": {"created_at": slice_timestamp_1}, "2": {"created_at": slice_timestamp_2}}
)
actual_records = _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
handler = ConcurrentReadProcessor(
[stream],
Mock(spec=PartitionEnqueuer),
Mock(spec=ThreadPoolManager),
logger,
slice_logger,
message_repository,
Mock(spec=PartitionReader),
)
for record in expected_records:
assert record in actual_records
# We need run on_record to update cursor with record cursor value
for record in actual_records:
list(handler.on_record(Record(record, Mock(spec=Partition, **{"stream_name.return_value": "__mock_stream"}))))
assert len(actual_records) == len(expected_records)
# We don't have a real source that reads from the message_repository for state, so we read from the queue directly to verify
# the cursor observed records correctly and updated partition states
mock_partition = Mock()
cursor.close_partition(mock_partition)
actual_state = [state for state in message_repository.consume_queue()]
assert len(actual_state) == 1
assert actual_state[0] == expected_state
def setup_stream_dependencies(configured_json_schema):
configured_stream = ConfiguredAirbyteStream(
stream=AirbyteStream(name="mock_stream", supported_sync_modes=[SyncMode.full_refresh], json_schema=configured_json_schema),
sync_mode=SyncMode.full_refresh,
destination_sync_mode=DestinationSyncMode.overwrite,
)
internal_config = InternalConfig()
logger = _mock_logger()
slice_logger = DebugSliceLogger()
message_repository = InMemoryMessageRepository(Level.INFO)
state_manager = ConnectorStateManager()
return configured_stream, internal_config, logger, slice_logger, message_repository, state_manager
def test_configured_json_schema():
current_json_schema = {
"$schema": "https://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"id": {"type": ["null", "number"]},
"name": {"type": ["null", "string"]},
},
}
configured_stream, internal_config, logger, slice_logger, message_repository, state_manager = setup_stream_dependencies(
current_json_schema
)
records = [
{"id": 1, "partition": 1},
{"id": 2, "partition": 1},
]
slice_to_partition = {1: records}
stream = _stream(slice_to_partition, slice_logger, logger, message_repository, json_schema=current_json_schema)
assert not stream.configured_json_schema
_read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
assert stream.configured_json_schema == current_json_schema
def test_configured_json_schema_with_invalid_properties():
"""
Configured Schemas can have very old fields, so we need to housekeeping ourselves.
The purpose of this test in ensure that correct cleanup occurs when configured catalog schema is compared with current stream schema.
"""
old_user_insights = "old_user_insights"
old_feature_info = "old_feature_info"
configured_json_schema = {
"$schema": "https://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"id": {"type": ["null", "number"]},
"name": {"type": ["null", "string"]},
"cost_per_conversation": {"type": ["null", "string"]},
old_user_insights: {"type": ["null", "string"]},
old_feature_info: {"type": ["null", "string"]},
},
}
# stream schema is updated e.g. some fields in new api version are deprecated
stream_schema = deepcopy(configured_json_schema)
del stream_schema["properties"][old_user_insights]
del stream_schema["properties"][old_feature_info]
configured_stream, internal_config, logger, slice_logger, message_repository, state_manager = setup_stream_dependencies(
configured_json_schema
)
records = [
{"id": 1, "partition": 1},
{"id": 2, "partition": 1},
]
slice_to_partition = {1: records}
stream = _stream(slice_to_partition, slice_logger, logger, message_repository, json_schema=stream_schema)
assert not stream.configured_json_schema
_read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config)
assert stream.configured_json_schema != configured_json_schema
configured_json_schema_properties = stream.configured_json_schema["properties"]
assert old_user_insights not in configured_json_schema_properties
assert old_feature_info not in configured_json_schema_properties
for stream_schema_property in stream_schema["properties"]:
assert (
stream_schema_property in configured_json_schema_properties
), f"Stream schema property: {stream_schema_property} missing in configured schema"
assert stream_schema["properties"][stream_schema_property] == configured_json_schema_properties[stream_schema_property]
def _read(stream, configured_stream, logger, slice_logger, message_repository, state_manager, internal_config):
records = []
for record in stream.read(configured_stream, logger, slice_logger, {}, state_manager, internal_config):
for message in message_repository.consume_queue():
records.append(message)
records.append(record)
return records
def _mock_partition_generator(name: str, slices, records_per_partition, *, available=True, debug_log=False):
stream = Mock()
stream.name = name
stream.get_json_schema.return_value = {}
stream.generate_partitions.return_value = iter(slices)
stream.read_records.side_effect = [iter(records) for records in records_per_partition]
stream.logger.isEnabledFor.return_value = debug_log
if available:
stream.check_availability.return_value = True, None
else:
stream.check_availability.return_value = False, "A reason why the stream is unavailable"
return stream
def _mock_logger(enabled_for_debug=False):
logger = Mock()
logger.isEnabledFor.return_value = enabled_for_debug
logger.level = logging.DEBUG if enabled_for_debug else logging.INFO
return logger
def _create_state_message(stream: str, state: Mapping[str, Any]) -> AirbyteMessage:
return AirbyteMessage(
type=MessageType.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name=stream, namespace=None),
stream_state=AirbyteStateBlob(**state),
),
),
)