From 374e265fcbdfcef41a77bfcd7ec2acd2450d181d Mon Sep 17 00:00:00 2001 From: Brian Lai <51336873+brianjlai@users.noreply.github.com> Date: Fri, 8 Jul 2022 16:49:16 -0400 Subject: [PATCH] [Low Code CDK] configurable oauth request payload (#13993) * configurable oauth request payload * support interpolation for dictionaries that are not new subcomponents * rewrite a declarative oauth authenticator that performs interpolation at runtime * formatting * whatever i don't know why factory gets flagged w/ the newline change * we java now * remove duplicate oauth * add some comments * parse time properly from string interpolation * move declarative oauth to its own package in declarative module * add changelog info --- airbyte-cdk/python/CHANGELOG.md | 3 + .../sources/declarative/auth/__init__.py | 9 ++ .../sources/declarative/auth/oauth.py | 137 ++++++++++++++++++ .../sources/declarative/parsers/factory.py | 1 - .../http/requests_native_auth/__init__.py | 3 +- .../requests_native_auth/abstract_oauth.py | 129 +++++++++++++++++ .../http/requests_native_auth/oauth.py | 127 ++++++++++------ airbyte-cdk/python/setup.py | 2 +- .../sources/declarative/auth/__init__.py | 3 + .../sources/declarative/auth/test_oauth.py | 90 ++++++++++++ .../sources/declarative/test_factory.py | 18 ++- .../test_requests_native_auth.py | 46 ++++-- 12 files changed, 498 insertions(+), 70 deletions(-) create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/__init__.py create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py create mode 100644 airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py create mode 100644 airbyte-cdk/python/unit_tests/sources/declarative/auth/__init__.py create mode 100644 airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py diff --git a/airbyte-cdk/python/CHANGELOG.md b/airbyte-cdk/python/CHANGELOG.md index 026db2b453c..388a7d688b3 100644 --- a/airbyte-cdk/python/CHANGELOG.md +++ b/airbyte-cdk/python/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 0.1.64 +- Add support for configurable oauth request payload and declarative oauth authenticator. + ## 0.1.63 - Define `namespace` property on the `Stream` class inside `core.py`. diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/__init__.py new file mode 100644 index 00000000000..eb3b84dde72 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeOauth2Authenticator + +__all__ = [ + "DeclarativeOauth2Authenticator", +] diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py new file mode 100644 index 00000000000..95cc5f162b3 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py @@ -0,0 +1,137 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +from typing import Any, List, Mapping + +import pendulum +from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping +from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator + + +class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator): + """ + Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials based on + a declarative connector configuration file. Credentials can be defined explicitly or via interpolation + at runtime. The generated access token is attached to each request via the Authorization header. + """ + + def __init__( + self, + token_refresh_endpoint: str, + client_id: str, + client_secret: str, + refresh_token: str, + config: Mapping[str, Any], + scopes: List[str] = None, + token_expiry_date: str = None, + access_token_name: str = "access_token", + expires_in_name: str = "expires_in", + refresh_request_body: Mapping[str, Any] = None, + ): + self.config = config + self.token_refresh_endpoint = InterpolatedString(token_refresh_endpoint) + self.client_secret = InterpolatedString(client_secret) + self.client_id = InterpolatedString(client_id) + self.refresh_token = InterpolatedString(refresh_token) + self.scopes = scopes + self.access_token_name = InterpolatedString(access_token_name) + self.expires_in_name = InterpolatedString(expires_in_name) + self.refresh_request_body = InterpolatedMapping(refresh_request_body) + + self.token_expiry_date = ( + pendulum.parse(InterpolatedString(token_expiry_date).eval(self.config)) + if token_expiry_date + else pendulum.now().subtract(days=1) + ) + self.access_token = None + + @property + def config(self) -> Mapping[str, Any]: + return self._config + + @config.setter + def config(self, value: Mapping[str, Any]): + self._config = value + + @property + def token_refresh_endpoint(self) -> InterpolatedString: + get_some = self._token_refresh_endpoint.eval(self.config) + return get_some + + @token_refresh_endpoint.setter + def token_refresh_endpoint(self, value: InterpolatedString): + self._token_refresh_endpoint = value + + @property + def client_id(self) -> InterpolatedString: + return self._client_id.eval(self.config) + + @client_id.setter + def client_id(self, value: InterpolatedString): + self._client_id = value + + @property + def client_secret(self) -> InterpolatedString: + return self._client_secret.eval(self.config) + + @client_secret.setter + def client_secret(self, value: InterpolatedString): + self._client_secret = value + + @property + def refresh_token(self) -> InterpolatedString: + return self._refresh_token.eval(self.config) + + @refresh_token.setter + def refresh_token(self, value: InterpolatedString): + self._refresh_token = value + + @property + def scopes(self) -> [str]: + return self._scopes + + @scopes.setter + def scopes(self, value: [str]): + self._scopes = value + + @property + def token_expiry_date(self) -> pendulum.DateTime: + return self._token_expiry_date + + @token_expiry_date.setter + def token_expiry_date(self, value: pendulum.DateTime): + self._token_expiry_date = value + + @property + def access_token_name(self) -> InterpolatedString: + return self._access_token_name.eval(self.config) + + @access_token_name.setter + def access_token_name(self, value: InterpolatedString): + self._access_token_name = value + + @property + def expires_in_name(self) -> InterpolatedString: + return self._expires_in_name.eval(self.config) + + @expires_in_name.setter + def expires_in_name(self, value: InterpolatedString): + self._expires_in_name = value + + @property + def refresh_request_body(self) -> InterpolatedMapping: + return self._refresh_request_body.eval(self.config) + + @refresh_request_body.setter + def refresh_request_body(self, value: InterpolatedMapping): + self._refresh_request_body = value + + @property + def access_token(self) -> str: + return self._access_token + + @access_token.setter + def access_token(self, value: str): + self._access_token = value diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/factory.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/factory.py index 4d66f96317b..1ca2eef42d3 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/factory.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/factory.py @@ -67,7 +67,6 @@ class DeclarativeComponentFactory: if self.is_object_definition_with_class_name(definition): # propagate kwargs to inner objects definition["options"] = self._merge_dicts(kwargs.get("options", dict()), definition.get("options", dict())) - return self.create_component(definition, config)() elif self.is_object_definition_with_type(definition): # If type is set instead of class_name, get the class_name from the CLASS_TYPES_REGISTRY diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py index 08bdcd068be..26e6f60099a 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/__init__.py @@ -1,7 +1,6 @@ # -# Copyright (c) 2021 Airbyte, Inc., all rights reserved. +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. # - from .oauth import Oauth2Authenticator from .token import MultipleTokenAuthenticator, TokenAuthenticator diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py new file mode 100644 index 00000000000..75afb711425 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +from abc import abstractmethod +from typing import Any, Mapping, MutableMapping, Tuple + +import pendulum +import requests +from requests.auth import AuthBase + + +class AbstractOauth2Authenticator(AuthBase): + """ + Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator + is designed to generically perform the refresh flow without regard to how config fields are get/set by + delegating that behavior to the classes implementing the interface. + """ + + def __call__(self, request): + request.headers.update(self.get_auth_header()) + return request + + def get_auth_header(self) -> Mapping[str, Any]: + return {"Authorization": f"Bearer {self.get_access_token()}"} + + def get_access_token(self): + if self.token_has_expired(): + t0 = pendulum.now() + token, expires_in = self.refresh_access_token() + self.access_token = token + self.token_expiry_date = t0.add(seconds=expires_in) + + return self.access_token + + def token_has_expired(self) -> bool: + return pendulum.now() > self.token_expiry_date + + def get_refresh_request_body(self) -> Mapping[str, Any]: + """Override to define additional parameters""" + payload: MutableMapping[str, Any] = { + "grant_type": "refresh_token", + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": self.refresh_token, + } + + if self.scopes: + payload["scopes"] = self.scopes + + if self.refresh_request_body: + for key, val in self.refresh_request_body.items(): + # We defer to existing oauth constructs over custom configured fields + if key not in payload: + payload[key] = val + + return payload + + def refresh_access_token(self) -> Tuple[str, int]: + """ + returns a tuple of (access_token, token_lifespan_in_seconds) + """ + try: + response = requests.request(method="POST", url=self.token_refresh_endpoint, data=self.get_refresh_request_body()) + response.raise_for_status() + response_json = response.json() + return response_json[self.access_token_name], response_json[self.expires_in_name] + except Exception as e: + raise Exception(f"Error while refreshing access token: {e}") from e + + @property + @abstractmethod + def token_refresh_endpoint(self): + pass + + @property + @abstractmethod + def client_id(self): + pass + + @property + @abstractmethod + def client_secret(self): + pass + + @property + @abstractmethod + def refresh_token(self): + pass + + @property + @abstractmethod + def scopes(self): + pass + + @property + @abstractmethod + def token_expiry_date(self): + pass + + @token_expiry_date.setter + @abstractmethod + def token_expiry_date(self, value): + pass + + @property + @abstractmethod + def access_token_name(self): + pass + + @property + @abstractmethod + def expires_in_name(self): + pass + + @property + @abstractmethod + def refresh_request_body(self): + pass + + @property + @abstractmethod + def access_token(self): + pass + + @access_token.setter + @abstractmethod + def access_token(self, value): + pass diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 84fed4ab977..ec37f436b6e 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -2,15 +2,13 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # - -from typing import Any, List, Mapping, MutableMapping, Tuple +from typing import Any, List, Mapping import pendulum -import requests -from requests.auth import AuthBase +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator -class Oauth2Authenticator(AuthBase): +class Oauth2Authenticator(AbstractOauth2Authenticator): """ Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials. The generated access token is attached to each request via the Authorization header. @@ -26,6 +24,7 @@ class Oauth2Authenticator(AuthBase): token_expiry_date: pendulum.DateTime = None, access_token_name: str = "access_token", expires_in_name: str = "expires_in", + refresh_request_body: Mapping[str, Any] = None, ): self.token_refresh_endpoint = token_refresh_endpoint self.client_secret = client_secret @@ -34,51 +33,87 @@ class Oauth2Authenticator(AuthBase): self.scopes = scopes self.access_token_name = access_token_name self.expires_in_name = expires_in_name + self.refresh_request_body = refresh_request_body - self._token_expiry_date = token_expiry_date or pendulum.now().subtract(days=1) - self._access_token = None + self.token_expiry_date = token_expiry_date or pendulum.now().subtract(days=1) + self.access_token = None - def __call__(self, request): - request.headers.update(self.get_auth_header()) - return request + @property + def token_refresh_endpoint(self) -> str: + return self._token_refresh_endpoint - def get_auth_header(self) -> Mapping[str, Any]: - return {"Authorization": f"Bearer {self.get_access_token()}"} + @token_refresh_endpoint.setter + def token_refresh_endpoint(self, value: str): + self._token_refresh_endpoint = value - def get_access_token(self): - if self.token_has_expired(): - t0 = pendulum.now() - token, expires_in = self.refresh_access_token() - self._access_token = token - self._token_expiry_date = t0.add(seconds=expires_in) + @property + def client_id(self) -> str: + return self._client_id + @client_id.setter + def client_id(self, value: str): + self._client_id = value + + @property + def client_secret(self) -> str: + return self._client_secret + + @client_secret.setter + def client_secret(self, value: str): + self._client_secret = value + + @property + def refresh_token(self) -> str: + return self._refresh_token + + @refresh_token.setter + def refresh_token(self, value: str): + self._refresh_token = value + + @property + def access_token_name(self) -> str: + return self._access_token_name + + @access_token_name.setter + def access_token_name(self, value: str): + self._access_token_name = value + + @property + def scopes(self) -> [str]: + return self._scopes + + @scopes.setter + def scopes(self, value: [str]): + self._scopes = value + + @property + def token_expiry_date(self) -> pendulum.DateTime: + return self._token_expiry_date + + @token_expiry_date.setter + def token_expiry_date(self, value: pendulum.DateTime): + self._token_expiry_date = value + + @property + def expires_in_name(self) -> str: + return self._expires_in_name + + @expires_in_name.setter + def expires_in_name(self, value): + self._expires_in_name = value + + @property + def refresh_request_body(self) -> Mapping[str, Any]: + return self._refresh_request_body + + @refresh_request_body.setter + def refresh_request_body(self, value: Mapping[str, Any]): + self._refresh_request_body = value + + @property + def access_token(self) -> str: return self._access_token - def token_has_expired(self) -> bool: - return pendulum.now() > self._token_expiry_date - - def get_refresh_request_body(self) -> Mapping[str, Any]: - """Override to define additional parameters""" - payload: MutableMapping[str, Any] = { - "grant_type": "refresh_token", - "client_id": self.client_id, - "client_secret": self.client_secret, - "refresh_token": self.refresh_token, - } - - if self.scopes: - payload["scopes"] = self.scopes - - return payload - - def refresh_access_token(self) -> Tuple[str, int]: - """ - returns a tuple of (access_token, token_lifespan_in_seconds) - """ - try: - response = requests.request(method="POST", url=self.token_refresh_endpoint, data=self.get_refresh_request_body()) - response.raise_for_status() - response_json = response.json() - return response_json[self.access_token_name], response_json[self.expires_in_name] - except Exception as e: - raise Exception(f"Error while refreshing access token: {e}") from e + @access_token.setter + def access_token(self, value: str): + self._access_token = value diff --git a/airbyte-cdk/python/setup.py b/airbyte-cdk/python/setup.py index 64f8a66f748..fcdfa0270e5 100644 --- a/airbyte-cdk/python/setup.py +++ b/airbyte-cdk/python/setup.py @@ -15,7 +15,7 @@ README = (HERE / "README.md").read_text() setup( name="airbyte-cdk", - version="0.1.63", + version="0.1.64", description="A framework for writing Airbyte Connectors.", long_description=README, long_description_content_type="text/markdown", diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/auth/__init__.py b/airbyte-cdk/python/unit_tests/sources/declarative/auth/__init__.py new file mode 100644 index 00000000000..1100c1c58cf --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/declarative/auth/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py b/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py new file mode 100644 index 00000000000..2d0c1d265fa --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_oauth.py @@ -0,0 +1,90 @@ +# +# Copyright (c) 2022 Airbyte, Inc., all rights reserved. +# + +import logging + +import pendulum +import requests +from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator +from requests import Response + +LOGGER = logging.getLogger(__name__) + +resp = Response() + +config = { + "refresh_endpoint": "refresh_end", + "client_id": "some_client_id", + "client_secret": "some_client_secret", + "refresh_token": "some_refresh_token", + "token_expiry_date": pendulum.now().subtract(days=2).to_rfc3339_string(), + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", +} + + +class TestOauth2Authenticator: + """ + Test class for OAuth2Authenticator. + """ + + def test_refresh_request_body(self): + """ + Request body should match given configuration. + """ + scopes = ["scope1", "scope2"] + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ config['refresh_token'] }}", + config=config, + scopes=["scope1", "scope2"], + token_expiry_date="{{ config['token_expiry_date'] }}", + refresh_request_body={ + "custom_field": "{{ config['custom_field'] }}", + "another_field": "{{ config['another_field'] }}", + "scopes": ["no_override"], + }, + ) + body = oauth.get_refresh_request_body() + expected = { + "grant_type": "refresh_token", + "client_id": "some_client_id", + "client_secret": "some_client_secret", + "refresh_token": "some_refresh_token", + "scopes": scopes, + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", + } + assert body == expected + + def test_refresh_access_token(self, mocker): + oauth = DeclarativeOauth2Authenticator( + token_refresh_endpoint="{{ config['refresh_endpoint'] }}", + client_id="{{ config['client_id'] }}", + client_secret="{{ config['client_secret'] }}", + refresh_token="{{ config['refresh_token'] }}", + config=config, + scopes=["scope1", "scope2"], + token_expiry_date="{{ config['token_expiry_date'] }}", + refresh_request_body={ + "custom_field": "{{ config['custom_field'] }}", + "another_field": "{{ config['another_field'] }}", + "scopes": ["no_override"], + }, + ) + + resp.status_code = 200 + mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) + token = oauth.refresh_access_token() + + assert ("access_token", 1000) == token + + +def mock_request(method, url, data): + if url == "refresh_end": + return resp + raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}") diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/test_factory.py b/airbyte-cdk/python/unit_tests/sources/declarative/test_factory.py index a23ab74369e..29d77e7eea5 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/test_factory.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/test_factory.py @@ -53,13 +53,23 @@ def test_factory(): def test_interpolate_config(): content = """ -authenticator: - class_name: airbyte_cdk.sources.streams.http.requests_native_auth.token.TokenAuthenticator - token: "{{ config['apikey'] }}" + authenticator: + class_name: airbyte_cdk.sources.declarative.auth.oauth.DeclarativeOauth2Authenticator + client_id: "some_client_id" + client_secret: "some_client_secret" + token_refresh_endpoint: "https://api.sendgrid.com/v3/auth" + refresh_token: "{{ config['apikey'] }}" + refresh_request_body: + body_field: "yoyoyo" + interpolated_body_field: "{{ config['apikey'] }}" """ config = parser.parse(content) authenticator = factory.create_component(config["authenticator"], input_config)() - assert authenticator._tokens == ["verysecrettoken"] + assert authenticator._client_id._string == "some_client_id" + assert authenticator._client_secret._string == "some_client_secret" + assert authenticator._token_refresh_endpoint._string == "https://api.sendgrid.com/v3/auth" + assert authenticator._refresh_token._string == "verysecrettoken" + assert authenticator._refresh_request_body._mapping == {"body_field": "yoyoyo", "interpolated_body_field": "{{ config['apikey'] }}"} def test_list_based_stream_slicer_with_values_refd(): diff --git a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index b4e737c19e8..5185950f4a8 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -2,15 +2,17 @@ # Copyright (c) 2022 Airbyte, Inc., all rights reserved. # - import logging +import pendulum import requests from airbyte_cdk.sources.streams.http.requests_native_auth import MultipleTokenAuthenticator, Oauth2Authenticator, TokenAuthenticator from requests import Response LOGGER = logging.getLogger(__name__) +resp = Response() + def test_token_authenticator(): """ @@ -96,34 +98,40 @@ class TestOauth2Authenticator: """ scopes = ["scope1", "scope2"] oauth = Oauth2Authenticator( - token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint, - client_id=TestOauth2Authenticator.client_id, - client_secret=TestOauth2Authenticator.client_secret, - refresh_token=TestOauth2Authenticator.refresh_token, - scopes=scopes, + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + scopes=["scope1", "scope2"], + token_expiry_date=pendulum.now().add(days=3), + refresh_request_body={"custom_field": "in_outbound_request", "another_field": "exists_in_body", "scopes": ["no_override"]}, ) body = oauth.get_refresh_request_body() expected = { "grant_type": "refresh_token", - "client_id": "client_id", - "client_secret": "client_secret", - "refresh_token": "refresh_token", + "client_id": "some_client_id", + "client_secret": "some_client_secret", + "refresh_token": "some_refresh_token", "scopes": scopes, + "custom_field": "in_outbound_request", + "another_field": "exists_in_body", } assert body == expected def test_refresh_access_token(self, mocker): oauth = Oauth2Authenticator( - token_refresh_endpoint=TestOauth2Authenticator.refresh_endpoint, - client_id=TestOauth2Authenticator.client_id, - client_secret=TestOauth2Authenticator.client_secret, - refresh_token=TestOauth2Authenticator.refresh_token, + token_refresh_endpoint="refresh_end", + client_id="some_client_id", + client_secret="some_client_secret", + refresh_token="some_refresh_token", + scopes=["scope1", "scope2"], + token_expiry_date=pendulum.now().add(days=3), + refresh_request_body={"custom_field": "in_outbound_request", "another_field": "exists_in_body", "scopes": ["no_override"]}, ) - resp = Response() - resp.status_code = 200 - mocker.patch.object(requests, "request", return_value=resp) + resp.status_code = 200 mocker.patch.object(resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}) + mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True) token = oauth.refresh_access_token() assert ("access_token", 1000) == token @@ -142,3 +150,9 @@ class TestOauth2Authenticator: oauth(prepared_request) assert {"Authorization": "Bearer access_token"} == prepared_request.headers + + +def mock_request(method, url, data): + if url == "refresh_end": + return resp + raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")