101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
#
|
|
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
#
|
|
|
|
|
|
import logging
|
|
from typing import Any, List, Mapping, MutableMapping, Tuple
|
|
|
|
import requests
|
|
from requests.auth import HTTPBasicAuth
|
|
|
|
from airbyte_cdk.sources import AbstractSource
|
|
from airbyte_cdk.sources.streams import Stream
|
|
from airbyte_cdk.sources.streams.http.requests_native_auth import Oauth2Authenticator
|
|
from source_tplcentral.streams import Customers, Inventory, Items, Orders, StockDetails, StockSummaries
|
|
|
|
|
|
class TplcentralAuthenticator(Oauth2Authenticator):
|
|
def __init__(
|
|
self,
|
|
token_refresh_endpoint: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
user_login_id: int = None,
|
|
user_login: str = None,
|
|
scopes: List[str] = None,
|
|
):
|
|
super().__init__(
|
|
token_refresh_endpoint=token_refresh_endpoint,
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
refresh_token=None,
|
|
)
|
|
|
|
self.token_refresh_endpoint = token_refresh_endpoint
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.scopes = scopes
|
|
self.access_token_name = "access_token"
|
|
self.expires_in_name = "expires_in"
|
|
self.user_login_id = user_login_id
|
|
self.user_login = user_login
|
|
|
|
def get_refresh_request_body(self) -> Mapping[str, Any]:
|
|
payload: MutableMapping[str, Any] = {
|
|
"grant_type": "client_credentials",
|
|
}
|
|
|
|
if self.scopes:
|
|
payload["scopes"] = self.scopes
|
|
|
|
if self.user_login_id:
|
|
payload["user_login_id"] = self.user_login_id
|
|
|
|
if self.user_login:
|
|
payload["user_login"] = self.user_login
|
|
|
|
return payload
|
|
|
|
def refresh_access_token(self) -> Tuple[str, int]:
|
|
try:
|
|
response = requests.post(
|
|
self.token_refresh_endpoint, auth=HTTPBasicAuth(self.client_id, self.client_secret), json=self.get_refresh_request_body()
|
|
)
|
|
response.raise_for_status()
|
|
response_json = response.json()
|
|
return response_json[self.access_token_name], response_json[self.expires_in_name]
|
|
except Exception as e:
|
|
raise Exception(f"Error while refreshing access token: {e}") from e
|
|
|
|
|
|
class SourceTplcentral(AbstractSource):
|
|
def _auth(self, config):
|
|
return TplcentralAuthenticator(
|
|
token_refresh_endpoint=f"{config['url_base']}AuthServer/api/Token",
|
|
client_id=config["client_id"],
|
|
client_secret=config["client_secret"],
|
|
user_login_id=config.get("user_login_id"),
|
|
user_login=config.get("user_login"),
|
|
)
|
|
|
|
def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, any]:
|
|
try:
|
|
self._auth(config).get_auth_header()
|
|
except Exception as e:
|
|
return None, e
|
|
|
|
return True, None
|
|
|
|
def streams(self, config: Mapping[str, Any]) -> List[Stream]:
|
|
config["authenticator"] = self._auth(config)
|
|
|
|
return [
|
|
StockSummaries(config),
|
|
Customers(config),
|
|
Items(config),
|
|
StockDetails(config),
|
|
Inventory(config),
|
|
Orders(config),
|
|
]
|