diff --git a/redash/authentication/__init__.py b/redash/authentication/__init__.py index 9c41b286c..c7fa63808 100644 --- a/redash/authentication/__init__.py +++ b/redash/authentication/__init__.py @@ -187,6 +187,10 @@ def jwt_token_load_user_from_request(request): if not payload: return + if "email" not in payload: + logger.info("No email field in token, refusing to login") + return + try: user = models.User.get_by_email_and_org(payload["email"], org) except models.NoResultFound: diff --git a/redash/authentication/jwt_auth.py b/redash/authentication/jwt_auth.py index 9d4d80abe..1713bf9e5 100644 --- a/redash/authentication/jwt_auth.py +++ b/redash/authentication/jwt_auth.py @@ -6,6 +6,34 @@ import simplejson logger = logging.getLogger("jwt_auth") +FILE_SCHEME_PREFIX = "file://" + + +def get_public_key_from_file(url): + file_path = url[len(FILE_SCHEME_PREFIX) :] + with open(file_path) as key_file: + key_str = key_file.read() + + get_public_keys.key_cache[url] = [key_str] + return key_str + + +def get_public_key_from_net(url): + r = requests.get(url) + r.raise_for_status() + data = r.json() + if "keys" in data: + public_keys = [] + for key_dict in data["keys"]: + public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict)) + public_keys.append(public_key) + + get_public_keys.key_cache[url] = public_keys + return public_keys + else: + get_public_keys.key_cache[url] = data + return data + def get_public_keys(url): """ @@ -13,23 +41,15 @@ def get_public_keys(url): List of RSA public keys usable by PyJWT. """ key_cache = get_public_keys.key_cache + keys = {} if url in key_cache: - return key_cache[url] + keys = key_cache[url] else: - r = requests.get(url) - r.raise_for_status() - data = r.json() - if "keys" in data: - public_keys = [] - for key_dict in data["keys"]: - public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict)) - public_keys.append(public_key) - - get_public_keys.key_cache[url] = public_keys - return public_keys + if url.startswith(FILE_SCHEME_PREFIX): + keys = [get_public_key_from_file(url)] else: - get_public_keys.key_cache[url] = data - return data + keys = get_public_key_from_net(url) + return keys get_public_keys.key_cache = {} @@ -58,4 +78,5 @@ def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms, break except Exception as e: logging.exception(e) + return payload, valid_token diff --git a/requirements_dev.txt b/requirements_dev.txt index 0c674efca..3a78a85e1 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,4 @@ +jwcrypto==1.5.0 pytest==7.4.0 pytest-cov==4.1.0 coverage==7.2.7 diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 04275669f..8cfbef69e 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,9 +1,14 @@ import importlib +import json import os +import subprocess import time +import jwcrypto.jwk +import jwt +import requests from flask import request -from mock import patch +from mock import Mock, patch from sqlalchemy.orm.exc import NoResultFound from redash import models, settings @@ -11,6 +16,8 @@ from redash.authentication import ( api_key_load_user_from_request, get_login_url, hmac_load_user_from_request, + jwt_auth, + org_settings, sign, ) from redash.authentication.google_oauth import ( @@ -405,3 +412,69 @@ class TestUserForgotPassword(BaseTestCase): self.assertEqual(response.status_code, 200) send_password_reset_email_mock.assert_not_called() send_user_disabled_email_mock.assert_called_with(user) + + +class TestJWTAuthentication(BaseTestCase): + def setUp(self): + super(TestJWTAuthentication, self).setUp() + self.auth_audience = "My Org" + self.auth_issuer = "Admin" + self.token_name = "jwt-token" + self.rsa_private_key = "/tmp/jwtRS256.key" + self.rsa_public_key = "/tmp/jwtRS256.pem" + + if not os.path.exists(self.rsa_public_key): + subprocess.check_output(["openssl", "genrsa", "-out", self.rsa_private_key, "4096"]) + subprocess.check_output( + ["openssl", "rsa", "-pubout", "-in", self.rsa_private_key, "-out", self.rsa_public_key] + ) + + org_settings["auth_jwt_login_enabled"] = True + org_settings["auth_jwt_auth_public_certs_url"] = "file://{}".format(self.rsa_public_key) + org_settings["auth_jwt_auth_issuer"] = self.auth_issuer + org_settings["auth_jwt_auth_audience"] = self.auth_audience + org_settings["auth_jwt_auth_header_name"] = self.token_name + + def tearDown(self): + org_settings["auth_jwt_login_enabled"] = False + org_settings["auth_jwt_auth_public_certs_url"] = "" + org_settings["auth_jwt_auth_issuer"] = "" + org_settings["auth_jwt_auth_audience"] = "" + org_settings["auth_jwt_auth_header_name"] = "" + + def test_jwt_no_token(self): + response = self.get_request("/data_sources", org=self.factory.org) + self.assertEqual(response.status_code, 302) + + def test_jwt_from_pem_file(self): + user = self.factory.create_user() + + issued_at_timestamp = time.time() + expiration_timestamp = issued_at_timestamp + 60 + + data = { + "aud": self.auth_audience, + "email": user.email, + "exp": expiration_timestamp, + "iat": issued_at_timestamp, + "iss": self.auth_issuer, + } + with open(self.rsa_private_key) as keyfile: + sign_key = keyfile.read().strip() + token_data = jwt.encode(data, sign_key, algorithm="RS256") + + response = self.get_request("/data_sources", org=self.factory.org, headers={self.token_name: token_data}) + self.assertEqual(response.status_code, 200) + + @patch.object(requests, "get") + def test_jwk_decode(self, mock_get): + with open(self.rsa_public_key, "rb") as keyfile: + public_key = jwcrypto.jwk.JWK.from_pem(keyfile.read()) + jwk_keys = {"keys": [json.loads(public_key.export())]} + + mockresponse = Mock() + mockresponse.json = lambda: jwk_keys + mock_get.return_value = mockresponse + + keys = jwt_auth.get_public_keys("http://localhost/key.jwt") + self.assertEqual(keys[0].key_size, 4096)