1
0
mirror of synced 2026-01-28 01:01:59 -05:00
Files
airbyte/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py
2023-11-02 08:21:08 -04:00

152 lines
6.4 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import functools
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Optional, Protocol, Tuple
from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager
from airbyte_cdk.sources.message import MessageRepository
from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition
from airbyte_cdk.sources.streams.concurrent.partitions.record import Record
from airbyte_cdk.sources.streams.concurrent.state_converter import ConcurrentStreamStateConverter
def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any:
return functools.reduce(lambda a, b: a[b], path, mapping)
class Comparable(Protocol):
"""Protocol for annotating comparable types."""
@abstractmethod
def __lt__(self: "Comparable", other: "Comparable") -> bool:
pass
class CursorField:
def __init__(self, cursor_field_key: str) -> None:
self._cursor_field_key = cursor_field_key
def extract_value(self, record: Record) -> Comparable:
cursor_value = record.data.get(self._cursor_field_key)
if cursor_value is None:
raise ValueError(f"Could not find cursor field {self._cursor_field_key} in record")
return cursor_value # type: ignore # we assume that the value the path points at is a comparable
class Cursor(ABC):
@abstractmethod
def observe(self, record: Record) -> None:
"""
Indicate to the cursor that the record has been emitted
"""
raise NotImplementedError()
@abstractmethod
def close_partition(self, partition: Partition) -> None:
"""
Indicate to the cursor that the partition has been successfully processed
"""
raise NotImplementedError()
class NoopCursor(Cursor):
def observe(self, record: Record) -> None:
pass
def close_partition(self, partition: Partition) -> None:
pass
class ConcurrentCursor(Cursor):
_START_BOUNDARY = 0
_END_BOUNDARY = 1
def __init__(
self,
stream_name: str,
stream_namespace: Optional[str],
stream_state: Any,
message_repository: MessageRepository,
connector_state_manager: ConnectorStateManager,
connector_state_converter: ConcurrentStreamStateConverter,
cursor_field: CursorField,
slice_boundary_fields: Optional[Tuple[str, str]],
) -> None:
self._stream_name = stream_name
self._stream_namespace = stream_namespace
self._message_repository = message_repository
self._connector_state_converter = connector_state_converter
self._connector_state_manager = connector_state_manager
self._cursor_field = cursor_field
# To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379
self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple()
self._most_recent_record: Optional[Record] = None
self._has_closed_at_least_one_slice = False
self._state = connector_state_converter.get_concurrent_stream_state(stream_state)
def observe(self, record: Record) -> None:
if self._slice_boundary_fields:
# Given that slicing is done using the cursor field, we don't need to observe the record as we assume slices will describe what
# has been emitted. Assuming there is a chance that records might not be yet populated for the most recent slice, use a lookback
# window
return
if not self._most_recent_record or self._extract_cursor_value(self._most_recent_record) < self._extract_cursor_value(record):
self._most_recent_record = record
def _extract_cursor_value(self, record: Record) -> Comparable:
return self._cursor_field.extract_value(record)
def close_partition(self, partition: Partition) -> None:
slice_count_before = len(self._state["slices"])
self._add_slice_to_state(partition)
if slice_count_before < len(self._state["slices"]):
self._merge_partitions()
self._emit_state_message()
self._has_closed_at_least_one_slice = True
def _add_slice_to_state(self, partition: Partition) -> None:
if self._slice_boundary_fields:
if "slices" not in self._state:
self._state["slices"] = []
self._state["slices"].append(
{
"start": self._extract_from_slice(partition, self._slice_boundary_fields[self._START_BOUNDARY]),
"end": self._extract_from_slice(partition, self._slice_boundary_fields[self._END_BOUNDARY]),
}
)
elif self._most_recent_record:
if self._has_closed_at_least_one_slice:
raise ValueError(
"Given that slice_boundary_fields is not defined and that per-partition state is not supported, only one slice is "
"expected."
)
self._state["slices"].append(
{
"start": 0, # FIXME this only works with int datetime
"end": self._extract_cursor_value(self._most_recent_record),
}
)
def _emit_state_message(self) -> None:
self._connector_state_manager.update_state_for_stream(self._stream_name, self._stream_namespace, self._state)
state_message = self._connector_state_manager.create_state_message(
self._stream_name, self._stream_namespace, send_per_stream_state=True
)
self._message_repository.emit_message(state_message)
def _merge_partitions(self) -> None:
self._state["slices"] = self._connector_state_converter.merge_intervals(self._state["slices"])
def _extract_from_slice(self, partition: Partition, key: str) -> Comparable:
try:
_slice = partition.to_slice()
if not _slice:
raise KeyError(f"Could not find key `{key}` in empty slice")
return _slice[key] # type: ignore # we expect the devs to specify a key that would return a Comparable
except KeyError as exception:
raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception