Skip to content

Commit

Permalink
(PXP-6339): Fetch and cache public keys for JWT validation (#52)
Browse files Browse the repository at this point in the history
* fix(jwks-uri): get jwks_uri from .well-known md doc if avbl

* fix(jwks-uri): make return statement neater, add comment, blacken

* fix(jwks-uri): serialize and cache generic public keys

* fix(scope-validation): split space-delimited scope strings

* fix(jwks-uri): handle errors when well-known endpt not avbl

* test(jwks-uri): Fix call count assertion
  • Loading branch information
vpsx authored May 26, 2021
1 parent e0d8b0b commit 7b1d7f5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 15 deletions.
13 changes: 11 additions & 2 deletions src/authutils/token/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import httpx
import jwt

from ..errors import (
Expand All @@ -10,7 +11,15 @@


def get_keys_url(issuer):
return "/".join([issuer.strip("/"), "jwt", "keys"])
# Prefer OIDC discovery doc, but fall back on Fence-specific /jwt/keys for backwards compatibility
openid_cfg_path = "/".join(
[issuer.strip("/"), ".well-known", "openid-configuration"]
)
try:
jwks_uri = httpx.get(openid_cfg_path).json().get("jwks_uri", "")
return jwks_uri
except:
return "/".join([issuer.strip("/"), "jwt", "keys"])


def get_kid(encoded_token):
Expand Down Expand Up @@ -145,7 +154,7 @@ def validate_jwt(encoded_token, public_key, aud, scope, issuers, options={}):
if scope:
token_scopes = token.get("scope", [])
if isinstance(token_scopes, str):
token_scopes = [token_scopes]
token_scopes = token_scopes.split()
if not isinstance(token_scopes, list):
raise JWTError(
"invalid format in scope claim: {}; expected string or list".format(
Expand Down
85 changes: 73 additions & 12 deletions src/authutils/token/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
}
"""

from collections import OrderedDict
import base64
import json
from collections import OrderedDict

from cdislogging import get_logger
import flask
import jwt
import httpx
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa


from authutils.errors import JWTError
Expand All @@ -39,7 +43,11 @@ def refresh_jwt_public_keys(user_api=None, logger=None):
Update the public keys that the Flask app is currently using to validate
JWTs.
Response from ``/jwt/keys`` should look like this:
The get_keys_url helper function will prefer the user_api's
.well-known/openid-configuration endpoint, but if no jwks_uri
is found, will default to /jwt/keys.
In the latter case, the response from ``/jwt/keys`` should look like this:
.. code-block:: javascript
Expand All @@ -56,20 +64,23 @@ def refresh_jwt_public_keys(user_api=None, logger=None):
]
}
Take out the array of keys, put it in an ordered dictionary, and assign
that to ``flask.current_app.jwt_public_keys``.
In either case, the keys are put into a dictionary and assigned to
``flask.current_app.jwt_public_keys`` with user_api as the key.
Keys are serialized to PEM if not already.
Args:
user_api (Optional[str]):
the URL of the user API to get the keys from; default to whatever
the flask app is configured to use
logger (Optional[Logger]):
the logger; default to app's parent logger
Return:
None
Side Effects:
- Reassign ``flask.current_app.jwt_public_keys`` to the keys obtained
from ``get_jwt_public_keys``, as an OrderedDict.
- Reassign ``flask.current_app.jwt_public_keys[user_api]`` to the keys obtained
from ``get_jwt_public_keys``, as a dictionary.
Raises:
ValueError: if user_api is not provided or set in app config
Expand All @@ -85,13 +96,54 @@ def refresh_jwt_public_keys(user_api=None, logger=None):
user_api = user_api or flask.current_app.config.get("USER_API")
if not user_api:
raise ValueError("no URL(s) provided for user API")

path = get_keys_url(user_api)
jwt_public_keys = httpx.get(path).json()["keys"]
logger.info(
"refreshing public keys; updated to:\n"
+ json.dumps(str(jwt_public_keys), indent=4)
try:
jwt_public_keys = httpx.get(path).json()["keys"]
except:
raise JWTError(
"Attempted to refresh public keys for {},"
"but could not get keys from path {}.".format(user_api, path)
)

logger.info("Refreshing public key cache for issuer {}...".format(user_api))
logger.debug(
"Received public keys:\n{}".format(json.dumps(str(jwt_public_keys), indent=4))
)
flask.current_app.jwt_public_keys.update({user_api: OrderedDict(jwt_public_keys)})

issuer_public_keys = {}
for key in jwt_public_keys:
if "kty" in key and key["kty"] == "RSA":
logger.debug(
"Serializing RSA public key (kid: {}) to PEM format.".format(key["kid"])
)
# Decode public numbers https://tools.ietf.org/html/rfc7518#section-6.3.1
n_padded_bytes = base64.urlsafe_b64decode(
key["n"] + "=" * (4 - len(key["n"]) % 4)
)
e_padded_bytes = base64.urlsafe_b64decode(
key["e"] + "=" * (4 - len(key["e"]) % 4)
)
n = int.from_bytes(n_padded_bytes, "big", signed=False)
e = int.from_bytes(e_padded_bytes, "big", signed=False)
# Serialize and encode public key--PyJWT decode/validation requires PEM
rsa_public_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend())
public_bytes = rsa_public_key.public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
# Cache the encoded key by issuer
issuer_public_keys[key["kid"]] = public_bytes
else:
logger.debug(
"Key type (kty) is not 'RSA'; assuming PEM format. Skipping key serialization. (kid: {})".format(
key[0]
)
)
issuer_public_keys[key[0]] = key[1]

flask.current_app.jwt_public_keys.update({user_api: issuer_public_keys})
logger.info("Done refreshing public key cache for issuer {}.".format(user_api))


def get_public_key(kid, iss=None, attempt_refresh=True, logger=None):
Expand Down Expand Up @@ -129,6 +181,7 @@ def get_public_key(kid, iss=None, attempt_refresh=True, logger=None):
JWTValidationError:
if the key id is provided and public key with that key id is found
"""

iss = (
iss
or flask.current_app.config.get("OIDC_ISSUER")
Expand All @@ -140,8 +193,16 @@ def get_public_key(kid, iss=None, attempt_refresh=True, logger=None):
)
if need_refresh and attempt_refresh:
refresh_jwt_public_keys(iss, logger=logger)
elif need_refresh and not attempt_refresh:
logger.warn(
"Public key {} not cached, but application is not attempting refresh.".format(
kid
)
)

if iss not in flask.current_app.jwt_public_keys:
raise JWTError("issuer not found: {}".format(iss))
raise JWTError("Public key for issuer {} not found.".format(iss))

iss_public_keys = flask.current_app.jwt_public_keys[iss]
try:
return iss_public_keys[kid]
Expand Down
4 changes: 3 additions & 1 deletion tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def test_get_public_key(app, example_keys_response, mock_get):
iss = app.config["USER_API"]
expected_jwt_public_keys_dict = {iss: OrderedDict(example_keys_response["keys"])}
key = get_public_key(kid=test_kid)
httpx.get.assert_called_once()
# httpx.get should be called twice: once attempting to get the jwks_uri from
# .well-known/openid-configuration, another to actually hit the jwks_uri
assert httpx.get.call_count == 2
assert key
assert key == expected_key
assert app.jwt_public_keys == expected_jwt_public_keys_dict
Expand Down

0 comments on commit 7b1d7f5

Please sign in to comment.