[RFR for API Sources] Add SubstreamResumableFullRefreshCursor to the Python CDK (#42429)
This commit is contained in:
@@ -2,34 +2,16 @@
|
||||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
||||
#
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Union
|
||||
|
||||
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
|
||||
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
|
||||
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
|
||||
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
|
||||
from airbyte_cdk.utils import AirbyteTracedException
|
||||
from airbyte_protocol.models import FailureType
|
||||
|
||||
|
||||
class PerPartitionKeySerializer:
|
||||
"""
|
||||
We are concerned of the performance of looping through the `states` list and evaluating equality on the partition. To reduce this
|
||||
concern, we wanted to use dictionaries to map `partition -> cursor`. However, partitions are dict and dict can't be used as dict keys
|
||||
since they are not hashable. By creating json string using the dict, we can have a use the dict as a key to the dict since strings are
|
||||
hashable.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def to_partition_key(to_serialize: Any) -> str:
|
||||
# separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value
|
||||
return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True)
|
||||
|
||||
@staticmethod
|
||||
def to_partition(to_deserialize: Any) -> Mapping[str, Any]:
|
||||
return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any
|
||||
|
||||
|
||||
class CursorFactory:
|
||||
def __init__(self, create_function: Callable[[], DeclarativeCursor]):
|
||||
self._create_function = create_function
|
||||
|
||||
@@ -143,7 +143,7 @@ class CursorBasedCheckpointReader(CheckpointReader):
|
||||
# current_slice is None represents the first time we are iterating over a stream's slices. The first slice to
|
||||
# sync not been assigned yet and must first be read from the iterator
|
||||
next_slice = self.read_and_convert_slice()
|
||||
state_for_slice = self._cursor.select_state(self.current_slice)
|
||||
state_for_slice = self._cursor.select_state(next_slice)
|
||||
if state_for_slice == FULL_REFRESH_COMPLETE_STATE:
|
||||
# Skip every slice that already has the terminal complete value indicating that a previous attempt
|
||||
# successfully synced the slice
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
class PerPartitionKeySerializer:
|
||||
"""
|
||||
We are concerned of the performance of looping through the `states` list and evaluating equality on the partition. To reduce this
|
||||
concern, we wanted to use dictionaries to map `partition -> cursor`. However, partitions are dict and dict can't be used as dict keys
|
||||
since they are not hashable. By creating json string using the dict, we can have a use the dict as a key to the dict since strings are
|
||||
hashable.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def to_partition_key(to_serialize: Any) -> str:
|
||||
# separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value
|
||||
return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True)
|
||||
|
||||
@staticmethod
|
||||
def to_partition(to_deserialize: Any) -> Mapping[str, Any]:
|
||||
return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any
|
||||
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping, MutableMapping, Optional
|
||||
|
||||
from airbyte_cdk.sources.streams.checkpoint import Cursor
|
||||
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
|
||||
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
|
||||
from airbyte_cdk.utils import AirbyteTracedException
|
||||
from airbyte_protocol.models import FailureType
|
||||
|
||||
FULL_REFRESH_COMPLETE_STATE: Mapping[str, Any] = {"__ab_full_refresh_sync_complete": True}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubstreamResumableFullRefreshCursor(Cursor):
|
||||
def __init__(self) -> None:
|
||||
self._per_partition_state: MutableMapping[str, StreamState] = {}
|
||||
self._partition_serializer = PerPartitionKeySerializer()
|
||||
|
||||
def get_stream_state(self) -> StreamState:
|
||||
states = []
|
||||
for partition_tuple, partition_state in self._per_partition_state.items():
|
||||
states.append(
|
||||
{
|
||||
"partition": self._to_dict(partition_tuple),
|
||||
"cursor": partition_state,
|
||||
}
|
||||
)
|
||||
state: dict[str, Any] = {"states": states}
|
||||
|
||||
return state
|
||||
|
||||
def set_initial_state(self, stream_state: StreamState) -> None:
|
||||
"""
|
||||
Set the initial state for the cursors.
|
||||
|
||||
This method initializes the state for each partition cursor using the provided stream state.
|
||||
If a partition state is provided in the stream state, it will update the corresponding partition cursor with this state.
|
||||
|
||||
To simplify processing and state management, we do not maintain the checkpointed state of the parent partitions.
|
||||
Instead, we are tracking whether a parent has already successfully synced on a prior attempt and skipping over it
|
||||
allowing the sync to continue making progress. And this works for RFR because the platform will dispose of this
|
||||
state on the next sync job.
|
||||
|
||||
Args:
|
||||
stream_state (StreamState): The state of the streams to be set. The format of the stream state should be:
|
||||
{
|
||||
"states": [
|
||||
{
|
||||
"partition": {
|
||||
"partition_key": "value_0"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"partition_key": "value_1"
|
||||
},
|
||||
"cursor": {},
|
||||
},
|
||||
]
|
||||
}
|
||||
"""
|
||||
if not stream_state:
|
||||
return
|
||||
|
||||
if "states" not in stream_state:
|
||||
raise AirbyteTracedException(
|
||||
internal_message=f"Could not sync parse the following state: {stream_state}",
|
||||
message="The state for is format invalid. Validate that the migration steps included a reset and that it was performed "
|
||||
"properly. Otherwise, please contact Airbyte support.",
|
||||
failure_type=FailureType.config_error,
|
||||
)
|
||||
|
||||
for state in stream_state["states"]:
|
||||
self._per_partition_state[self._to_partition_key(state["partition"])] = state["cursor"]
|
||||
|
||||
def observe(self, stream_slice: StreamSlice, record: Record) -> None:
|
||||
"""
|
||||
Substream resumable full refresh manages state by closing the slice after syncing a parent so observe is not used.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
|
||||
self._per_partition_state[self._to_partition_key(stream_slice.partition)] = FULL_REFRESH_COMPLETE_STATE
|
||||
|
||||
def should_be_synced(self, record: Record) -> bool:
|
||||
"""
|
||||
Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages
|
||||
that don't have filterable bounds. We should always return them.
|
||||
"""
|
||||
return True
|
||||
|
||||
def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
|
||||
"""
|
||||
RFR record don't have ordering to be compared between one another.
|
||||
"""
|
||||
return False
|
||||
|
||||
def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
|
||||
if not stream_slice:
|
||||
raise ValueError("A partition needs to be provided in order to extract a state")
|
||||
|
||||
return self._per_partition_state.get(self._to_partition_key(stream_slice.partition))
|
||||
|
||||
def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
|
||||
return self._partition_serializer.to_partition_key(partition)
|
||||
|
||||
def _to_dict(self, partition_key: str) -> Mapping[str, Any]:
|
||||
return self._partition_serializer.to_partition(partition_key)
|
||||
@@ -15,6 +15,7 @@ from airbyte_cdk.sources.message.repository import InMemoryMessageRepository
|
||||
from airbyte_cdk.sources.streams.call_rate import APIBudget
|
||||
from airbyte_cdk.sources.streams.checkpoint.cursor import Cursor
|
||||
from airbyte_cdk.sources.streams.checkpoint.resumable_full_refresh_cursor import ResumableFullRefreshCursor
|
||||
from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor
|
||||
from airbyte_cdk.sources.streams.core import CheckpointMixin, Stream, StreamData
|
||||
from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, ErrorHandler, HttpStatusErrorHandler
|
||||
from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction
|
||||
@@ -49,13 +50,13 @@ class HttpStream(Stream, CheckpointMixin, ABC):
|
||||
message_repository=InMemoryMessageRepository(),
|
||||
)
|
||||
|
||||
# Stream's with at least one cursor_field is incremental and thus a superior sync than RFR. We also cannot
|
||||
# assume that streams overriding read_records() will call the parent implementation
|
||||
# which
|
||||
if len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records:
|
||||
# There are three conditions that dictate if RFR should automatically be applied to a stream
|
||||
# 1. Streams that explicitly initialize their own cursor should defer to it and not automatically apply RFR
|
||||
# 2. Streams with at least one cursor_field are incremental and thus a superior sync to RFR.
|
||||
# 3. Streams overriding read_records() do not guarantee that they will call the parent implementation which can perform
|
||||
# per-page checkpointing so RFR is only supported if a stream use the default `HttpStream.read_records()` method
|
||||
if not self.cursor and len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records:
|
||||
self.cursor = ResumableFullRefreshCursor()
|
||||
else:
|
||||
self.cursor = None
|
||||
|
||||
@property
|
||||
def exit_on_rate_limit(self) -> bool:
|
||||
@@ -360,11 +361,12 @@ class HttpStream(Stream, CheckpointMixin, ABC):
|
||||
def get_cursor(self) -> Optional[Cursor]:
|
||||
# I don't love that this is semi-stateful but not sure what else to do. We don't know exactly what type of cursor to
|
||||
# instantiate when creating the class. We can make a few assumptions like if there is a cursor_field which implies
|
||||
# incremental, but we don't know until runtime if this is a substream. This makes for a slightly awkward getter that
|
||||
# is stateful because `has_multiple_slices` can be reassigned at runtime after we inspect slices.
|
||||
# This code temporarily gates RFR off for substreams by always removing the RFR cursor if there are multiple slices
|
||||
if self.has_multiple_slices:
|
||||
return None
|
||||
# incremental, but we don't know until runtime if this is a substream. Ideally, a stream should explicitly define
|
||||
# its cursor, but because we're trying to automatically apply RFR we're stuck with this logic where we replace the
|
||||
# cursor at runtime once we detect this is a substream based on self.has_multiple_slices being reassigned
|
||||
if self.has_multiple_slices and isinstance(self.cursor, ResumableFullRefreshCursor):
|
||||
self.cursor = SubstreamResumableFullRefreshCursor()
|
||||
return self.cursor
|
||||
else:
|
||||
return self.cursor
|
||||
|
||||
@@ -376,6 +378,8 @@ class HttpStream(Stream, CheckpointMixin, ABC):
|
||||
stream_slice: Optional[Mapping[str, Any]] = None,
|
||||
stream_state: Optional[Mapping[str, Any]] = None,
|
||||
) -> Iterable[StreamData]:
|
||||
partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice)
|
||||
|
||||
stream_state = stream_state or {}
|
||||
pagination_complete = False
|
||||
next_page_token = None
|
||||
@@ -387,6 +391,12 @@ class HttpStream(Stream, CheckpointMixin, ABC):
|
||||
if not next_page_token:
|
||||
pagination_complete = True
|
||||
|
||||
cursor = self.get_cursor()
|
||||
if cursor and isinstance(cursor, SubstreamResumableFullRefreshCursor):
|
||||
# Substreams checkpoint state by marking an entire parent partition as completed so that on the subsequent attempt
|
||||
# after a failure, completed parents are skipped and the sync can make progress
|
||||
cursor.close_slice(StreamSlice(cursor_slice={}, partition=partition))
|
||||
|
||||
# Always return an empty generator just in case no records were ever yielded
|
||||
yield from []
|
||||
|
||||
@@ -473,7 +483,14 @@ class HttpSubStream(HttpStream, ABC):
|
||||
super().__init__(**kwargs)
|
||||
self.parent = parent
|
||||
self.has_multiple_slices = True # Substreams are based on parent records which implies there are multiple slices
|
||||
self.cursor = None # todo: HttpSubStream does not currently support RFR so this is set to None
|
||||
|
||||
# There are three conditions that dictate if RFR should automatically be applied to a stream
|
||||
# 1. Streams that explicitly initialize their own cursor should defer to it and not automatically apply RFR
|
||||
# 2. Streams with at least one cursor_field are incremental and thus a superior sync to RFR.
|
||||
# 3. Streams overriding read_records() do not guarantee that they will call the parent implementation which can perform
|
||||
# per-page checkpointing so RFR is only supported if a stream use the default `HttpStream.read_records()` method
|
||||
if not self.cursor and len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records:
|
||||
self.cursor = SubstreamResumableFullRefreshCursor()
|
||||
|
||||
def stream_slices(
|
||||
self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None
|
||||
|
||||
@@ -246,12 +246,21 @@ class FullRefreshStreamTest(TestCase):
|
||||
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0
|
||||
|
||||
@HttpMocker()
|
||||
def test_full_refresh_with_slices(self, http_mocker):
|
||||
def test_substream_resumable_full_refresh_with_parent_slices(self, http_mocker):
|
||||
start_datetime = _NOW - timedelta(days=14)
|
||||
config = {
|
||||
"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
}
|
||||
|
||||
expected_first_substream_per_stream_state = [
|
||||
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
|
||||
]
|
||||
|
||||
expected_second_substream_per_stream_state = [
|
||||
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
|
||||
{'partition': {'divide_category': 'mentats'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
|
||||
]
|
||||
|
||||
http_mocker.get(
|
||||
_create_dividers_request().with_category("dukes").build(),
|
||||
_create_response().with_record(record=_create_record("dividers")).with_record(record=_create_record("dividers")).build(),
|
||||
@@ -267,11 +276,12 @@ class FullRefreshStreamTest(TestCase):
|
||||
|
||||
assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("dividers"))
|
||||
assert len(actual_messages.records) == 4
|
||||
assert len(actual_messages.state_messages) == 1
|
||||
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
|
||||
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "dividers"
|
||||
assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(__ab_no_cursor_state_message=True)
|
||||
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 4.0
|
||||
assert len(actual_messages.state_messages) == 2
|
||||
validate_message_order([Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
|
||||
assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(states=expected_first_substream_per_stream_state)
|
||||
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0
|
||||
assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(states=expected_second_substream_per_stream_state)
|
||||
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0
|
||||
|
||||
|
||||
@freezegun.freeze_time(_NOW)
|
||||
@@ -428,6 +438,15 @@ class MultipleStreamTest(TestCase):
|
||||
"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
}
|
||||
|
||||
expected_first_substream_per_stream_state = [
|
||||
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
|
||||
]
|
||||
|
||||
expected_second_substream_per_stream_state = [
|
||||
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
|
||||
{'partition': {'divide_category': 'mentats'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
|
||||
]
|
||||
|
||||
# Mocks for users full refresh stream
|
||||
http_mocker.get(
|
||||
_create_users_request().build(),
|
||||
@@ -466,7 +485,7 @@ class MultipleStreamTest(TestCase):
|
||||
assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("dividers"))
|
||||
|
||||
assert len(actual_messages.records) == 11
|
||||
assert len(actual_messages.state_messages) == 4
|
||||
assert len(actual_messages.state_messages) == 5
|
||||
validate_message_order([
|
||||
Type.RECORD,
|
||||
Type.RECORD,
|
||||
@@ -480,6 +499,7 @@ class MultipleStreamTest(TestCase):
|
||||
Type.STATE,
|
||||
Type.RECORD,
|
||||
Type.RECORD,
|
||||
Type.STATE,
|
||||
Type.RECORD,
|
||||
Type.RECORD,
|
||||
Type.STATE
|
||||
@@ -494,5 +514,8 @@ class MultipleStreamTest(TestCase):
|
||||
assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_1)
|
||||
assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2.0
|
||||
assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "dividers"
|
||||
assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob(__ab_no_cursor_state_message=True)
|
||||
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 4.0
|
||||
assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob(states=expected_first_substream_per_stream_state)
|
||||
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 2.0
|
||||
assert actual_messages.state_messages[4].state.stream.stream_descriptor.name == "dividers"
|
||||
assert actual_messages.state_messages[4].state.stream.stream_state == AirbyteStateBlob(states=expected_second_substream_per_stream_state)
|
||||
assert actual_messages.state_messages[4].state.sourceStats.recordCount == 2.0
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
|
||||
|
||||
import pytest
|
||||
from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor
|
||||
from airbyte_cdk.sources.types import StreamSlice
|
||||
from airbyte_cdk.utils import AirbyteTracedException
|
||||
|
||||
|
||||
def test_substream_resumable_full_refresh_cursor():
|
||||
"""
|
||||
Test scenario where a set of parent record partitions are iterated over by the cursor resulting in a completed sync
|
||||
"""
|
||||
expected_starting_state = {"states": []}
|
||||
|
||||
expected_ending_state = {
|
||||
"states": [
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "kousei_arima"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "kaori_miyazono"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
partitions = [
|
||||
StreamSlice(partition={"musician_id": "kousei_arima"}, cursor_slice={}),
|
||||
StreamSlice(partition={"musician_id": "kaori_miyazono"}, cursor_slice={}),
|
||||
]
|
||||
|
||||
cursor = SubstreamResumableFullRefreshCursor()
|
||||
|
||||
starting_state = cursor.get_stream_state()
|
||||
assert starting_state == expected_starting_state
|
||||
|
||||
for partition in partitions:
|
||||
partition_state = cursor.select_state(partition)
|
||||
assert partition_state is None
|
||||
cursor.close_slice(partition)
|
||||
|
||||
ending_state = cursor.get_stream_state()
|
||||
assert ending_state == expected_ending_state
|
||||
|
||||
|
||||
def test_substream_resumable_full_refresh_cursor_with_state():
|
||||
"""
|
||||
Test scenario where a set of parent record partitions are iterated over and previously completed parents are skipped
|
||||
"""
|
||||
initial_state = {
|
||||
"states": [
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "kousei_arima"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "kaori_miyazono"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "takeshi_aiza"
|
||||
},
|
||||
"cursor": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
expected_ending_state = {
|
||||
"states": [
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "kousei_arima"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "kaori_miyazono"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "takeshi_aiza"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
},
|
||||
{
|
||||
"partition": {
|
||||
"musician_id": "emi_igawa"
|
||||
},
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
partitions = [
|
||||
StreamSlice(partition={"musician_id": "kousei_arima"}, cursor_slice={}),
|
||||
StreamSlice(partition={"musician_id": "kaori_miyazono"}, cursor_slice={}),
|
||||
StreamSlice(partition={"musician_id": "takeshi_aiza"}, cursor_slice={}),
|
||||
StreamSlice(partition={"musician_id": "emi_igawa"}, cursor_slice={}),
|
||||
]
|
||||
|
||||
cursor = SubstreamResumableFullRefreshCursor()
|
||||
cursor.set_initial_state(initial_state)
|
||||
|
||||
starting_state = cursor.get_stream_state()
|
||||
assert starting_state == initial_state
|
||||
|
||||
for i, partition in enumerate(partitions):
|
||||
partition_state = cursor.select_state(partition)
|
||||
if i < len(initial_state.get("states")):
|
||||
assert partition_state == initial_state.get("states")[i].get("cursor")
|
||||
else:
|
||||
assert partition_state is None
|
||||
cursor.close_slice(partition)
|
||||
|
||||
ending_state = cursor.get_stream_state()
|
||||
assert ending_state == expected_ending_state
|
||||
|
||||
|
||||
def test_set_initial_state_invalid_incoming_state():
|
||||
bad_state = {
|
||||
"next_page_token": 2
|
||||
}
|
||||
cursor = SubstreamResumableFullRefreshCursor()
|
||||
|
||||
with pytest.raises(AirbyteTracedException):
|
||||
cursor.set_initial_state(bad_state)
|
||||
|
||||
|
||||
def test_select_state_without_slice():
|
||||
cursor = SubstreamResumableFullRefreshCursor()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cursor.select_state()
|
||||
@@ -6,7 +6,7 @@
|
||||
import json
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union
|
||||
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -14,6 +14,7 @@ import requests
|
||||
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type
|
||||
from airbyte_cdk.sources.streams import CheckpointMixin
|
||||
from airbyte_cdk.sources.streams.checkpoint import ResumableFullRefreshCursor
|
||||
from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor
|
||||
from airbyte_cdk.sources.streams.core import StreamData
|
||||
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
|
||||
from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler, HttpStatusErrorHandler
|
||||
@@ -1067,6 +1068,319 @@ def test_resumable_full_refresh_legacy_stream_slice(mocker):
|
||||
assert records == expected
|
||||
|
||||
|
||||
class StubSubstreamResumableFullRefreshStream(HttpSubStream, CheckpointMixin):
|
||||
primary_key = "primary_key"
|
||||
|
||||
counter = 0
|
||||
|
||||
def __init__(self, parent: HttpStream, partition_id_to_child_records: Mapping[str, List[Mapping[str, Any]]]):
|
||||
super().__init__(parent=parent)
|
||||
self._partition_id_to_child_records = partition_id_to_child_records
|
||||
# self._state: MutableMapping[str, Any] = {}
|
||||
|
||||
@property
|
||||
def url_base(self) -> str:
|
||||
return "https://airbyte.io/api/v1"
|
||||
|
||||
def path(self, *, stream_state: Optional[Mapping[str, Any]] = None, stream_slice: Optional[Mapping[str, Any]] = None,
|
||||
next_page_token: Optional[Mapping[str, Any]] = None) -> str:
|
||||
return f"/parents/{stream_slice.get('parent_id')}/children"
|
||||
|
||||
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
|
||||
return None
|
||||
|
||||
# def read_records(
|
||||
# self,
|
||||
# sync_mode: SyncMode,
|
||||
# cursor_field: Optional[List[str]] = None,
|
||||
# stream_slice: Optional[Mapping[str, Any]] = None,
|
||||
# stream_state: Optional[Mapping[str, Any]] = None,
|
||||
# ) -> Iterable[StreamData]:
|
||||
# page_number = self.state.get("page") or 1
|
||||
# yield from self._record_pages[page_number - 1]
|
||||
#
|
||||
# if page_number < len(self._record_pages):
|
||||
# self.state = {"page": page_number + 1}
|
||||
# else:
|
||||
# self.state = {"__ab_full_refresh_sync_complete": True}
|
||||
|
||||
def _fetch_next_page(
|
||||
self,
|
||||
stream_slice: Optional[Mapping[str, Any]] = None,
|
||||
stream_state: Optional[Mapping[str, Any]] = None,
|
||||
next_page_token: Optional[Mapping[str, Any]] = None,
|
||||
) -> Tuple[requests.PreparedRequest, requests.Response]:
|
||||
return requests.PreparedRequest(), requests.Response()
|
||||
|
||||
def parse_response(
|
||||
self,
|
||||
response: requests.Response,
|
||||
*,
|
||||
stream_state: Mapping[str, Any],
|
||||
stream_slice: Optional[Mapping[str, Any]] = None,
|
||||
next_page_token: Optional[Mapping[str, Any]] = None
|
||||
) -> Iterable[Mapping[str, Any]]:
|
||||
partition_id = stream_slice.get("parent").get("parent_id")
|
||||
if partition_id in self._partition_id_to_child_records:
|
||||
yield from self._partition_id_to_child_records.get(partition_id)
|
||||
else:
|
||||
raise Exception(f"No mocked output supplied for parent partition_id: {partition_id}")
|
||||
|
||||
def get_json_schema(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def test_substream_resumable_full_refresh_read_from_start(mocker):
|
||||
"""
|
||||
Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one
|
||||
page per invocation and emits state afterward.
|
||||
parses over
|
||||
"""
|
||||
|
||||
parent_records = [
|
||||
{"parent_id": "100", "name": "christopher_nolan"},
|
||||
{"parent_id": "101", "name": "celine_song"},
|
||||
{"parent_id": "102", "name": "david_fincher"},
|
||||
|
||||
]
|
||||
parent_stream = StubParentHttpStream(records=parent_records)
|
||||
|
||||
parents_to_children_records = {
|
||||
"100": [{"id": "a200", "parent_id": "100", "film": "interstellar"}, {"id": "a201", "parent_id": "100", "film": "oppenheimer"}, {"id": "a202", "parent_id": "100", "film": "inception"}],
|
||||
"101": [{"id": "b200", "parent_id": "101", "film": "past_lives"}, {"id": "b201", "parent_id": "101", "film": "materialists"}],
|
||||
"102": [{"id": "c200", "parent_id": "102", "film": "the_social_network"}, {"id": "c201", "parent_id": "102", "film": "gone_girl"}, {"id": "c202", "parent_id": "102", "film": "the_curious_case_of_benjamin_button"}],
|
||||
}
|
||||
stream = StubSubstreamResumableFullRefreshStream(parent=parent_stream, partition_id_to_child_records=parents_to_children_records)
|
||||
|
||||
blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway.
|
||||
mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response))
|
||||
|
||||
# Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect
|
||||
mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages"))
|
||||
|
||||
checkpoint_reader = stream._get_checkpoint_reader(
|
||||
cursor_field=[], logger=logging.getLogger("airbyte"), sync_mode=SyncMode.full_refresh, stream_state={}
|
||||
)
|
||||
next_stream_slice = checkpoint_reader.next()
|
||||
records = []
|
||||
|
||||
expected_checkpoints = [
|
||||
{
|
||||
"states": [
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "christopher_nolan", "parent_id": "100"}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"states": [
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "christopher_nolan", "parent_id": "100"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "celine_song", "parent_id": "101"}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"states": [
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "christopher_nolan", "parent_id": "100"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "celine_song", "parent_id": "101"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "david_fincher", "parent_id": "102"}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
i = 0
|
||||
while next_stream_slice is not None:
|
||||
next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice))
|
||||
records.extend(next_records)
|
||||
checkpoint_reader.observe(stream.state)
|
||||
assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i]
|
||||
next_stream_slice = checkpoint_reader.next()
|
||||
i += 1
|
||||
|
||||
assert getattr(stream, "_read_pages").call_count == 3
|
||||
|
||||
expected = [
|
||||
{
|
||||
"film": "interstellar",
|
||||
"id": "a200",
|
||||
"parent_id": "100"
|
||||
},
|
||||
{
|
||||
"film": "oppenheimer",
|
||||
"id": "a201",
|
||||
"parent_id": "100"
|
||||
},
|
||||
{
|
||||
"film": "inception",
|
||||
"id": "a202",
|
||||
"parent_id": "100"
|
||||
},
|
||||
{
|
||||
"film": "past_lives",
|
||||
"id": "b200",
|
||||
"parent_id": "101"
|
||||
},
|
||||
{
|
||||
"film": "materialists",
|
||||
"id": "b201",
|
||||
"parent_id": "101"
|
||||
},
|
||||
{
|
||||
"film": "the_social_network",
|
||||
"id": "c200",
|
||||
"parent_id": "102"
|
||||
},
|
||||
{
|
||||
"film": "gone_girl",
|
||||
"id": "c201",
|
||||
"parent_id": "102"
|
||||
},
|
||||
{
|
||||
"film": "the_curious_case_of_benjamin_button",
|
||||
"id": "c202",
|
||||
"parent_id": "102"
|
||||
}
|
||||
]
|
||||
|
||||
assert records == expected
|
||||
|
||||
|
||||
def test_substream_resumable_full_refresh_read_from_state(mocker):
|
||||
"""
|
||||
Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one
|
||||
page per invocation and emits state afterward.
|
||||
parses over
|
||||
"""
|
||||
|
||||
parent_records = [
|
||||
{"parent_id": "100", "name": "christopher_nolan"},
|
||||
{"parent_id": "101", "name": "celine_song"},
|
||||
|
||||
]
|
||||
parent_stream = StubParentHttpStream(records=parent_records)
|
||||
|
||||
parents_to_children_records = {
|
||||
"100": [{"id": "a200", "parent_id": "100", "film": "interstellar"}, {"id": "a201", "parent_id": "100", "film": "oppenheimer"},
|
||||
{"id": "a202", "parent_id": "100", "film": "inception"}],
|
||||
"101": [{"id": "b200", "parent_id": "101", "film": "past_lives"}, {"id": "b201", "parent_id": "101", "film": "materialists"}],
|
||||
}
|
||||
stream = StubSubstreamResumableFullRefreshStream(parent=parent_stream, partition_id_to_child_records=parents_to_children_records)
|
||||
|
||||
blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway.
|
||||
mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response))
|
||||
|
||||
# Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect
|
||||
mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages"))
|
||||
|
||||
checkpoint_reader = stream._get_checkpoint_reader(
|
||||
cursor_field=[],
|
||||
logger=logging.getLogger("airbyte"),
|
||||
sync_mode=SyncMode.full_refresh,
|
||||
stream_state={
|
||||
"states": [
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "christopher_nolan", "parent_id": "100"}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
next_stream_slice = checkpoint_reader.next()
|
||||
records = []
|
||||
|
||||
expected_checkpoints = [
|
||||
{
|
||||
"states": [
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "christopher_nolan", "parent_id": "100"}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cursor": {
|
||||
"__ab_full_refresh_sync_complete": True
|
||||
},
|
||||
"partition": {
|
||||
"parent": {"name": "celine_song", "parent_id": "101"}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
i = 0
|
||||
while next_stream_slice is not None:
|
||||
next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice))
|
||||
records.extend(next_records)
|
||||
checkpoint_reader.observe(stream.state)
|
||||
assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i]
|
||||
next_stream_slice = checkpoint_reader.next()
|
||||
i += 1
|
||||
|
||||
assert getattr(stream, "_read_pages").call_count == 1
|
||||
|
||||
expected = [
|
||||
{
|
||||
"film": "past_lives",
|
||||
"id": "b200",
|
||||
"parent_id": "101"
|
||||
},
|
||||
{
|
||||
"film": "materialists",
|
||||
"id": "b201",
|
||||
"parent_id": "101"
|
||||
},
|
||||
]
|
||||
|
||||
assert records == expected
|
||||
|
||||
|
||||
class StubWithCursorFields(StubBasicReadHttpStream):
|
||||
def __init__(self, has_multiple_slices: bool, set_cursor_field: List[str], deduplicate_query_params: bool = False, **kwargs):
|
||||
self.has_multiple_slices = has_multiple_slices
|
||||
@@ -1083,8 +1397,8 @@ class StubWithCursorFields(StubBasicReadHttpStream):
|
||||
[
|
||||
pytest.param([], False, ResumableFullRefreshCursor(), id="test_stream_supports_resumable_full_refresh_cursor"),
|
||||
pytest.param(["updated_at"], False, None, id="test_incremental_stream_does_not_use_cursor"),
|
||||
pytest.param([], True, None, id="test_substream_does_not_use_cursor"),
|
||||
pytest.param(["updated_at"], True, None, id="test_incremental_substream_does_not_use_cursor"),
|
||||
pytest.param([], True, SubstreamResumableFullRefreshCursor(), id="test_full_refresh_substream_automatically_applies_substream_resumable_full_refresh_cursor"),
|
||||
]
|
||||
)
|
||||
def test_get_cursor(cursor_field, is_substream, expected_cursor):
|
||||
|
||||
Reference in New Issue
Block a user