* Update square stream_slices * Fix unit tests * Version bump and changelogs * Update changelogs * Format code * Update CDK version to fix bug * Fix unit tests * Increase CDK version
92 lines
3.6 KiB
Python
92 lines
3.6 KiB
Python
#
|
|
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
#
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Iterable, Mapping, Optional
|
|
|
|
from airbyte_cdk.models import SyncMode
|
|
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator
|
|
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
|
|
from airbyte_cdk.sources.declarative.auth.token import BearerAuthenticator
|
|
from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor
|
|
from airbyte_cdk.sources.declarative.types import StreamSlice, StreamState
|
|
from airbyte_cdk.sources.streams.core import Stream
|
|
|
|
|
|
@dataclass
|
|
class AuthenticatorSquare(DeclarativeAuthenticator):
|
|
config: Mapping[str, Any]
|
|
bearer: BearerAuthenticator
|
|
oauth: DeclarativeOauth2Authenticator
|
|
|
|
def __new__(cls, bearer, oauth, config, *args, **kwargs):
|
|
if config.get("credentials", {}).get("api_key"):
|
|
return bearer
|
|
else:
|
|
return oauth
|
|
|
|
|
|
@dataclass
|
|
class SquareSubstreamIncrementalSync(DatetimeBasedCursor):
|
|
parent_stream: Stream = None
|
|
parent_key: str = None
|
|
parent_records_per_request: int = 10
|
|
|
|
@property
|
|
def logger(self):
|
|
return logging.getLogger(f"airbyte.streams.{self.parent_stream.name}")
|
|
|
|
def get_request_body_json(
|
|
self,
|
|
*,
|
|
stream_state: Optional[StreamState] = {},
|
|
stream_slice: Optional[StreamSlice] = None,
|
|
next_page_token: Optional[Mapping[str, Any]] = None,
|
|
) -> Optional[Mapping]:
|
|
json_payload = {"cursor": next_page_token["cursor"]} if next_page_token else {}
|
|
if stream_slice:
|
|
json_payload.update(stream_slice)
|
|
initial_start_time = self._format_datetime(self.start_datetime.get_datetime(self.config, stream_state={}))
|
|
json_payload["query"] = {
|
|
"filter": {
|
|
"date_time_filter": {
|
|
"updated_at": {
|
|
"start_at": stream_slice.get(self.cursor_field.eval(self.config), initial_start_time),
|
|
}
|
|
}
|
|
},
|
|
"sort": {"sort_field": "UPDATED_AT", "sort_order": "ASC"},
|
|
}
|
|
return json_payload
|
|
|
|
def stream_slices(self) -> Iterable[StreamSlice]:
|
|
locations_records = self.parent_stream.read_records(sync_mode=SyncMode.full_refresh)
|
|
location_ids = [location[self.parent_key] for location in locations_records]
|
|
|
|
if not location_ids:
|
|
self.logger.error(
|
|
"No locations found. Orders cannot be extracted without locations. "
|
|
"Check https://developer.squareup.com/explorer/square/locations-api/list-locations"
|
|
)
|
|
yield from []
|
|
separated_locations = [
|
|
location_ids[i : i + self.parent_records_per_request] for i in range(0, len(location_ids), self.parent_records_per_request)
|
|
]
|
|
for location in separated_locations:
|
|
stream_slice = {"location_ids": location}
|
|
cursor_field = self.cursor_field.eval(self.config)
|
|
if self._cursor:
|
|
# The Square API throws an error if when a datetime is greater than the current time
|
|
current_datetime = datetime.now(timezone.utc)
|
|
cursor_datetime = self.parse_date(self._cursor)
|
|
slice_datetime = (
|
|
current_datetime.strftime(self.datetime_format)
|
|
if cursor_datetime > current_datetime
|
|
else cursor_datetime.strftime(self.datetime_format)
|
|
)
|
|
stream_slice[cursor_field] = slice_datetime
|
|
yield stream_slice
|