mirror of
https://github.com/getredash/redash.git
synced 2025-12-19 09:27:23 -05:00
Allow RSA key used for JWT to be specified as a file path (#6271)
- auth_jwt_auth_public_certs_url may file:// in addition to http/https - Log an error if payload does not contain an email address
This commit is contained in:
@@ -187,6 +187,10 @@ def jwt_token_load_user_from_request(request):
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if "email" not in payload:
|
||||
logger.info("No email field in token, refusing to login")
|
||||
return
|
||||
|
||||
try:
|
||||
user = models.User.get_by_email_and_org(payload["email"], org)
|
||||
except models.NoResultFound:
|
||||
|
||||
@@ -6,6 +6,34 @@ import simplejson
|
||||
|
||||
logger = logging.getLogger("jwt_auth")
|
||||
|
||||
FILE_SCHEME_PREFIX = "file://"
|
||||
|
||||
|
||||
def get_public_key_from_file(url):
|
||||
file_path = url[len(FILE_SCHEME_PREFIX) :]
|
||||
with open(file_path) as key_file:
|
||||
key_str = key_file.read()
|
||||
|
||||
get_public_keys.key_cache[url] = [key_str]
|
||||
return key_str
|
||||
|
||||
|
||||
def get_public_key_from_net(url):
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "keys" in data:
|
||||
public_keys = []
|
||||
for key_dict in data["keys"]:
|
||||
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
|
||||
public_keys.append(public_key)
|
||||
|
||||
get_public_keys.key_cache[url] = public_keys
|
||||
return public_keys
|
||||
else:
|
||||
get_public_keys.key_cache[url] = data
|
||||
return data
|
||||
|
||||
|
||||
def get_public_keys(url):
|
||||
"""
|
||||
@@ -13,23 +41,15 @@ def get_public_keys(url):
|
||||
List of RSA public keys usable by PyJWT.
|
||||
"""
|
||||
key_cache = get_public_keys.key_cache
|
||||
keys = {}
|
||||
if url in key_cache:
|
||||
return key_cache[url]
|
||||
keys = key_cache[url]
|
||||
else:
|
||||
r = requests.get(url)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "keys" in data:
|
||||
public_keys = []
|
||||
for key_dict in data["keys"]:
|
||||
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
|
||||
public_keys.append(public_key)
|
||||
|
||||
get_public_keys.key_cache[url] = public_keys
|
||||
return public_keys
|
||||
if url.startswith(FILE_SCHEME_PREFIX):
|
||||
keys = [get_public_key_from_file(url)]
|
||||
else:
|
||||
get_public_keys.key_cache[url] = data
|
||||
return data
|
||||
keys = get_public_key_from_net(url)
|
||||
return keys
|
||||
|
||||
|
||||
get_public_keys.key_cache = {}
|
||||
@@ -58,4 +78,5 @@ def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms,
|
||||
break
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
return payload, valid_token
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
jwcrypto==1.5.0
|
||||
pytest==7.4.0
|
||||
pytest-cov==4.1.0
|
||||
coverage==7.2.7
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import jwcrypto.jwk
|
||||
import jwt
|
||||
import requests
|
||||
from flask import request
|
||||
from mock import patch
|
||||
from mock import Mock, patch
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from redash import models, settings
|
||||
@@ -11,6 +16,8 @@ from redash.authentication import (
|
||||
api_key_load_user_from_request,
|
||||
get_login_url,
|
||||
hmac_load_user_from_request,
|
||||
jwt_auth,
|
||||
org_settings,
|
||||
sign,
|
||||
)
|
||||
from redash.authentication.google_oauth import (
|
||||
@@ -405,3 +412,69 @@ class TestUserForgotPassword(BaseTestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
send_password_reset_email_mock.assert_not_called()
|
||||
send_user_disabled_email_mock.assert_called_with(user)
|
||||
|
||||
|
||||
class TestJWTAuthentication(BaseTestCase):
|
||||
def setUp(self):
|
||||
super(TestJWTAuthentication, self).setUp()
|
||||
self.auth_audience = "My Org"
|
||||
self.auth_issuer = "Admin"
|
||||
self.token_name = "jwt-token"
|
||||
self.rsa_private_key = "/tmp/jwtRS256.key"
|
||||
self.rsa_public_key = "/tmp/jwtRS256.pem"
|
||||
|
||||
if not os.path.exists(self.rsa_public_key):
|
||||
subprocess.check_output(["openssl", "genrsa", "-out", self.rsa_private_key, "4096"])
|
||||
subprocess.check_output(
|
||||
["openssl", "rsa", "-pubout", "-in", self.rsa_private_key, "-out", self.rsa_public_key]
|
||||
)
|
||||
|
||||
org_settings["auth_jwt_login_enabled"] = True
|
||||
org_settings["auth_jwt_auth_public_certs_url"] = "file://{}".format(self.rsa_public_key)
|
||||
org_settings["auth_jwt_auth_issuer"] = self.auth_issuer
|
||||
org_settings["auth_jwt_auth_audience"] = self.auth_audience
|
||||
org_settings["auth_jwt_auth_header_name"] = self.token_name
|
||||
|
||||
def tearDown(self):
|
||||
org_settings["auth_jwt_login_enabled"] = False
|
||||
org_settings["auth_jwt_auth_public_certs_url"] = ""
|
||||
org_settings["auth_jwt_auth_issuer"] = ""
|
||||
org_settings["auth_jwt_auth_audience"] = ""
|
||||
org_settings["auth_jwt_auth_header_name"] = ""
|
||||
|
||||
def test_jwt_no_token(self):
|
||||
response = self.get_request("/data_sources", org=self.factory.org)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
|
||||
def test_jwt_from_pem_file(self):
|
||||
user = self.factory.create_user()
|
||||
|
||||
issued_at_timestamp = time.time()
|
||||
expiration_timestamp = issued_at_timestamp + 60
|
||||
|
||||
data = {
|
||||
"aud": self.auth_audience,
|
||||
"email": user.email,
|
||||
"exp": expiration_timestamp,
|
||||
"iat": issued_at_timestamp,
|
||||
"iss": self.auth_issuer,
|
||||
}
|
||||
with open(self.rsa_private_key) as keyfile:
|
||||
sign_key = keyfile.read().strip()
|
||||
token_data = jwt.encode(data, sign_key, algorithm="RS256")
|
||||
|
||||
response = self.get_request("/data_sources", org=self.factory.org, headers={self.token_name: token_data})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@patch.object(requests, "get")
|
||||
def test_jwk_decode(self, mock_get):
|
||||
with open(self.rsa_public_key, "rb") as keyfile:
|
||||
public_key = jwcrypto.jwk.JWK.from_pem(keyfile.read())
|
||||
jwk_keys = {"keys": [json.loads(public_key.export())]}
|
||||
|
||||
mockresponse = Mock()
|
||||
mockresponse.json = lambda: jwk_keys
|
||||
mock_get.return_value = mockresponse
|
||||
|
||||
keys = jwt_auth.get_public_keys("http://localhost/key.jwt")
|
||||
self.assertEqual(keys[0].key_size, 4096)
|
||||
|
||||
Reference in New Issue
Block a user