1
0
mirror of synced 2025-12-25 02:09:19 -05:00

[RFR for API Sources] Add SubstreamResumableFullRefreshCursor to the Python CDK (#42429)

This commit is contained in:
Brian Lai
2024-08-01 16:39:14 -04:00
committed by GitHub
parent 57a2623283
commit 197cb810b0
8 changed files with 676 additions and 43 deletions

View File

@@ -2,34 +2,16 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import json
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Union
from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor
from airbyte_cdk.sources.declarative.partition_routers.partition_router import PartitionRouter
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_protocol.models import FailureType
class PerPartitionKeySerializer:
"""
We are concerned of the performance of looping through the `states` list and evaluating equality on the partition. To reduce this
concern, we wanted to use dictionaries to map `partition -> cursor`. However, partitions are dict and dict can't be used as dict keys
since they are not hashable. By creating json string using the dict, we can have a use the dict as a key to the dict since strings are
hashable.
"""
@staticmethod
def to_partition_key(to_serialize: Any) -> str:
# separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value
return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True)
@staticmethod
def to_partition(to_deserialize: Any) -> Mapping[str, Any]:
return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any
class CursorFactory:
def __init__(self, create_function: Callable[[], DeclarativeCursor]):
self._create_function = create_function

View File

@@ -143,7 +143,7 @@ class CursorBasedCheckpointReader(CheckpointReader):
# current_slice is None represents the first time we are iterating over a stream's slices. The first slice to
# sync not been assigned yet and must first be read from the iterator
next_slice = self.read_and_convert_slice()
state_for_slice = self._cursor.select_state(self.current_slice)
state_for_slice = self._cursor.select_state(next_slice)
if state_for_slice == FULL_REFRESH_COMPLETE_STATE:
# Skip every slice that already has the terminal complete value indicating that a previous attempt
# successfully synced the slice

View File

@@ -0,0 +1,22 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
import json
from typing import Any, Mapping
class PerPartitionKeySerializer:
"""
We are concerned of the performance of looping through the `states` list and evaluating equality on the partition. To reduce this
concern, we wanted to use dictionaries to map `partition -> cursor`. However, partitions are dict and dict can't be used as dict keys
since they are not hashable. By creating json string using the dict, we can have a use the dict as a key to the dict since strings are
hashable.
"""
@staticmethod
def to_partition_key(to_serialize: Any) -> str:
# separators have changed in Python 3.4. To avoid being impacted by further change, we explicitly specify our own value
return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True)
@staticmethod
def to_partition(to_deserialize: Any) -> Mapping[str, Any]:
return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any

View File

