mirror of
https://github.com/getredash/redash.git
synced 2025-12-19 17:37:19 -05:00
Allow RSA key used for JWT to be specified as a file path (#6271)
- auth_jwt_auth_public_certs_url may file:// in addition to http/https - Log an error if payload does not contain an email address
This commit is contained in:
@@ -187,6 +187,10 @@ def jwt_token_load_user_from_request(request):
|
|||||||
if not payload:
|
if not payload:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if "email" not in payload:
|
||||||
|
logger.info("No email field in token, refusing to login")
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = models.User.get_by_email_and_org(payload["email"], org)
|
user = models.User.get_by_email_and_org(payload["email"], org)
|
||||||
except models.NoResultFound:
|
except models.NoResultFound:
|
||||||
|
|||||||
@@ -6,6 +6,34 @@ import simplejson
|
|||||||
|
|
||||||
logger = logging.getLogger("jwt_auth")
|
logger = logging.getLogger("jwt_auth")
|
||||||
|
|
||||||
|
FILE_SCHEME_PREFIX = "file://"
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_key_from_file(url):
|
||||||
|
file_path = url[len(FILE_SCHEME_PREFIX) :]
|
||||||
|
with open(file_path) as key_file:
|
||||||
|
key_str = key_file.read()
|
||||||
|
|
||||||
|
get_public_keys.key_cache[url] = [key_str]
|
||||||
|
return key_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_key_from_net(url):
|
||||||
|
r = requests.get(url)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "keys" in data:
|
||||||
|
public_keys = []
|
||||||
|
for key_dict in data["keys"]:
|
||||||
|
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
|
||||||
|
public_keys.append(public_key)
|
||||||
|
|
||||||
|
get_public_keys.key_cache[url] = public_keys
|
||||||
|
return public_keys
|
||||||
|
else:
|
||||||
|
get_public_keys.key_cache[url] = data
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def get_public_keys(url):
|
def get_public_keys(url):
|
||||||
"""
|
"""
|
||||||
@@ -13,23 +41,15 @@ def get_public_keys(url):
|
|||||||
List of RSA public keys usable by PyJWT.
|
List of RSA public keys usable by PyJWT.
|
||||||
"""
|
"""
|
||||||
key_cache = get_public_keys.key_cache
|
key_cache = get_public_keys.key_cache
|
||||||
|
keys = {}
|
||||||
if url in key_cache:
|
if url in key_cache:
|
||||||
return key_cache[url]
|
keys = key_cache[url]
|
||||||
else:
|
else:
|
||||||
r = requests.get(url)
|
if url.startswith(FILE_SCHEME_PREFIX):
|
||||||
r.raise_for_status()
|
keys = [get_public_key_from_file(url)]
|
||||||
data = r.json()
|
|
||||||
if "keys" in data:
|
|
||||||
public_keys = []
|
|
||||||
for key_dict in data["keys"]:
|
|
||||||
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
|
|
||||||
public_keys.append(public_key)
|
|
||||||
|
|
||||||
get_public_keys.key_cache[url] = public_keys
|
|
||||||
return public_keys
|
|
||||||
else:
|
else:
|
||||||
get_public_keys.key_cache[url] = data
|
keys = get_public_key_from_net(url)
|
||||||
return data
|
return keys
|
||||||
|
|
||||||
|
|
||||||
get_public_keys.key_cache = {}
|
get_public_keys.key_cache = {}
|
||||||
@@ -58,4 +78,5 @@ def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms,
|
|||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
|
|
||||||
return payload, valid_token
|
return payload, valid_token
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
jwcrypto==1.5.0
|
||||||
pytest==7.4.0
|
pytest==7.4.0
|
||||||
pytest-cov==4.1.0
|
pytest-cov==4.1.0
|
||||||
coverage==7.2.7
|
coverage==7.2.7
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import jwcrypto.jwk
|
||||||
|
import jwt
|
||||||
|
import requests
|
||||||
from flask import request
|
from flask import request
|
||||||
from mock import patch
|
from mock import Mock, patch
|
||||||
from sqlalchemy.orm.exc import NoResultFound
|
from sqlalchemy.orm.exc import NoResultFound
|
||||||
|
|
||||||
from redash import models, settings
|
from redash import models, settings
|
||||||
@@ -11,6 +16,8 @@ from redash.authentication import (
|
|||||||
api_key_load_user_from_request,
|
api_key_load_user_from_request,
|
||||||
get_login_url,
|
get_login_url,
|
||||||
hmac_load_user_from_request,
|
hmac_load_user_from_request,
|
||||||
|
jwt_auth,
|
||||||
|
org_settings,
|
||||||
sign,
|
sign,
|
||||||
)
|
)
|
||||||
from redash.authentication.google_oauth import (
|
from redash.authentication.google_oauth import (
|
||||||
@@ -405,3 +412,69 @@ class TestUserForgotPassword(BaseTestCase):
|
|||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
send_password_reset_email_mock.assert_not_called()
|
send_password_reset_email_mock.assert_not_called()
|
||||||
send_user_disabled_email_mock.assert_called_with(user)
|
send_user_disabled_email_mock.assert_called_with(user)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTAuthentication(BaseTestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super(TestJWTAuthentication, self).setUp()
|
||||||
|
self.auth_audience = "My Org"
|
||||||
|
self.auth_issuer = "Admin"
|
||||||
|
self.token_name = "jwt-token"
|
||||||
|
self.rsa_private_key = "/tmp/jwtRS256.key"
|
||||||
|
self.rsa_public_key = "/tmp/jwtRS256.pem"
|
||||||
|
|
||||||
|
if not os.path.exists(self.rsa_public_key):
|
||||||
|
subprocess.check_output(["openssl", "genrsa", "-out", self.rsa_private_key, "4096"])
|
||||||
|
subprocess.check_output(
|
||||||
|
["openssl", "rsa", "-pubout", "-in", self.rsa_private_key, "-out", self.rsa_public_key]
|
||||||
|
)
|
||||||
|
|
||||||
|
org_settings["auth_jwt_login_enabled"] = True
|
||||||
|
org_settings["auth_jwt_auth_public_certs_url"] = "file://{}".format(self.rsa_public_key)
|
||||||
|
org_settings["auth_jwt_auth_issuer"] = self.auth_issuer
|
||||||
|
org_settings["auth_jwt_auth_audience"] = self.auth_audience
|
||||||
|
org_settings["auth_jwt_auth_header_name"] = self.token_name
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
org_settings["auth_jwt_login_enabled"] = False
|
||||||
|
org_settings["auth_jwt_auth_public_certs_url"] = ""
|
||||||
|
org_settings["auth_jwt_auth_issuer"] = ""
|
||||||
|
org_settings["auth_jwt_auth_audience"] = ""
|
||||||
|
org_settings["auth_jwt_auth_header_name"] = ""
|
||||||
|
|
||||||
|
def test_jwt_no_token(self):
|
||||||
|
response = self.get_request("/data_sources", org=self.factory.org)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
|
||||||
|
def test_jwt_from_pem_file(self):
|
||||||
|
user = self.factory.create_user()
|
||||||
|
|
||||||
|
issued_at_timestamp = time.time()
|
||||||
|
expiration_timestamp = issued_at_timestamp + 60
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"aud": self.auth_audience,
|
||||||
|
"email": user.email,
|
||||||
|
"exp": expiration_timestamp,
|
||||||
|
"iat": issued_at_timestamp,
|
||||||
|
"iss": self.auth_issuer,
|
||||||
|
}
|
||||||
|
with open(self.rsa_private_key) as keyfile:
|
||||||
|
sign_key = keyfile.read().strip()
|
||||||
|
token_data = jwt.encode(data, sign_key, algorithm="RS256")
|
||||||
|
|
||||||
|
response = self.get_request("/data_sources", org=self.factory.org, headers={self.token_name: token_data})
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
@patch.object(requests, "get")
|
||||||
|
def test_jwk_decode(self, mock_get):
|
||||||
|
with open(self.rsa_public_key, "rb") as keyfile:
|
||||||
|
public_key = jwcrypto.jwk.JWK.from_pem(keyfile.read())
|
||||||
|
jwk_keys = {"keys": [json.loads(public_key.export())]}
|
||||||
|
|
||||||
|
mockresponse = Mock()
|
||||||
|
mockresponse.json = lambda: jwk_keys
|
||||||
|
mock_get.return_value = mockresponse
|
||||||
|
|
||||||
|
keys = jwt_auth.get_public_keys("http://localhost/key.jwt")
|
||||||
|
self.assertEqual(keys[0].key_size, 4096)
|
||||||
|
|||||||
Reference in New Issue
Block a user