From 4174ec7d3cb6e052ad3ee751223344362a7253c0 Mon Sep 17 00:00:00 2001 From: Pauline <4224001+paulineribeyre@users.noreply.github.com> Date: Mon, 27 Jun 2022 13:55:41 -0500 Subject: [PATCH] Fix parsing of PEM and RSA keys --- .pre-commit-config.yaml | 2 +- src/authutils/token/fastapi.py | 5 ++- src/authutils/token/keys.py | 76 +++++++++++++++++++++------------- 3 files changed, 52 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 436b8eb..d806d8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,6 +11,6 @@ repos: - id: no-commit-to-branch args: [--branch, develop, --branch, master, --pattern, release/.*] - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 22.3.0 hooks: - id: black diff --git a/src/authutils/token/fastapi.py b/src/authutils/token/fastapi.py index 4488cfa..9607bb8 100644 --- a/src/authutils/token/fastapi.py +++ b/src/authutils/token/fastapi.py @@ -7,6 +7,7 @@ from starlette.status import HTTP_403_FORBIDDEN from . import core +from .keys import get_pem_key from ..errors import JWTError, AuthError bearer = HTTPBearer() @@ -76,7 +77,9 @@ async def getter(token: HTTPAuthorizationCredentials = Security(bearer)): async with httpx.AsyncClient() as client: resp = await client.get(core.get_keys_url(issuer)) resp.raise_for_status() - pub_keys.set_result(OrderedDict(resp.json()["keys"])) + pub_keys.set_result( + OrderedDict(get_pem_key(key) for key in resp.json()["keys"]) + ) except Exception as e: _jwt_public_keys.pop(issuer) pub_keys.set_exception( diff --git a/src/authutils/token/keys.py b/src/authutils/token/keys.py index e5e83ec..28fcbf6 100644 --- a/src/authutils/token/keys.py +++ b/src/authutils/token/keys.py @@ -26,7 +26,13 @@ from collections import OrderedDict from cdislogging import get_logger -import flask + +try: + import flask +except ImportError: + print( + "Unable to import flask. Some functionalities may not work. Flask can be installed as an extra." + ) import jwt import httpx from cryptography.hazmat.backends import default_backend @@ -38,6 +44,44 @@ from .core import get_keys_url, get_kid, get_iss +def get_pem_key(key, logger=None): + """ + The key is serialized to PEM if not already. + + Return: tuple (key id, key in PEM format) + """ + if "kty" in key and key["kty"] == "RSA": + if logger: + logger.debug( + "Serializing RSA public key (kid: {}) to PEM format.".format(key["kid"]) + ) + # Decode public numbers https://tools.ietf.org/html/rfc7518#section-6.3.1 + n_padded_bytes = base64.urlsafe_b64decode( + key["n"] + "=" * (4 - len(key["n"]) % 4) + ) + e_padded_bytes = base64.urlsafe_b64decode( + key["e"] + "=" * (4 - len(key["e"]) % 4) + ) + n = int.from_bytes(n_padded_bytes, "big", signed=False) + e = int.from_bytes(e_padded_bytes, "big", signed=False) + # Serialize and encode public key--PyJWT decode/validation requires PEM + rsa_public_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend()) + public_bytes = rsa_public_key.public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + # Cache the encoded key by issuer + return key["kid"], public_bytes + else: + if logger: + logger.debug( + "Key type (kty) is not 'RSA'; assuming PEM format. Skipping key serialization. (kid: {})".format( + key[0] + ) + ) + return key[0], key[1] + + def refresh_jwt_public_keys(user_api=None, pkey_cache=None, logger=None): """ Update the public keys that the Flask app is currently using to validate @@ -125,34 +169,8 @@ def refresh_jwt_public_keys(user_api=None, pkey_cache=None, logger=None): issuer_public_keys = {} for key in jwt_public_keys: - if "kty" in key and key["kty"] == "RSA": - logger.debug( - "Serializing RSA public key (kid: {}) to PEM format.".format(key["kid"]) - ) - # Decode public numbers https://tools.ietf.org/html/rfc7518#section-6.3.1 - n_padded_bytes = base64.urlsafe_b64decode( - key["n"] + "=" * (4 - len(key["n"]) % 4) - ) - e_padded_bytes = base64.urlsafe_b64decode( - key["e"] + "=" * (4 - len(key["e"]) % 4) - ) - n = int.from_bytes(n_padded_bytes, "big", signed=False) - e = int.from_bytes(e_padded_bytes, "big", signed=False) - # Serialize and encode public key--PyJWT decode/validation requires PEM - rsa_public_key = rsa.RSAPublicNumbers(e, n).public_key(default_backend()) - public_bytes = rsa_public_key.public_bytes( - serialization.Encoding.PEM, - serialization.PublicFormat.SubjectPublicKeyInfo, - ) - # Cache the encoded key by issuer - issuer_public_keys[key["kid"]] = public_bytes - else: - logger.debug( - "Key type (kty) is not 'RSA'; assuming PEM format. Skipping key serialization. (kid: {})".format( - key[0] - ) - ) - issuer_public_keys[key[0]] = key[1] + kid, pem_bytes = get_pem_key(key, logger) + issuer_public_keys[kid] = pem_bytes if flask.has_app_context(): flask.current_app.jwt_public_keys.update({user_api: issuer_public_keys})