1
0
mirror of synced 2026-01-20 03:07:18 -05:00
Files
airbyte/airbyte-integrations/connectors/source-square/source_square/source.py
Marcos Marx dca2256a7c Bump 2022 license version (#13233)
* Bump year in license short to 2022

* remove protocol from cdk
2022-05-26 15:00:42 -03:00

478 lines
17 KiB
Python

#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#
import json
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
import pendulum
import requests
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http import HttpStream
from airbyte_cdk.sources.streams.http.auth.core import HttpAuthenticator
from airbyte_cdk.sources.streams.http.requests_native_auth import Oauth2Authenticator, TokenAuthenticator
from requests.auth import AuthBase
from source_square.utils import separate_items_by_count
class SquareException(Exception):
"""Just for formatting the exception as Square"""
def __init__(self, status_code, errors):
self.status_code = status_code
self.errors = errors
def __str__(self):
return f"Code: {self.status_code}, Detail: {self.errors}"
def parse_square_error_response(error: requests.exceptions.HTTPError) -> SquareException:
if error.response.content:
content = json.loads(error.response.content.decode())
if content and "errors" in content:
return SquareException(error.response.status_code, content["errors"])
class SquareStream(HttpStream, ABC):
def __init__(
self,
is_sandbox: bool,
api_version: str,
start_date: str,
include_deleted_objects: bool,
authenticator: Union[AuthBase, HttpAuthenticator],
):
super().__init__(authenticator)
self._authenticator = authenticator
self.is_sandbox = is_sandbox
self.api_version = api_version
# Converting users ISO 8601 format (YYYY-MM-DD) to RFC 3339 (2021-06-14T13:47:56.799Z)
# Because this standard is used by square in 'updated_at' records field
self.start_date = pendulum.parse(start_date).to_rfc3339_string()
self.include_deleted_objects = include_deleted_objects
data_field = None
primary_key = "id"
items_per_page_limit = 100
@property
def url_base(self) -> str:
return "https://connect.squareup{}.com/v2/".format("sandbox" if self.is_sandbox else "")
def next_page_token(self, response: requests.Response) -> Optional[Mapping[str, Any]]:
next_page_cursor = response.json().get("cursor", False)
if next_page_cursor:
return {"cursor": next_page_cursor}
def request_headers(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> Mapping[str, Any]:
return {"Square-Version": self.api_version, "Content-Type": "application/json"}
def parse_response(self, response: requests.Response, **kwargs) -> Iterable[Mapping]:
json_response = response.json()
records = json_response.get(self.data_field, []) if self.data_field is not None else json_response
yield from records
def _send_request(self, request: requests.PreparedRequest, request_kwargs: Mapping[str, Any]) -> requests.Response:
try:
return super()._send_request(request, request_kwargs)
except requests.exceptions.HTTPError as e:
square_exception = parse_square_error_response(e)
if square_exception:
self.logger.error(str(square_exception))
raise e
# Some streams require next_page_token in request query parameters (TeamMemberWages, Customers)
# but others in JSON payload (Items, Discounts, Orders, etc)
# That's why this 2 classes SquareStreamPageParam and SquareStreamPageJson are made
class SquareStreamPageParam(SquareStream, ABC):
def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
return {"cursor": next_page_token["cursor"]} if next_page_token else {}
class SquareStreamPageJson(SquareStream, ABC):
def request_body_json(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> Optional[Mapping]:
return {"cursor": next_page_token["cursor"]} if next_page_token else {}
class SquareStreamPageJsonAndLimit(SquareStreamPageJson, ABC):
def request_body_json(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> Optional[Mapping]:
json_payload = {"limit": self.items_per_page_limit}
if next_page_token:
json_payload.update(next_page_token)
return json_payload
class SquareCatalogObjectsStream(SquareStreamPageJson):
data_field = "objects"
http_method = "POST"
items_per_page_limit = 1000
def path(self, **kwargs) -> str:
return "catalog/search"
def request_body_json(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, next_page_token: Mapping[str, Any] = None
) -> Optional[Mapping]:
json_payload = super().request_body_json(stream_state, stream_slice, next_page_token)
if self.path() == "catalog/search":
json_payload["include_deleted_objects"] = self.include_deleted_objects
json_payload["include_related_objects"] = False
json_payload["limit"] = self.items_per_page_limit
return json_payload
class IncrementalSquareGenericStream(SquareStream, ABC):
def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> Mapping[str, Any]:
if current_stream_state is not None and self.cursor_field in current_stream_state:
return {self.cursor_field: max(current_stream_state[self.cursor_field], latest_record[self.cursor_field])}
else:
return {self.cursor_field: self.start_date}
class IncrementalSquareCatalogObjectsStream(SquareCatalogObjectsStream, IncrementalSquareGenericStream, ABC):
@property
@abstractmethod
def object_type(self):
"""Object type property"""
state_checkpoint_interval = SquareCatalogObjectsStream.items_per_page_limit
cursor_field = "updated_at"
def request_body_json(self, stream_state: Mapping[str, Any], **kwargs) -> Optional[Mapping]:
json_payload = super().request_body_json(stream_state, **kwargs)
if stream_state:
json_payload["begin_time"] = stream_state[self.cursor_field]
json_payload["object_types"] = [self.object_type]
return json_payload
class IncrementalSquareStream(IncrementalSquareGenericStream, SquareStreamPageParam, ABC):
state_checkpoint_interval = SquareStream.items_per_page_limit
cursor_field = "created_at"
def request_params(
self,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> MutableMapping[str, Any]:
params_payload = super().request_params(stream_state, stream_slice, next_page_token)
if stream_state:
params_payload["begin_time"] = stream_state[self.cursor_field]
params_payload["limit"] = self.items_per_page_limit
return params_payload
class Items(IncrementalSquareCatalogObjectsStream):
"""Docs: https://developer.squareup.com/explorer/square/catalog-api/search-catalog-objects
with object_types = ITEM"""
object_type = "ITEM"
class Categories(IncrementalSquareCatalogObjectsStream):
"""Docs: https://developer.squareup.com/explorer/square/catalog-api/search-catalog-objects
with object_types = CATEGORY"""
object_type = "CATEGORY"
class Discounts(IncrementalSquareCatalogObjectsStream):
"""Docs: https://developer.squareup.com/explorer/square/catalog-api/search-catalog-objects
with object_types = DISCOUNT"""
object_type = "DISCOUNT"
class Taxes(IncrementalSquareCatalogObjectsStream):
"""Docs: https://developer.squareup.com/explorer/square/catalog-api/search-catalog-objects
with object_types = TAX"""
object_type = "TAX"
class ModifierList(IncrementalSquareCatalogObjectsStream):
"""Docs: https://developer.squareup.com/explorer/square/catalog-api/search-catalog-objects
with object_types = MODIFIER_LIST"""
object_type = "MODIFIER_LIST"
class Refunds(IncrementalSquareStream):
"""Docs: https://developer.squareup.com/reference/square_2021-06-16/refunds-api/list-payment-refunds"""
data_field = "refunds"
def path(self, **kwargs) -> str:
return "refunds"
def request_params(self, **kwargs) -> MutableMapping[str, Any]:
params_payload = super().request_params(**kwargs)
params_payload["sort_order"] = "ASC"
return params_payload
class Payments(IncrementalSquareStream):
"""Docs: https://developer.squareup.com/reference/square_2021-06-16/payments-api/list-payments"""
data_field = "payments"
def path(self, **kwargs) -> str:
return "payments"
def request_params(self, **kwargs) -> MutableMapping[str, Any]:
params_payload = super().request_params(**kwargs)
params_payload["sort_order"] = "ASC"
return params_payload
class Locations(SquareStream):
"""Docs: https://developer.squareup.com/explorer/square/locations-api/list-locations"""
data_field = "locations"
def path(self, **kwargs) -> str:
return "locations"
class Shifts(SquareStreamPageJsonAndLimit):
"""Docs: https://developer.squareup.com/reference/square/labor-api/search-shifts"""
data_field = "shifts"
http_method = "POST"
items_per_page_limit = 200
def path(self, **kwargs) -> str:
return "labor/shifts/search"
class TeamMembers(SquareStreamPageJsonAndLimit):
"""Docs: https://developer.squareup.com/reference/square/team-api/search-team-members"""
data_field = "team_members"
http_method = "POST"
def path(self, **kwargs) -> str:
return "team-members/search"
class TeamMemberWages(SquareStreamPageParam):
"""Docs: https://developer.squareup.com/reference/square_2021-06-16/labor-api/list-team-member-wages"""
data_field = "team_member_wages"
items_per_page_limit = 200
def path(self, **kwargs) -> str:
return "labor/team-member-wages"
def request_params(self, **kwargs) -> MutableMapping[str, Any]:
params_payload = super().request_params(**kwargs)
params_payload = params_payload or {}
params_payload["limit"] = self.items_per_page_limit
return params_payload
# This stream is tricky because once in a while it returns 404 error 'Not Found for url'.
# Thus the retry strategy was implemented.
def should_retry(self, response: requests.Response) -> bool:
return response.status_code == 404 or super().should_retry(response)
def backoff_time(self, response: requests.Response) -> Optional[float]:
return 3
class Customers(SquareStreamPageParam):
"""Docs: https://developer.squareup.com/reference/square_2021-06-16/customers-api/list-customers"""
data_field = "customers"
def path(self, **kwargs) -> str:
return "customers"
def request_params(self, **kwargs) -> MutableMapping[str, Any]:
params_payload = super().request_params(**kwargs)
params_payload = params_payload or {}
params_payload["sort_order"] = "ASC"
params_payload["sort_field"] = "CREATED_AT"
return params_payload
class Orders(SquareStreamPageJson):
"""Docs: https://developer.squareup.com/reference/square/orders-api/search-orders"""
data_field = "orders"
http_method = "POST"
items_per_page_limit = 500
# There is a restriction in the documentation where only 10 locations can be send at one request
# https://developer.squareup.com/reference/square/orders-api/search-orders#request__property-location_ids
locations_per_requets = 10
def path(self, **kwargs) -> str:
return "orders/search"
def request_body_json(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> Optional[Mapping]:
json_payload = super().request_body_json(stream_slice=stream_slice, **kwargs)
json_payload = json_payload or {}
if stream_slice:
json_payload.update(stream_slice)
json_payload["limit"] = self.items_per_page_limit
return json_payload
def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]:
locations_stream = Locations(
authenticator=self.authenticator,
is_sandbox=self.is_sandbox,
api_version=self.api_version,
start_date=self.start_date,
include_deleted_objects=self.include_deleted_objects,
)
locations_records = locations_stream.read_records(sync_mode=SyncMode.full_refresh)
location_ids = [location["id"] for location in locations_records]
if not location_ids:
self.logger.error(
"No locations found. Orders cannot be extracted without locations. "
"Check https://developer.squareup.com/explorer/square/locations-api/list-locations"
)
yield from []
separated_locations = separate_items_by_count(location_ids, self.locations_per_requets)
for location in separated_locations:
yield {"location_ids": location}
class Oauth2AuthenticatorSquare(Oauth2Authenticator):
def refresh_access_token(self) -> Tuple[str, int]:
"""Handle differences in expiration attr:
from API: "expires_at": "2021-11-05T14:26:57Z"
expected: "expires_in": number of seconds
"""
token, expires_at = super().refresh_access_token()
expires_in = pendulum.parse(expires_at) - pendulum.now()
return token, expires_in.seconds
class SourceSquare(AbstractSource):
api_version = "2021-09-15" # Latest Stable Release
@staticmethod
def get_auth(config: Mapping[str, Any]) -> AuthBase:
credential = config.get("credentials", {})
auth_type = credential.get("auth_type")
if auth_type == "Oauth":
# scopes needed for all currently supported streams:
scopes = [
"CUSTOMERS_READ",
"EMPLOYEES_READ",
"ITEMS_READ",
"MERCHANT_PROFILE_READ",
"ORDERS_READ",
"PAYMENTS_READ",
"TIMECARDS_READ",
# OAuth Permissions:
# https://developer.squareup.com/docs/oauth-api/square-permissions
# https://developer.squareup.com/reference/square/enums/OAuthPermission
# "DISPUTES_READ",
# "GIFTCARDS_READ",
# "INVENTORY_READ",
# "INVOICES_READ",
# "TIMECARDS_SETTINGS_READ",
# "LOYALTY_READ",
# "ONLINE_STORE_SITE_READ",
# "ONLINE_STORE_SNIPPETS_READ",
# "SUBSCRIPTIONS_READ",
]
auth = Oauth2AuthenticatorSquare(
token_refresh_endpoint="https://connect.squareup.com/oauth2/token",
client_secret=credential.get("client_secret"),
client_id=credential.get("client_id"),
refresh_token=credential.get("refresh_token"),
scopes=scopes,
expires_in_name="expires_at",
)
elif auth_type == "Apikey":
auth = TokenAuthenticator(token=credential.get("api_key"))
elif not auth_type and config.get("api_key"):
auth = TokenAuthenticator(token=config.get("api_key"))
else:
raise Exception(f"Invalid auth type: {auth_type}")
return auth
def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, any]:
headers = {
"Square-Version": self.api_version,
"Content-Type": "application/json",
}
auth = self.get_auth(config)
headers.update(auth.get_auth_header())
url = "https://connect.squareup{}.com/v2/catalog/info".format("sandbox" if config["is_sandbox"] else "")
try:
session = requests.get(url, headers=headers)
session.raise_for_status()
return True, None
except requests.exceptions.RequestException as e:
square_exception = parse_square_error_response(e)
if square_exception:
return False, square_exception.errors[0]["detail"]
return False, e
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
args = {
"authenticator": self.get_auth(config),
"is_sandbox": config["is_sandbox"],
"api_version": self.api_version,
"start_date": config["start_date"],
"include_deleted_objects": config["include_deleted_objects"],
}
return [
Items(**args),
Categories(**args),
Discounts(**args),
Taxes(**args),
Locations(**args),
TeamMembers(**args),
TeamMemberWages(**args),
Refunds(**args),
Payments(**args),
Customers(**args),
ModifierList(**args),
Shifts(**args),
Orders(**args),
]