1
0
mirror of synced 2026-02-02 16:02:07 -05:00
Files
airbyte/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py

161 lines
6.9 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import functools
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Optional, Protocol, Tuple
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import AbstractStreamStateConverter
def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
return functools.reduce(lambda a, b: a[b], path, mapping)
class Comparable(Protocol):
"""Protocol for annotating comparable types."""
@abstractmethod
def __lt__(self: "Comparable", other: "Comparable") -> bool:
pass
class CursorField:
def __init__(self, cursor_field_key: str) -> None:
self.cursor_field_key = cursor_field_key
def extract_value(self, record: Record) -> Comparable:
cursor_value = record.data.get(self.cursor_field_key)
if cursor_value is None:
raise ValueError(f"Could not find cursor field {self.cursor_field_key} in record")
return cursor_value # type: ignore # we assume that the value the path points at is a comparable
class Cursor(ABC):
@abstractmethod
def observe(self, record: Record) -> None:
"""
Indicate to the cursor that the record has been emitted
"""
raise NotImplementedError()
@abstractmethod
def close_partition(self, partition: Partition) -> None:
"""
Indicate to the cursor that the partition has been successfully processed
"""
raise NotImplementedError()
class NoopCursor(Cursor):
def observe(self, record: Record) -> None:
pass
def close_partition(self, partition: Partition) -> None:
pass
class ConcurrentCursor(Cursor):
_START_BOUNDARY = 0
_END_BOUNDARY = 1
def __init__(
self,
stream_name: str,
stream_namespace: Optional[str],
stream_state: Any,
message_repository: MessageRepository,
connector_state_manager: ConnectorStateManager,
connector_state_converter: AbstractStreamStateConverter,
cursor_field: CursorField,
slice_boundary_fields: Optional[Tuple[str, str]],
) -> None:
self._stream_name = stream_name
self._stream_namespace = stream_namespace
self._message_repository = message_repository
self._connector_state_converter = connector_state_converter
self._connector_state_manager = connector_state_manager
self._cursor_field = cursor_field
# To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple()
self._most_recent_record: Optional[Record] = None
self._has_closed_at_least_one_slice = False
self.state = stream_state
def observe(self, record: Record) -> None:
if self._slice_boundary_fields:
# Given that slicing is done using the cursor field, we don't need to observe the record as we assume slices will describe what
# has been emitted. Assuming there is a chance that records might not be yet populated for the most recent slice, use a lookback
# window
return
if not self._most_recent_record or self._extract_cursor_value(self._most_recent_record) < self._extract_cursor_value(record):
self._most_recent_record = record
def _extract_cursor_value(self, record: Record) -> Any:
return self._connector_state_converter.parse_value(self._cursor_field.extract_value(record))
def close_partition(self, partition: Partition) -> None:
slice_count_before = len(self.state.get("slices", []))
self._add_slice_to_state(partition)
if slice_count_before < len(self.state["slices"]):
self._merge_partitions()
self._emit_state_message()
self._has_closed_at_least_one_slice = True
def _add_slice_to_state(self, partition: Partition) -> None:
if self._slice_boundary_fields:
if "slices" not in self.state:
self.state["slices"] = []
self.state["slices"].append(
{
"start": self._extract_from_slice(partition, self._slice_boundary_fields[self._START_BOUNDARY]),
"end": self._extract_from_slice(partition, self._slice_boundary_fields[self._END_BOUNDARY]),
}
)
elif self._most_recent_record:
if self._has_closed_at_least_one_slice:
raise ValueError(
"Given that slice_boundary_fields is not defined and that per-partition state is not supported, only one slice is "
"expected."
)
self.state["slices"].append(
{
# TODO: if we migrate stored state to the concurrent state format, we may want this to be the config start date
# instead of zero_value.
"start": self._connector_state_converter.zero_value,
"end": self._extract_cursor_value(self._most_recent_record),
}
)
def _emit_state_message(self) -> None:
self._connector_state_manager.update_state_for_stream(
self._stream_name,
self._stream_namespace,
self._connector_state_converter.convert_to_sequential_state(self._cursor_field, self.state),
)
# TODO: if we migrate stored state to the concurrent state format
# (aka stop calling self._connector_state_converter.convert_to_sequential_state`), we'll need to cast datetimes to string or
# int before emitting state
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace, send_per_stream_state=True
)
self._message_repository.emit_message(state_message)
def _merge_partitions(self) -> None:
self.state["slices"] = self._connector_state_converter.merge_intervals(self.state["slices"])
def _extract_from_slice(self, partition: Partition, key: str) -> Comparable:
try:
_slice = partition.to_slice()
if not _slice:
raise KeyError(f"Could not find key `{key}` in empty slice")
return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a Comparable
except KeyError as exception:
raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception