1
0
mirror of synced 2025-12-30 21:02:43 -05:00
Files
airbyte/airbyte-cdk/python/airbyte_cdk/sources/connector_state_manager.py
Cole Snodgrass 2e099acc52 update headers from 2022 -> 2023 (#22594)
* It's 2023!

* 2022 -> 2023

---------

Co-authored-by: evantahler <evan@airbyte.io>
2023-02-08 13:01:16 -08:00

197 lines
10 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import copy
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
from airbyte_cdk.models import AirbyteMessage, AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor
from airbyte_cdk.models import Type as MessageType
from airbyte_cdk.sources.streams import Stream
from pydantic import Extra
class HashableStreamDescriptor(StreamDescriptor):
"""
Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and
freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
"""
class Config:
extra = Extra.allow
frozen = True
class ConnectorStateManager:
"""
ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL / LEGACY) under a common
interface. It also provides methods to extract and update state
"""
def __init__(self, stream_instance_map: Mapping[str, Stream], state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]] = None):
shared_state, per_stream_states = self._extract_from_state_message(state, stream_instance_map)
# We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
# designed to checkpoint state independently of one another. API sources should never be emitting a state message where
# shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an
# error instead and if/when we find one, we will then implement processing of the shared_state value.
if shared_state:
raise ValueError(
"Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM "
"STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL "
"state messages with shared_state will not be processed correctly. "
)
self.per_stream_states = per_stream_states
def get_stream_state(self, stream_name: str, namespace: Optional[str]) -> Mapping[str, Any]:
"""
Retrieves the state of a given stream based on its descriptor (name + namespace).
:param stream_name: Name of the stream being fetched
:param namespace: Namespace of the stream being fetched
:return: The per-stream state for a stream
"""
stream_state = self.per_stream_states.get(HashableStreamDescriptor(name=stream_name, namespace=namespace))
if stream_state:
return stream_state.dict()
return {}
def update_state_for_stream(self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]):
"""
Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
:param stream_name: The name of the stream whose state is being updated
:param namespace: The namespace of the stream if it exists
:param value: A stream state mapping that is being updated for a stream
"""
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
self.per_stream_states[stream_descriptor] = AirbyteStateBlob.parse_obj(value)
def create_state_message(self, stream_name: str, namespace: Optional[str], send_per_stream_state: bool) -> AirbyteMessage:
"""
Generates an AirbyteMessage using the current per-stream state of a specified stream in either the per-stream or legacy format
:param stream_name: The name of the stream for the message that is being created
:param namespace: The namespace of the stream for the message that is being created
:param send_per_stream_state: Decides which state format the message should be generated as
:return: The Airbyte state message to be emitted by the connector during a sync
"""
if send_per_stream_state:
hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob()
# According to the Airbyte protocol, the StreamDescriptor namespace field is not required. However, the platform will throw
# a validation error if it receives namespace=null. That is why if namespace is None, the field should be omitted instead.
stream_descriptor = (
StreamDescriptor(name=stream_name) if namespace is None else StreamDescriptor(name=stream_name, namespace=namespace)
)
return AirbyteMessage(
type=MessageType.STATE,
state=AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(stream_descriptor=stream_descriptor, stream_state=stream_state),
data=dict(self._get_legacy_state()),
),
)
return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(data=dict(self._get_legacy_state())))
@classmethod
def _extract_from_state_message(
cls, state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]], stream_instance_map: Mapping[str, Stream]
) -> Tuple[Optional[AirbyteStateBlob], MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]]:
"""
Takes an incoming list of state messages or the legacy state format and extracts state attributes according to type
which can then be assigned to the new state manager being instantiated
:param state: The incoming state input
:return: A tuple of shared state and per stream state assembled from the incoming state list
"""
if state is None:
return None, {}
is_legacy = cls._is_legacy_dict_state(state)
is_migrated_legacy = cls._is_migrated_legacy_state(state)
is_global = cls._is_global_state(state)
is_per_stream = cls._is_per_stream_state(state)
# Incoming pure legacy object format
if is_legacy:
streams = cls._create_descriptor_to_stream_state_mapping(state, stream_instance_map)
return None, streams
# When processing incoming state in source.read_state(), legacy state gets deserialized into List[AirbyteStateMessage]
# which can be translated into independent per-stream state values
if is_migrated_legacy:
streams = cls._create_descriptor_to_stream_state_mapping(state[0].data, stream_instance_map)
return None, streams
if is_global:
global_state = state[0].global_
shared_state = copy.deepcopy(global_state.shared_state, {})
streams = {
HashableStreamDescriptor(
name=per_stream_state.stream_descriptor.name, namespace=per_stream_state.stream_descriptor.namespace
): per_stream_state.stream_state
for per_stream_state in global_state.stream_states
}
return shared_state, streams
if is_per_stream:
streams = {
HashableStreamDescriptor(
name=per_stream_state.stream.stream_descriptor.name, namespace=per_stream_state.stream.stream_descriptor.namespace
): per_stream_state.stream.stream_state
for per_stream_state in state
if per_stream_state.type == AirbyteStateType.STREAM and hasattr(per_stream_state, "stream")
}
return None, streams
else:
raise ValueError("Input state should come in the form of list of Airbyte state messages or a mapping of states")
@staticmethod
def _create_descriptor_to_stream_state_mapping(
state: MutableMapping[str, Any], stream_to_instance_map: Mapping[str, Stream]
) -> MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]:
"""
Takes incoming state received in the legacy format and transforms it into a mapping of StreamDescriptor to AirbyteStreamState
:param state: A mapping object representing the complete state of all streams in the legacy format
:param stream_to_instance_map: A mapping of stream name to stream instance used to retrieve a stream's namespace
:return: The mapping of all of a sync's streams to the corresponding stream state
"""
streams = {}
for stream_name, state_value in state.items():
namespace = stream_to_instance_map[stream_name].namespace if stream_name in stream_to_instance_map else None
stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
streams[stream_descriptor] = AirbyteStateBlob.parse_obj(state_value or {})
return streams
def _get_legacy_state(self) -> Mapping[str, Any]:
"""
Using the current per-stream state, creates a mapping of all the stream states for the connector being synced
:return: A deep copy of the mapping of stream name to stream state value
"""
return {descriptor.name: state.dict() if state else {} for descriptor, state in self.per_stream_states.items()}
@staticmethod
def _is_legacy_dict_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]):
return isinstance(state, dict)
@staticmethod
def _is_migrated_legacy_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
return (
isinstance(state, List)
and len(state) == 1
and isinstance(state[0], AirbyteStateMessage)
and state[0].type == AirbyteStateType.LEGACY
)
@staticmethod
def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
return (
isinstance(state, List)
and len(state) == 1
and isinstance(state[0], AirbyteStateMessage)
and state[0].type == AirbyteStateType.GLOBAL
)
@staticmethod
def _is_per_stream_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
return isinstance(state, List)