1
0
mirror of synced 2026-01-18 15:02:51 -05:00
Files
airbyte/airbyte-cdk/python/unit_tests/sources/declarative/auth/test_jwt.py
2024-04-19 09:17:54 -07:00

172 lines
5.4 KiB
Python

#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import base64
import logging
from datetime import datetime
import freezegun
import jwt
import pytest
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
LOGGER = logging.getLogger(__name__)
class TestJwtAuthenticator:
"""
Test class for JWT Authenticator.
"""
@pytest.mark.parametrize(
"algorithm, kid, typ, cty, additional_jwt_headers, expected",
[
("ALGORITHM", "test_kid", "test_typ", "test_cty", {"test": "test"}, {"kid": "test_kid", "typ": "test_typ", "cty": "test_cty", "test": "test", "alg": "ALGORITHM"}),
("ALGORITHM", None, None, None, None, {"alg": "ALGORITHM"})
]
)
def test_get_jwt_headers(self, algorithm, kid, typ, cty, additional_jwt_headers, expected):
authenticator = JwtAuthenticator(
config={},
parameters={},
algorithm=algorithm,
secret_key="test_key",
token_duration=1200,
kid=kid,
typ=typ,
cty=cty,
additional_jwt_headers=additional_jwt_headers,
)
assert authenticator._get_jwt_headers() == expected
def test_given_overriden_reserverd_properties_get_jwt_headers_throws_error(self):
authenticator = JwtAuthenticator(
config={},
parameters={},
algorithm="ALGORITHM",
secret_key="test_key",
token_duration=1200,
additional_jwt_headers={"kid": "test_kid"},
)
with pytest.raises(ValueError):
authenticator._get_jwt_headers()
@pytest.mark.parametrize(
"iss, sub, aud, additional_jwt_payload, expected",
[
(
"test_iss",
"test_sub",
"test_aud",
{"test": "test"},
{"iss": "test_iss", "sub": "test_sub", "aud": "test_aud", "test": "test"},
),
(
None,
None,
None,
None,
{}
),
]
)
def test_get_jwt_payload(self, iss, sub, aud, additional_jwt_payload, expected):
authenticator = JwtAuthenticator(
config={},
parameters={},
algorithm="ALGORITHM",
secret_key="test_key",
token_duration=1000,
iss=iss,
sub=sub,
aud=aud,
additional_jwt_payload=additional_jwt_payload,
)
with freezegun.freeze_time("2022-01-01 00:00:00"):
expected["iat"] = int(datetime.now().timestamp())
expected["exp"] = expected["iat"] + 1000
expected["nbf"] = expected["iat"]
assert authenticator._get_jwt_payload() == expected
def test_given_overriden_reserverd_properties_get_jwt_payload_throws_error(self):
authenticator = JwtAuthenticator(
config={},
parameters={},
algorithm="ALGORITHM",
secret_key="test_key",
token_duration=0,
additional_jwt_payload={"exp": 1234},
)
with pytest.raises(ValueError):
authenticator._get_jwt_payload()
@pytest.mark.parametrize(
"base64_encode_secret_key, secret_key, expected",
[
(True, "test", base64.b64encode("test".encode()).decode()),
(False, "test", "test"),
]
)
def test_get_secret_key(self, base64_encode_secret_key, secret_key, expected):
authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key=secret_key,
algorithm="test_algo",
token_duration=1200,
base64_encode_secret_key=base64_encode_secret_key,
)
assert authenticator._get_secret_key() == expected
def test_get_signed_token(self):
authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key="test",
algorithm="HS256",
token_duration=1000,
typ="JWT",
iss="iss",
)
assert authenticator._get_signed_token() == jwt.encode(
payload=authenticator._get_jwt_payload(),
key=authenticator._get_secret_key(),
algorithm=authenticator._algorithm,
headers=authenticator._get_jwt_headers(),
)
def test_given_invalid_algorithm_get_signed_token_throws_error(self):
authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key="test",
algorithm="invalid algorithm type",
token_duration=1000,
base64_encode_secret_key=False,
header_prefix="Bearer",
typ="JWT",
iss="iss",
additional_jwt_headers={},
additional_jwt_payload={},
)
with pytest.raises(ValueError):
authenticator._get_signed_token()
@pytest.mark.parametrize(
"header_prefix, expected",
[
("test", "test"),
(None, None)
]
)
def test_get_header_prefix(self, header_prefix, expected):
authenticator = JwtAuthenticator(
config={},
parameters={},
secret_key="key",
algorithm="test_algo",
token_duration=1200,
header_prefix=header_prefix,
)
assert authenticator._get_header_prefix() == expected