@@ -0,0 +1,113 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
from dataclasses import dataclass
from typing import Any, Mapping, MutableMapping, Optional
from airbyte_cdk.sources.streams.checkpoint import Cursor
from airbyte_cdk.sources.streams.checkpoint.per_partition_key_serializer import PerPartitionKeySerializer
from airbyte_cdk.sources.types import Record, StreamSlice, StreamState
from airbyte_cdk.utils import AirbyteTracedException
from airbyte_protocol.models import FailureType
FULL_REFRESH_COMPLETE_STATE: Mapping[str, Any] = {"__ab_full_refresh_sync_complete": True}
@dataclass
class SubstreamResumableFullRefreshCursor(Cursor):
def __init__(self) -> None:
self._per_partition_state: MutableMapping[str, StreamState] = {}
self._partition_serializer = PerPartitionKeySerializer()
def get_stream_state(self) -> StreamState:
states = []
for partition_tuple, partition_state in self._per_partition_state.items():
states.append(
{
"partition": self._to_dict(partition_tuple),
"cursor": partition_state,
}
)
state: dict[str, Any] = {"states": states}
return state
def set_initial_state(self, stream_state: StreamState) -> None:
"""
Set the initial state for the cursors.
This method initializes the state for each partition cursor using the provided stream state.
If a partition state is provided in the stream state, it will update the corresponding partition cursor with this state.
To simplify processing and state management, we do not maintain the checkpointed state of the parent partitions.
Instead, we are tracking whether a parent has already successfully synced on a prior attempt and skipping over it
allowing the sync to continue making progress. And this works for RFR because the platform will dispose of this
state on the next sync job.
Args:
stream_state (StreamState): The state of the streams to be set. The format of the stream state should be:
{
"states": [
{
"partition": {
"partition_key": "value_0"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"partition_key": "value_1"
},
"cursor": {},
},
]
}
"""
if not stream_state:
return
if "states" not in stream_state:
raise AirbyteTracedException(
internal_message=f"Could not sync parse the following state: {stream_state}",
message="The state for is format invalid. Validate that the migration steps included a reset and that it was performed "
"properly. Otherwise, please contact Airbyte support.",
failure_type=FailureType.config_error,
)
for state in stream_state["states"]:
self._per_partition_state[self._to_partition_key(state["partition"])] = state["cursor"]
def observe(self, stream_slice: StreamSlice, record: Record) -> None:
"""
Substream resumable full refresh manages state by closing the slice after syncing a parent so observe is not used.
"""
pass
def close_slice(self, stream_slice: StreamSlice, *args: Any) -> None:
self._per_partition_state[self._to_partition_key(stream_slice.partition)] = FULL_REFRESH_COMPLETE_STATE
def should_be_synced(self, record: Record) -> bool:
"""
Unlike date-based cursors which filter out records outside slice boundaries, resumable full refresh records exist within pages
that don't have filterable bounds. We should always return them.
"""
return True
def is_greater_than_or_equal(self, first: Record, second: Record) -> bool:
"""
RFR record don't have ordering to be compared between one another.
"""
return False
def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]:
if not stream_slice:
raise ValueError("A partition needs to be provided in order to extract a state")
return self._per_partition_state.get(self._to_partition_key(stream_slice.partition))
def _to_partition_key(self, partition: Mapping[str, Any]) -> str:
return self._partition_serializer.to_partition_key(partition)
def _to_dict(self, partition_key: str) -> Mapping[str, Any]:
return self._partition_serializer.to_partition(partition_key)

View File

@@ -15,6 +15,7 @@ from airbyte_cdk.sources.message.repository import InMemoryMessageRepository
from airbyte_cdk.sources.streams.call_rate import APIBudget
from airbyte_cdk.sources.streams.checkpoint.cursor import Cursor
from airbyte_cdk.sources.streams.checkpoint.resumable_full_refresh_cursor import ResumableFullRefreshCursor
from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor
from airbyte_cdk.sources.streams.core import CheckpointMixin, Stream, StreamData
from airbyte_cdk.sources.streams.http.error_handlers import BackoffStrategy, ErrorHandler, HttpStatusErrorHandler
from airbyte_cdk.sources.streams.http.error_handlers.response_models import ErrorResolution, ResponseAction
@@ -49,13 +50,13 @@ class HttpStream(Stream, CheckpointMixin, ABC):
message_repository=InMemoryMessageRepository(),
)
# Stream's with at least one cursor_field is incremental and thus a superior sync than RFR. We also cannot
# assume that streams overriding read_records() will call the parent implementation
# which
if len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records:
# There are three conditions that dictate if RFR should automatically be applied to a stream
# 1. Streams that explicitly initialize their own cursor should defer to it and not automatically apply RFR
# 2. Streams with at least one cursor_field are incremental and thus a superior sync to RFR.
# 3. Streams overriding read_records() do not guarantee that they will call the parent implementation which can perform
# per-page checkpointing so RFR is only supported if a stream use the default `HttpStream.read_records()` method
if not self.cursor and len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records:
self.cursor = ResumableFullRefreshCursor()
else:
self.cursor = None
@property
def exit_on_rate_limit(self) -> bool:
@@ -360,11 +361,12 @@ class HttpStream(Stream, CheckpointMixin, ABC):
def get_cursor(self) -> Optional[Cursor]:
# I don't love that this is semi-stateful but not sure what else to do. We don't know exactly what type of cursor to
# instantiate when creating the class. We can make a few assumptions like if there is a cursor_field which implies
# incremental, but we don't know until runtime if this is a substream. This makes for a slightly awkward getter that
# is stateful because `has_multiple_slices` can be reassigned at runtime after we inspect slices.
# This code temporarily gates RFR off for substreams by always removing the RFR cursor if there are multiple slices
if self.has_multiple_slices:
return None
# incremental, but we don't know until runtime if this is a substream. Ideally, a stream should explicitly define
# its cursor, but because we're trying to automatically apply RFR we're stuck with this logic where we replace the
# cursor at runtime once we detect this is a substream based on self.has_multiple_slices being reassigned
if self.has_multiple_slices and isinstance(self.cursor, ResumableFullRefreshCursor):
self.cursor = SubstreamResumableFullRefreshCursor()
return self.cursor
else:
return self.cursor
@@ -376,6 +378,8 @@ class HttpStream(Stream, CheckpointMixin, ABC):
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
) -> Iterable[StreamData]:
partition, _, _ = self._extract_slice_fields(stream_slice=stream_slice)
stream_state = stream_state or {}
pagination_complete = False
next_page_token = None
@@ -387,6 +391,12 @@ class HttpStream(Stream, CheckpointMixin, ABC):
if not next_page_token:
pagination_complete = True
cursor = self.get_cursor()
if cursor and isinstance(cursor, SubstreamResumableFullRefreshCursor):
# Substreams checkpoint state by marking an entire parent partition as completed so that on the subsequent attempt
# after a failure, completed parents are skipped and the sync can make progress
cursor.close_slice(StreamSlice(cursor_slice={}, partition=partition))
# Always return an empty generator just in case no records were ever yielded
yield from []
@@ -473,7 +483,14 @@ class HttpSubStream(HttpStream, ABC):
super().__init__(**kwargs)
self.parent = parent
self.has_multiple_slices = True # Substreams are based on parent records which implies there are multiple slices
self.cursor = None # todo: HttpSubStream does not currently support RFR so this is set to None
# There are three conditions that dictate if RFR should automatically be applied to a stream
# 1. Streams that explicitly initialize their own cursor should defer to it and not automatically apply RFR
# 2. Streams with at least one cursor_field are incremental and thus a superior sync to RFR.
# 3. Streams overriding read_records() do not guarantee that they will call the parent implementation which can perform
# per-page checkpointing so RFR is only supported if a stream use the default `HttpStream.read_records()` method
if not self.cursor and len(self.cursor_field) == 0 and type(self).read_records is HttpStream.read_records:
self.cursor = SubstreamResumableFullRefreshCursor()
def stream_slices(
self, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None

View File

@@ -246,12 +246,21 @@ class FullRefreshStreamTest(TestCase):
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0
@HttpMocker()
def test_full_refresh_with_slices(self, http_mocker):
def test_substream_resumable_full_refresh_with_parent_slices(self, http_mocker):
start_datetime = _NOW - timedelta(days=14)
config = {
"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
}
expected_first_substream_per_stream_state = [
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
]
expected_second_substream_per_stream_state = [
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
{'partition': {'divide_category': 'mentats'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
]
http_mocker.get(
_create_dividers_request().with_category("dukes").build(),
_create_response().with_record(record=_create_record("dividers")).with_record(record=_create_record("dividers")).build(),
@@ -267,11 +276,12 @@ class FullRefreshStreamTest(TestCase):
assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("dividers"))
assert len(actual_messages.records) == 4
assert len(actual_messages.state_messages) == 1
validate_message_order([Type.RECORD, Type.RECORD, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_descriptor.name == "dividers"
assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(__ab_no_cursor_state_message=True)
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 4.0
assert len(actual_messages.state_messages) == 2
validate_message_order([Type.RECORD, Type.RECORD, Type.STATE, Type.RECORD, Type.RECORD, Type.STATE], actual_messages.records_and_state_messages)
assert actual_messages.state_messages[0].state.stream.stream_state == AirbyteStateBlob(states=expected_first_substream_per_stream_state)
assert actual_messages.state_messages[0].state.sourceStats.recordCount == 2.0
assert actual_messages.state_messages[1].state.stream.stream_state == AirbyteStateBlob(states=expected_second_substream_per_stream_state)
assert actual_messages.state_messages[1].state.sourceStats.recordCount == 2.0
@freezegun.freeze_time(_NOW)
@@ -428,6 +438,15 @@ class MultipleStreamTest(TestCase):
"start_date": start_datetime.strftime("%Y-%m-%dT%H:%M:%SZ")
}
expected_first_substream_per_stream_state = [
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
]
expected_second_substream_per_stream_state = [
{'partition': {'divide_category': 'dukes'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
{'partition': {'divide_category': 'mentats'}, 'cursor': {'__ab_full_refresh_sync_complete': True}},
]
# Mocks for users full refresh stream
http_mocker.get(
_create_users_request().build(),
@@ -466,7 +485,7 @@ class MultipleStreamTest(TestCase):
assert emits_successful_sync_status_messages(actual_messages.get_stream_statuses("dividers"))
assert len(actual_messages.records) == 11
assert len(actual_messages.state_messages) == 4
assert len(actual_messages.state_messages) == 5
validate_message_order([
Type.RECORD,
Type.RECORD,
@@ -480,6 +499,7 @@ class MultipleStreamTest(TestCase):
Type.STATE,
Type.RECORD,
Type.RECORD,
Type.STATE,
Type.RECORD,
Type.RECORD,
Type.STATE
@@ -494,5 +514,8 @@ class MultipleStreamTest(TestCase):
assert actual_messages.state_messages[2].state.stream.stream_state == AirbyteStateBlob(created_at=last_record_date_1)
assert actual_messages.state_messages[2].state.sourceStats.recordCount == 2.0
assert actual_messages.state_messages[3].state.stream.stream_descriptor.name == "dividers"
assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob(__ab_no_cursor_state_message=True)
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 4.0
assert actual_messages.state_messages[3].state.stream.stream_state == AirbyteStateBlob(states=expected_first_substream_per_stream_state)
assert actual_messages.state_messages[3].state.sourceStats.recordCount == 2.0
assert actual_messages.state_messages[4].state.stream.stream_descriptor.name == "dividers"
assert actual_messages.state_messages[4].state.stream.stream_state == AirbyteStateBlob(states=expected_second_substream_per_stream_state)
assert actual_messages.state_messages[4].state.sourceStats.recordCount == 2.0

View File

@@ -0,0 +1,162 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
import pytest
from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor
from airbyte_cdk.sources.types import StreamSlice
from airbyte_cdk.utils import AirbyteTracedException
def test_substream_resumable_full_refresh_cursor():
"""
Test scenario where a set of parent record partitions are iterated over by the cursor resulting in a completed sync
"""
expected_starting_state = {"states": []}
expected_ending_state = {
"states": [
{
"partition": {
"musician_id": "kousei_arima"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"musician_id": "kaori_miyazono"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
}
]
}
partitions = [
StreamSlice(partition={"musician_id": "kousei_arima"}, cursor_slice={}),
StreamSlice(partition={"musician_id": "kaori_miyazono"}, cursor_slice={}),
]
cursor = SubstreamResumableFullRefreshCursor()
starting_state = cursor.get_stream_state()
assert starting_state == expected_starting_state
for partition in partitions:
partition_state = cursor.select_state(partition)
assert partition_state is None
cursor.close_slice(partition)
ending_state = cursor.get_stream_state()
assert ending_state == expected_ending_state
def test_substream_resumable_full_refresh_cursor_with_state():
"""
Test scenario where a set of parent record partitions are iterated over and previously completed parents are skipped
"""
initial_state = {
"states": [
{
"partition": {
"musician_id": "kousei_arima"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"musician_id": "kaori_miyazono"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"musician_id": "takeshi_aiza"
},
"cursor": {}
}
]
}
expected_ending_state = {
"states": [
{
"partition": {
"musician_id": "kousei_arima"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"musician_id": "kaori_miyazono"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"musician_id": "takeshi_aiza"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
},
{
"partition": {
"musician_id": "emi_igawa"
},
"cursor": {
"__ab_full_refresh_sync_complete": True
}
}
]
}
partitions = [
StreamSlice(partition={"musician_id": "kousei_arima"}, cursor_slice={}),
StreamSlice(partition={"musician_id": "kaori_miyazono"}, cursor_slice={}),
StreamSlice(partition={"musician_id": "takeshi_aiza"}, cursor_slice={}),
StreamSlice(partition={"musician_id": "emi_igawa"}, cursor_slice={}),
]
cursor = SubstreamResumableFullRefreshCursor()
cursor.set_initial_state(initial_state)
starting_state = cursor.get_stream_state()
assert starting_state == initial_state
for i, partition in enumerate(partitions):
partition_state = cursor.select_state(partition)
if i < len(initial_state.get("states")):
assert partition_state == initial_state.get("states")[i].get("cursor")
else:
assert partition_state is None
cursor.close_slice(partition)
ending_state = cursor.get_stream_state()
assert ending_state == expected_ending_state
def test_set_initial_state_invalid_incoming_state():
bad_state = {
"next_page_token": 2
}
cursor = SubstreamResumableFullRefreshCursor()
with pytest.raises(AirbyteTracedException):
cursor.set_initial_state(bad_state)
def test_select_state_without_slice():
cursor = SubstreamResumableFullRefreshCursor()
with pytest.raises(ValueError):
cursor.select_state()

View File

@@ -6,7 +6,7 @@
import json
import logging
from http import HTTPStatus
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import ANY, MagicMock, patch
import pytest
@@ -14,6 +14,7 @@ import requests
from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, SyncMode, Type
from airbyte_cdk.sources.streams import CheckpointMixin
from airbyte_cdk.sources.streams.checkpoint import ResumableFullRefreshCursor
from airbyte_cdk.sources.streams.checkpoint.substream_resumable_full_refresh_cursor import SubstreamResumableFullRefreshCursor
from airbyte_cdk.sources.streams.core import StreamData
from airbyte_cdk.sources.streams.http import HttpStream, HttpSubStream
from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler, HttpStatusErrorHandler
@@ -1067,6 +1068,319 @@ def test_resumable_full_refresh_legacy_stream_slice(mocker):
assert records == expected
class StubSubstreamResumableFullRefreshStream(HttpSubStream, CheckpointMixin):
primary_key = "primary_key"
counter = 0
def __init__(self, parent: HttpStream, partition_id_to_child_records: Mapping[str, List[Mapping[str, Any]]]):
super().__init__(parent=parent)
self._partition_id_to_child_records = partition_id_to_child_records
# self._state: MutableMapping[str, Any] = {}
@property
def url_base(self) -> str:
return "https://airbyte.io/api/v1"
def path(self, *, stream_state: Optional[Mapping[str, Any]] = None, stream_slice: Optional[Mapping[str, Any]] = None,
next_page_token: Optional[Mapping[str, Any]] = None) -> str:
return f"/parents/{stream_slice.get('parent_id')}/children"
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
return None
# def read_records(
# self,
# sync_mode: SyncMode,
# cursor_field: Optional[List[str]] = None,
# stream_slice: Optional[Mapping[str, Any]] = None,
# stream_state: Optional[Mapping[str, Any]] = None,
# ) -> Iterable[StreamData]:
# page_number = self.state.get("page") or 1
# yield from self._record_pages[page_number - 1]
#
# if page_number < len(self._record_pages):
# self.state = {"page": page_number + 1}
# else:
# self.state = {"__ab_full_refresh_sync_complete": True}
def _fetch_next_page(
self,
stream_slice: Optional[Mapping[str, Any]] = None,
stream_state: Optional[Mapping[str, Any]] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Tuple[requests.PreparedRequest, requests.Response]:
return requests.PreparedRequest(), requests.Response()
def parse_response(
self,
response: requests.Response,
*,
stream_state: Mapping[str, Any],
stream_slice: Optional[Mapping[str, Any]] = None,
next_page_token: Optional[Mapping[str, Any]] = None
) -> Iterable[Mapping[str, Any]]:
partition_id = stream_slice.get("parent").get("parent_id")
if partition_id in self._partition_id_to_child_records:
yield from self._partition_id_to_child_records.get(partition_id)
else:
raise Exception(f"No mocked output supplied for parent partition_id: {partition_id}")
def get_json_schema(self) -> Mapping[str, Any]:
return {}
def test_substream_resumable_full_refresh_read_from_start(mocker):
"""
Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one
page per invocation and emits state afterward.
parses over
"""
parent_records = [
{"parent_id": "100", "name": "christopher_nolan"},
{"parent_id": "101", "name": "celine_song"},
{"parent_id": "102", "name": "david_fincher"},
]
parent_stream = StubParentHttpStream(records=parent_records)
parents_to_children_records = {
"100": [{"id": "a200", "parent_id": "100", "film": "interstellar"}, {"id": "a201", "parent_id": "100", "film": "oppenheimer"}, {"id": "a202", "parent_id": "100", "film": "inception"}],
"101": [{"id": "b200", "parent_id": "101", "film": "past_lives"}, {"id": "b201", "parent_id": "101", "film": "materialists"}],
"102": [{"id": "c200", "parent_id": "102", "film": "the_social_network"}, {"id": "c201", "parent_id": "102", "film": "gone_girl"}, {"id": "c202", "parent_id": "102", "film": "the_curious_case_of_benjamin_button"}],
}
stream = StubSubstreamResumableFullRefreshStream(parent=parent_stream, partition_id_to_child_records=parents_to_children_records)
blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway.
mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response))
# Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect
mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages"))
checkpoint_reader = stream._get_checkpoint_reader(
cursor_field=[], logger=logging.getLogger("airbyte"), sync_mode=SyncMode.full_refresh, stream_state={}
)
next_stream_slice = checkpoint_reader.next()
records = []
expected_checkpoints = [
{
"states": [
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "christopher_nolan", "parent_id": "100"}
}
}
]
},
{
"states": [
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "christopher_nolan", "parent_id": "100"}
}
},
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "celine_song", "parent_id": "101"}
}
}
]
},
{
"states": [
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "christopher_nolan", "parent_id": "100"}
}
},
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "celine_song", "parent_id": "101"}
}
},
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "david_fincher", "parent_id": "102"}
}
}
]
},
]
i = 0
while next_stream_slice is not None:
next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice))
records.extend(next_records)
checkpoint_reader.observe(stream.state)
assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i]
next_stream_slice = checkpoint_reader.next()
i += 1
assert getattr(stream, "_read_pages").call_count == 3
expected = [
{
"film": "interstellar",
"id": "a200",
"parent_id": "100"
},
{
"film": "oppenheimer",
"id": "a201",
"parent_id": "100"
},
{
"film": "inception",
"id": "a202",
"parent_id": "100"
},
{
"film": "past_lives",
"id": "b200",
"parent_id": "101"
},
{
"film": "materialists",
"id": "b201",
"parent_id": "101"
},
{
"film": "the_social_network",
"id": "c200",
"parent_id": "102"
},
{
"film": "gone_girl",
"id": "c201",
"parent_id": "102"
},
{
"film": "the_curious_case_of_benjamin_button",
"id": "c202",
"parent_id": "102"
}
]
assert records == expected
def test_substream_resumable_full_refresh_read_from_state(mocker):
"""
Validates the default behavior of a stream that supports resumable full refresh by using read_records() which gets one
page per invocation and emits state afterward.
parses over
"""
parent_records = [
{"parent_id": "100", "name": "christopher_nolan"},
{"parent_id": "101", "name": "celine_song"},
]
parent_stream = StubParentHttpStream(records=parent_records)
parents_to_children_records = {
"100": [{"id": "a200", "parent_id": "100", "film": "interstellar"}, {"id": "a201", "parent_id": "100", "film": "oppenheimer"},
{"id": "a202", "parent_id": "100", "film": "inception"}],
"101": [{"id": "b200", "parent_id": "101", "film": "past_lives"}, {"id": "b201", "parent_id": "101", "film": "materialists"}],
}
stream = StubSubstreamResumableFullRefreshStream(parent=parent_stream, partition_id_to_child_records=parents_to_children_records)
blank_response = {} # Send a blank response is fine as we ignore the response in `parse_response anyway.
mocker.patch.object(stream._http_client, "send_request", return_value=(None, blank_response))
# Wrap all methods we're interested in testing with mocked objects to spy on their input args and verify they were what we expect
mocker.patch.object(stream, "_read_pages", wraps=getattr(stream, "_read_pages"))
checkpoint_reader = stream._get_checkpoint_reader(
cursor_field=[],
logger=logging.getLogger("airbyte"),
sync_mode=SyncMode.full_refresh,
stream_state={
"states": [
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "christopher_nolan", "parent_id": "100"}
}
},
]
}
)
next_stream_slice = checkpoint_reader.next()
records = []
expected_checkpoints = [
{
"states": [
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "christopher_nolan", "parent_id": "100"}
}
},
{
"cursor": {
"__ab_full_refresh_sync_complete": True
},
"partition": {
"parent": {"name": "celine_song", "parent_id": "101"}
}
}
]
},
]
i = 0
while next_stream_slice is not None:
next_records = list(stream.read_records(SyncMode.full_refresh, stream_slice=next_stream_slice))
records.extend(next_records)
checkpoint_reader.observe(stream.state)
assert checkpoint_reader.get_checkpoint() == expected_checkpoints[i]
next_stream_slice = checkpoint_reader.next()
i += 1
assert getattr(stream, "_read_pages").call_count == 1
expected = [
{
"film": "past_lives",
"id": "b200",
"parent_id": "101"
},
{
"film": "materialists",
"id": "b201",
"parent_id": "101"
},
]
assert records == expected
class StubWithCursorFields(StubBasicReadHttpStream):
def __init__(self, has_multiple_slices: bool, set_cursor_field: List[str], deduplicate_query_params: bool = False, **kwargs):
self.has_multiple_slices = has_multiple_slices
@@ -1083,8 +1397,8 @@ class StubWithCursorFields(StubBasicReadHttpStream):
[
pytest.param([], False, ResumableFullRefreshCursor(), id="test_stream_supports_resumable_full_refresh_cursor"),
pytest.param(["updated_at"], False, None, id="test_incremental_stream_does_not_use_cursor"),
pytest.param([], True, None, id="test_substream_does_not_use_cursor"),
pytest.param(["updated_at"], True, None, id="test_incremental_substream_does_not_use_cursor"),
pytest.param([], True, SubstreamResumableFullRefreshCursor(), id="test_full_refresh_substream_automatically_applies_substream_resumable_full_refresh_cursor"),
]
)
def test_get_cursor(cursor_field, is_substream, expected_cursor):