From 2b7538e901766ab2a22a2cee1b55f248aef8669e Mon Sep 17 00:00:00 2001 From: vpsx <19900057+vpsx@users.noreply.github.com> Date: Thu, 29 Apr 2021 10:53:20 -0500 Subject: [PATCH] PXP-6617 Add custom scopes validation and revert aud validation to default (#47) * feat(scope): Add class JWTScopeError(JWTError) * fix(aud): Validate aud claim the normal way, validate custom scopes claim * import new JWTScopeError * add new scope arg to core.validate_jwt * add new scope validation to core.validate_jwt * remove custom aud validation from core.validate_jwt * remove random_aud hack in core.validate_jwt; type(aud) now string-or-None, not set-or-list * pass aud through to PyJWT for normal validation * update docstring for core.validate_jwt * allow empty aud arg in validate.validate_jwt; cease raising ValueError * add new optional scope arg in validate.validate_jwt; pass through to core.validate_jwt * update docstring for validate.validate_jwt * fix(aud): allow passthrough of options arg to pyjwt * fix(aud-scope): switch require_auth_header to checking scopes not aud * fix(aud-scope): Skip aud validation in require_auth_header/validate_request * test(aud-scope): Change default_audiences fixture to default_scopes; rm aud from generic claims * fix(aud-scope): chg aud to scope in FastAPI access_token dependency * test(aud-scope): Upd tests to reflect new aud/scope usage * fix(aud-scope): chg aud to scope in CurrentUser call to validate_request * test(aud): add happy-path test for aud validation * style(black): Blacken, and update black rev in precommit config * test(aud): Explicitly pass None instead of default_audiences * because default_audience may change to not None in future * and because this better reflects the intention of the test * fix(aud): Re-enable aud claim validation in require_auth_header * fix(aud): Expect iss in aud claim by default in token.validate_jwt... * ...if a value for iss is avbl, from app cfg BASE_URL or USER_API. * Also clarify core.validate_jwt docstring. * test(app-fixture): Set app.config['BASE_URL'] as well as ['USER_API'] * fix(aud): Allow passing expected audience to FastAPI access_token dependency * test(aud): Include aud claim in default claims test fixture * Update default_audience fixture accordingly * Update tests to account for new default claims * fix(aud): Allow passing expected audience to require_auth_header and validate_request * Also let scope={} by default * Update calls to require_auth_header * fix(aud): Update set_current_user proxy fn to pass in expected aud * based on flask.current_app.config * Since this already assumes Flask request ctx, I think OK to look in Flask app cfg in this case * test(aud): Add test: no aud arg provided and no aud claim in token * test(aud): Rename fixture default_audiences to default_audience * chore(precommit): pre-commit autoupdate * fix(aud): fix incorrect kwargs logic * docs(aud): add missing audience arg to docstring * test(aud): use default_audience instead of iss in claims fixture * fix(aud-scope): error message Co-authored-by: Pauline Ribeyre * docs(aud-scope): Fix docstring Co-authored-by: Pauline Ribeyre Co-authored-by: Pauline Ribeyre --- .pre-commit-config.yaml | 6 +- .secrets.baseline | 64 +++++++++---- src/authutils/errors.py | 5 + src/authutils/token/core.py | 100 ++++++++++++++------ src/authutils/token/fastapi.py | 25 +++-- src/authutils/token/validate.py | 55 ++++++++--- src/authutils/user.py | 14 ++- tests/conftest.py | 25 +++-- tests/test_fastapi.py | 14 +-- tests/test_jwt.py | 157 ++++++++++++++++++++++---------- 10 files changed, 333 insertions(+), 132 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4e960ea..9b96f9d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,16 +1,16 @@ repos: - repo: git@github.com:Yelp/detect-secrets - rev: v0.13.1 + rev: v1.1.0 hooks: - id: detect-secrets args: ['--baseline', '.secrets.baseline'] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 + rev: v3.4.0 hooks: - id: end-of-file-fixer - id: no-commit-to-branch args: [--branch, develop, --branch, master, --pattern, release/.*] - repo: https://github.com/psf/black - rev: 19.10b0 + rev: 20.8b1 hooks: - id: black diff --git a/.secrets.baseline b/.secrets.baseline index f9a1b4d..85b4bb3 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -1,8 +1,4 @@ { - "exclude": { - "files": null, - "lines": null - }, "generated_at": "2021-01-19T16:35:59Z", "plugins_used": [ { @@ -12,8 +8,8 @@ "name": "ArtifactoryDetector" }, { - "base64_limit": 4.5, - "name": "Base64HighEntropyString" + "name": "Base64HighEntropyString", + "limit": 4.5 }, { "name": "BasicAuthDetector" @@ -22,8 +18,8 @@ "name": "CloudantDetector" }, { - "hex_limit": 3, - "name": "HexHighEntropyString" + "name": "HexHighEntropyString", + "limit": 3 }, { "name": "IbmCloudIamDetector" @@ -60,26 +56,60 @@ "results": { "src/authutils/oauth2/client/blueprint.py": [ { + "type": "Secret Keyword", + "filename": "src/authutils/oauth2/client/blueprint.py", "hashed_secret": "6eae3a5b062c6d0d79f070c26e6d62486b40cb46", - "is_secret": false, "is_verified": false, "line_number": 15, - "type": "Secret Keyword" + "is_secret": false } ], "src/authutils/testing/fixtures/keys.py": [ { + "type": "Private Key", + "filename": "src/authutils/testing/fixtures/keys.py", "hashed_secret": "be4fc4886bd949b369d5e092eb87494f12e57e5b", - "is_secret": false, "is_verified": false, "line_number": 83, - "type": "Private Key" + "is_secret": false } ] }, - "version": "0.13.1", - "word_list": { - "file": null, - "hash": null - } + "version": "1.1.0", + "filters_used": [ + { + "path": "detect_secrets.filters.allowlist.is_line_allowlisted" + }, + { + "path": "detect_secrets.filters.heuristic.is_sequential_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_potential_uuid" + }, + { + "path": "detect_secrets.filters.heuristic.is_likely_id_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_templated_secret" + }, + { + "path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign" + }, + { + "path": "detect_secrets.filters.heuristic.is_indirect_reference" + }, + { + "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", + "min_level": 2 + }, + { + "path": "detect_secrets.filters.heuristic.is_lock_file" + }, + { + "path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string" + }, + { + "path": "detect_secrets.filters.heuristic.is_swagger_file" + } + ] } diff --git a/src/authutils/errors.py b/src/authutils/errors.py index 82ff82e..6ce5d97 100644 --- a/src/authutils/errors.py +++ b/src/authutils/errors.py @@ -24,3 +24,8 @@ class JWTPurposeError(JWTError): class JWTAudienceError(JWTError): pass + + +class JWTScopeError(JWTError): + + pass diff --git a/src/authutils/token/core.py b/src/authutils/token/core.py index 094f719..210940e 100644 --- a/src/authutils/token/core.py +++ b/src/authutils/token/core.py @@ -1,6 +1,12 @@ import jwt -from ..errors import JWTAudienceError, JWTExpiredError, JWTPurposeError, JWTError +from ..errors import ( + JWTAudienceError, + JWTExpiredError, + JWTPurposeError, + JWTScopeError, + JWTError, +) def get_keys_url(issuer): @@ -47,50 +53,76 @@ def validate_purpose(claims, pur): ) -def validate_jwt(encoded_token, public_key, aud, issuers): +def validate_jwt(encoded_token, public_key, aud, scope, issuers, options={}): """ Validate the encoded JWT ``encoded_token``, which must satisfy the - audiences ``aud``. + scopes ``scope``. This is just a slightly lower-level function to decode the token and perform the most basic checks on the token. - Decode JWT using public key; PyJWT will fail if iat or exp fields are invalid - - Check audiences: token audiences must be a superset of required audiences - (the ``aud`` argument); fail if not satisfied + - PyJWT will also fail if the aud field is present in the JWT but no + ``aud`` arg is passed, or if the ``aud`` arg does not match one of + the items in the token aud field + - Check issuers: token iss field must match one of the items in the + ``issuers`` arg + - Check scopes: token scopes must be a superset of required scopes + (the ``scope`` argument); fail if not satisfied Args: encoded_token (str): encoded JWT public_key (str): public key to validate the JWT signature - aud (set): non-empty set of audiences the JWT must satisfy + aud (Optional[str]): + audience with which the app identifies, usually an OIDC + client id, which the JWT will be expected to include in its ``aud`` + claim. Optional; if no ``aud`` argument given, then the JWT must + not have an ``aud`` claim, or validation will fail. + scope (Optional[Iterable[str]]): + set of scopes, each of which the JWT must satisfy in its + ``scope`` claim. Optional. issuers (list or set): allowed issuers whitelist + options (Optional[dict]): options to pass through to pyjwt's decode Return: dict: the decoded and validated JWT Raises: ValueError: if receiving an incorrectly-typed argument - JWTValidationError: if any step of the validation fails + JWTExpiredError: if token is expired + JWTAudienceError: if aud validation fails + JWTScopeError: if scope validation fails + JWTError: if some other token validation step fails """ + # Typecheck arguments. - if not isinstance(aud, set) and not isinstance(aud, list): - raise ValueError("aud must be set or list") + if not isinstance(aud, str) and not aud is None: + raise ValueError( + "aud must be string or None. Instead received aud of type {}".format( + type(aud) + ) + ) + if not isinstance(scope, set) and not isinstance(scope, list) and not scope is None: + raise ValueError( + "scope must be set or list or None. Instead received scope of type {}".format( + type(scope) + ) + ) if not isinstance(issuers, set) and not isinstance(issuers, list): - raise ValueError("issuers must be set or list") - - # To satisfy PyJWT, since the token will contain an aud field, decode has - # to be passed one of the audiences to check here (so PyJWT doesn't raise - # an InvalidAudienceError). Per the JWT specification, if the token - # contains an aud field, the validator MUST identify with one of the - # audiences listed in that field. This implementation is more strict, and - # allows the validator to demand multiple audiences which must all be - # satisfied by the token (see below). - aud = set(aud) - random_aud = list(aud)[0] + raise ValueError( + "issuers must be set or list. Instead received issuers of type {}".format( + type(issuers) + ) + ) + try: token = jwt.decode( - encoded_token, key=public_key, algorithms=["RS256"], audience=random_aud + encoded_token, + key=public_key, + algorithms=["RS256"], + audience=aud, + options=options, ) except jwt.InvalidAudienceError as e: raise JWTAudienceError(e) @@ -99,7 +131,7 @@ def validate_jwt(encoded_token, public_key, aud, issuers): except jwt.InvalidTokenError as e: raise JWTError(e) - # PyJWT validates iat and exp fields (and aud...sort of); everything else + # PyJWT validates iat, exp, and aud fields; everything else # must happen here. # iss @@ -108,12 +140,22 @@ def validate_jwt(encoded_token, public_key, aud, issuers): msg = "invalid issuer {}; expected: {}".format(token["iss"], issuers) raise JWTError(msg) - # aud - # The audiences listed in the token must completely satisfy all the - # required audiences provided. Note that this is stricter than the - # specification suggested in RFC 7519. - missing = aud - set(token["aud"]) - if missing: - raise JWTAudienceError("missing audiences: " + str(missing)) + # scope + # Check that if scope arg was non-empty then the token includes each given scope in its scope claim + if scope: + token_scopes = token.get("scope", []) + if isinstance(token_scopes, str): + token_scopes = [token_scopes] + if not isinstance(token_scopes, list): + raise JWTError( + "invalid format in scope claim: {}; expected string or list".format( + token["scopes"] + ) + ) + missing_scopes = set(scope) - set(token_scopes) + if missing_scopes: + raise JWTScopeError( + "token is missing required scopes: " + str(missing_scopes) + ) return token diff --git a/src/authutils/token/fastapi.py b/src/authutils/token/fastapi.py index 8523d35..4488cfa 100644 --- a/src/authutils/token/fastapi.py +++ b/src/authutils/token/fastapi.py @@ -14,7 +14,9 @@ _jwt_public_keys = {} -def access_token(*audiences, issuer=None, allowed_issuers=None, purpose=None): +def access_token( + *scopes, audience=None, issuer=None, allowed_issuers=None, purpose=None +): """ Validate and return the JWT bearer token in HTTP header:: @@ -25,8 +27,11 @@ def whoami(token=Depends(access_token("user", "openapi", purpose="access"))): return token["iss"] Args: - *audiences: Required, all must occur in ``aud``. - issuer: Force to use this issuer to validate the token if provided. + *scopes: Required, all must occur in ``scope``. + audience: Optional; if provided, JWT validation will require that the token's + ``aud`` value contains the arg value; if not provided, validation will require + that the token not have an aud field. + issuer: Optional; force to use this issuer to validate the token if provided. allowed_issuers: Optional allowed issuers whitelist, default: allow all. purpose: Optional, must match ``pur`` if provided. @@ -34,9 +39,9 @@ def whoami(token=Depends(access_token("user", "openapi", purpose="access"))): Decoded JWT claims as a :class:`dict`. """ - if not audiences: - raise ValueError("Missing parameter: audiences") - audiences = set(audiences) + if not scopes: + raise ValueError("Missing parameter: scopes") + scopes = set(scopes) if not allowed_issuers and issuer: allowed_issuers = [issuer] @@ -93,7 +98,13 @@ async def getter(token: HTTPAuthorizationCredentials = Security(bearer)): # decode and validate the token try: claims = await loop.run_in_executor( - None, core.validate_jwt, token, pub_key, audiences, allowed_issuers + None, + core.validate_jwt, + token, + pub_key, + audience, + scopes, + allowed_issuers, ) if purpose: diff --git a/src/authutils/token/validate.py b/src/authutils/token/validate.py index 5357307..ce6342b 100644 --- a/src/authutils/token/validate.py +++ b/src/authutils/token/validate.py @@ -67,57 +67,74 @@ def get_session_token(): def validate_jwt( encoded_token, - aud, + aud=None, + scope=None, purpose="access", issuers=None, public_key=None, attempt_refresh=True, logger=None, + options={}, ): """ Validate a JWT and return the claims. Args: encoded_token (str): the base64 encoding of the token - aud (Optional[Iterable[str]]): - list of audiences that the token must satisfy; defaults to - ``{'openid'}`` (minimum expected by OpenID provider) + aud (Optional[str]): + audience as which the app identifies, which the JWT will be + expected to include in its ``aud`` claim. + Optional; will default to issuer from flask.current_app.config + if available (either BASE_URL or USER_API). + To skip aud validation, pass the following in the options arg: + options={"verify_aud": False} + scope (Optional[Iterable[str]]): + scopes that the token must satisfy purpose (Optional[str]): which purpose the token is supposed to be used for (access, refresh, or id) issuers (Iterable[str]): list of allowed token issuers public_key (Optional[str]): public key to vaidate JWT with + attempt_refresh (Optional[bool]): + whether to attempt refresh of public keys if not found in cache + options (Optional[dict]): options to pass through to pyjwt's decode Return: dict: dictionary of claims from the validated JWT Raises: - ValueError: if ``aud`` is empty JWTError: if auth header is missing, decoding fails, or the JWT fails to satisfy any expectation """ logger = logger or get_logger(__name__, log_level="info") + if not issuers: issuers = [] for config_var in ["OIDC_ISSUER", "USER_API", "BASE_URL"]: value = flask.current_app.config.get(config_var) if value: issuers.append(value) + + # Can't set arg default to config[x] in fn def, so doing it this way. + if aud is None: + aud = flask.current_app.config.get("BASE_URL") + # Some Gen3 apps use BASE_URL and some use USER_API, so fall back on USER_API + if aud is None: + aud = flask.current_app.config.get("USER_API") + if public_key is None: public_key = get_public_key_for_token( encoded_token, attempt_refresh=attempt_refresh, logger=logger ) - if not aud: - raise ValueError("must provide at least one audience") - aud = set(aud) - claims = core.validate_jwt(encoded_token, public_key, aud, issuers) + + claims = core.validate_jwt(encoded_token, public_key, aud, scope, issuers, options) if purpose: core.validate_purpose(claims, purpose) return claims -def validate_request(aud, purpose="access", logger=None): +def validate_request(scope={}, audience=None, purpose="access", logger=None): """ Validate a ``flask.request`` by checking the JWT contained in the request headers. @@ -132,13 +149,19 @@ def validate_request(aud, purpose="access", logger=None): raise JWTError("no authorization header provided") # Pass token to ``validate_jwt``. - return validate_jwt(encoded_token, aud, purpose, logger=logger) + return validate_jwt( + encoded_token, + aud=audience, + scope=scope, + purpose=purpose, + logger=logger, + ) -def require_auth_header(aud, purpose=None, logger=None): +def require_auth_header(scope={}, audience=None, purpose=None, logger=None): """ Return a decorator which adds request validation to check the given - audiences and (optionally) purpose. + scopes, audience and purpose (all optional). """ logger = logger or get_logger(__name__, log_level="info") @@ -156,7 +179,11 @@ def wrapper(*args, **kwargs): the code inside the function can use the ``LocalProxy`` for the token (see top of this file). """ - set_current_token(validate_request(aud=aud, purpose=purpose, logger=logger)) + set_current_token( + validate_request( + scope=scope, audience=audience, purpose=purpose, logger=logger + ) + ) return f(*args, **kwargs) return wrapper diff --git a/src/authutils/user.py b/src/authutils/user.py index 33dffcf..d6169b2 100644 --- a/src/authutils/user.py +++ b/src/authutils/user.py @@ -11,6 +11,16 @@ def set_current_user(**kwargs): + default_expected_audience = flask.current_app.config.get("USER_API") + # Gen3 services use both USER_API and BASE_URL + if not default_expected_audience: + default_expected_audience = flask.current_app.config.get("BASE_URL") + + # If not already passed an aud to expect, default to the application's url + kwargs.setdefault("jwt_kwargs", {}).setdefault( + "audience", default_expected_audience + ) + flask.g.user = CurrentUser(**kwargs) set_current_token(flask.g.user._claims) return flask.g.user @@ -41,8 +51,8 @@ class CurrentUser(object): def __init__(self, claims=None, jwt_kwargs=None): jwt_kwargs = jwt_kwargs or {} - if "aud" not in jwt_kwargs: - jwt_kwargs["aud"] = {"openid"} + if "scope" not in jwt_kwargs: + jwt_kwargs["scope"] = {"openid"} self._claims = claims or validate_request(**jwt_kwargs) self.id = self._claims["sub"] self.username = self._get_user_info("name") diff --git a/tests/conftest.py b/tests/conftest.py index aa4d899..b4b3a85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,17 +38,25 @@ def iss(): @pytest.fixture(scope="session") -def default_audiences(): +def default_audience(): """ - Return some default audiences to put in the claims of a JWT. + Return default audience to pass to core.validate_jwt calls. """ - # Note that ``test_aud`` here is the audience expected on the test endpoint + return USER_API + + +@pytest.fixture(scope="session") +def default_scopes(): + """ + Return some default scopes to put in the claims of a JWT. + """ + # Note that ``test_scope`` here is the scope expected on the test endpoint # in the test application. - return ["openid", "access", "user", "test_aud"] + return ["openid", "access", "user", "test_scope"] @pytest.fixture(scope="session") -def claims(default_audiences, iss): +def claims(default_audience, default_scopes, iss): """ Return some generic claims to put in a JWT. @@ -60,12 +68,13 @@ def claims(default_audiences, iss): exp = int((now + timedelta(seconds=600)).strftime("%s")) return { "pur": "access", - "aud": default_audiences, "sub": "1234", "iss": iss, + "aud": default_audience, "iat": iat, "exp": exp, "jti": str(uuid.uuid4()), + "scope": default_scopes, "context": {"user": {"name": "test-user", "projects": []}}, } @@ -140,10 +149,12 @@ def app(): """ app = flask.Flask(__name__) app.debug = True + # Gen3 services use both USER_API and BASE_URL app.config["USER_API"] = USER_API + app.config["BASE_URL"] = USER_API @app.route("/test") - @require_auth_header({"test_aud"}, "access") + @require_auth_header({"test_scope"}, USER_API, "access") def test_endpoint(): """ Define a simple endpoint for testing which requires a JWT header for diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 4af26c6..f64c50c 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -7,21 +7,23 @@ @pytest.fixture(scope="function") -def async_client(default_audiences, mock_async_get, iss): +def async_client(default_scopes, mock_async_get, iss): mock_async_get() app = fastapi.FastAPI() @app.get("/whoami") def whoami( - token=fastapi.Depends(access_token(*default_audiences, purpose="access")) + token=fastapi.Depends( + access_token(*default_scopes, audience=iss, purpose="access") + ) ): return token @app.get("/force_issuer") def force_issuer( token=fastapi.Depends( - access_token(*default_audiences, issuer=iss, purpose="access") + access_token(*default_scopes, audience=iss, issuer=iss, purpose="access") ) ): return token @@ -30,7 +32,7 @@ def force_issuer( def whitelist( token=fastapi.Depends( access_token( - *default_audiences, + *default_scopes, allowed_issuers=["https://right.example.com"], purpose="access" ) @@ -42,8 +44,8 @@ def whitelist( yield client -def test_no_audience(): - with pytest.raises(ValueError, match="audiences"): +def test_no_scopes(): + with pytest.raises(ValueError, match="scopes"): access_token() diff --git a/tests/test_jwt.py b/tests/test_jwt.py index a5eb70d..9bc0ccd 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -1,12 +1,13 @@ # pylint: disable=unused-argument from collections import OrderedDict +import jwt import flask import pytest import httpx -from authutils.errors import JWTError, JWTAudienceError, JWTExpiredError +from authutils.errors import JWTError, JWTAudienceError, JWTExpiredError, JWTScopeError from authutils.token.keys import get_public_key from authutils.token.core import validate_jwt from authutils.token.validate import require_auth_header @@ -14,51 +15,127 @@ from tests.utils import TEST_RESPONSE_JSON -def test_valid_signature(claims, encoded_jwt, rsa_public_key, default_audiences, iss): +def test_valid_signature( + claims, encoded_jwt, rsa_public_key, default_audience, default_scopes, iss +): """ Do a basic test of the expected functionality with the sample payload in the fence README. """ - decoded_token = validate_jwt(encoded_jwt, rsa_public_key, default_audiences, [iss]) + decoded_token = validate_jwt( + encoded_jwt, rsa_public_key, default_audience, default_scopes, [iss] + ) assert decoded_token assert decoded_token == claims def test_expired_token_rejected( - encoded_jwt_expired, rsa_public_key, default_audiences, iss + encoded_jwt_expired, rsa_public_key, default_audience, default_scopes, iss ): with pytest.raises(JWTExpiredError): - validate_jwt(encoded_jwt_expired, rsa_public_key, default_audiences, [iss]) + validate_jwt( + encoded_jwt_expired, + rsa_public_key, + default_audience, + default_scopes, + [iss], + ) def test_invalid_signature_rejected( - encoded_jwt, rsa_public_key_2, default_audiences, iss + encoded_jwt, rsa_public_key_2, default_audience, default_scopes, iss ): """ Test that ``validate_jwt`` rejects JWTs signed with a private key not corresponding to the public key it is given. """ with pytest.raises(JWTError): - validate_jwt(encoded_jwt, rsa_public_key_2, default_audiences, [iss]) + validate_jwt( + encoded_jwt, rsa_public_key_2, default_audience, default_scopes, [iss] + ) + + +def test_invalid_scope_rejected(encoded_jwt, rsa_public_key, default_audience, iss): + """ + Test that if ``validate_jwt`` is passed values for ``scope`` which do not + appear in the token, a ``JWTScopeError`` is raised. + """ + with pytest.raises(JWTScopeError): + validate_jwt( + encoded_jwt, rsa_public_key, default_audience, {"not-in-scopes"}, [iss] + ) + + +def test_missing_aud_rejected(encoded_jwt, rsa_public_key, default_scopes, iss): + """ + Test that if ``validate_jwt`` is passed a value for ``aud`` which does not + appear in the token, a ``JWTError`` is raised. + """ + with pytest.raises(JWTError): + validate_jwt(encoded_jwt, rsa_public_key, "not-in-aud", default_scopes, [iss]) -def test_invalid_aud_rejected(encoded_jwt, rsa_public_key, iss): +def test_unexpected_aud_rejected( + encoded_jwt, + rsa_public_key, + default_scopes, + iss, +): """ - Test that if ``validate_jwt`` is passed values for ``aud`` which do not - appear in the token, a ``JWTAudienceError`` is raised. + Test that if the token contains an ``aud`` claim and no ``aud`` arg is passed + to ``validate_jwt``, a ``JWTAudienceError`` is raised. """ with pytest.raises(JWTAudienceError): - validate_jwt(encoded_jwt, rsa_public_key, {"not-in-aud"}, [iss]) + validate_jwt(encoded_jwt, rsa_public_key, None, default_scopes, [iss]) + + +def test_expected_missing_aud_accepted( + claims, + token_headers, + rsa_private_key, + rsa_public_key, + default_scopes, + iss, +): + """ + Test that if no ``aud`` arg is passed to ``validate_jwt`` and the token does NOT + contain an ``aud`` claim then validation passes. + """ + claims = claims.copy() + claims.pop("aud") + encoded_token = jwt.encode( + claims, headers=token_headers, key=rsa_private_key, algorithm="RS256" + ) + validate_jwt(encoded_token, rsa_public_key, None, default_scopes, [iss]) + + +def test_valid_aud_accepted( + claims, token_headers, rsa_private_key, rsa_public_key, default_scopes, iss +): + """ + Test that if the token contains multiple audience values in its ``aud`` claim + and one of those values is passed to ``validate_jwt`` then validation passes. + """ + claims = claims.copy() + claims["aud"] = ["foo", "bar", "baz"] + encoded_token = jwt.encode( + claims, headers=token_headers, key=rsa_private_key, algorithm="RS256" + ) + validate_jwt(encoded_token, rsa_public_key, "baz", default_scopes, [iss]) -def test_invalid_iss_rejected(encoded_jwt, rsa_public_key, iss): +def test_invalid_iss_rejected( + encoded_jwt, rsa_public_key, default_audience, default_scopes, iss +): """ Test that if ``validate_jwt`` receives a token whose value for ``iss`` does not match the expected value, a ``JWTValidationError`` is raised. """ wrong_iss = iss + "garbage" with pytest.raises(JWTError): - validate_jwt(encoded_jwt, rsa_public_key, {"not-in-aud"}, [wrong_iss]) + validate_jwt( + encoded_jwt, rsa_public_key, default_audience, default_scopes, [wrong_iss] + ) def test_get_public_key(app, example_keys_response, mock_get): @@ -114,56 +191,42 @@ def test_validate_request_jwt_bad_header(client, mock_get, encoded_jwt): client.get("/test", headers=incorrect_headers) -def test_validate_request_jwt_incorrect_usage(app, client, auth_header, mock_get): - """ - Test that if a ``require_auth_header`` caller does not give it any - audiences, a JWTAudienceError is raised. - """ - mock_get() - - # This should raise a ValueError, since no audiences are provided. - @require_auth_header({}, "access") - def bad(): - return flask.jsonify({"foo": "bar"}) - - app.add_url_rule("/test_incorrect_usage", "bad", bad) - - with pytest.raises(ValueError): - client.get("/test_incorrect_usage", headers=auth_header) - - -def test_validate_request_jwt_missing(app, client, auth_header, mock_get): +def test_validate_request_jwt_missing_all_scopes( + app, client, auth_header, default_audience, mock_get +): """ - Test that if the JWT is completely missing an audience which is required by - an endpoint, a ``jwt.InvalidAudienceError`` is raised. + Test that if the JWT is completely missing a scope which is required by + an endpoint, a ``JWTScopeError`` is raised. """ mock_get() - # This should raise jwt.InvalidAudienceError, since the audience it + # This should raise a JWTScopeError, since the scope it # requires does not appear in the default JWT anywhere. - @app.route("/test_missing_audience") - @require_auth_header({"missing_audience"}, "access") + @app.route("/test_missing_scope") + @require_auth_header({"missing_scope"}, default_audience, "access") def bad(): return flask.jsonify({"foo": "bar"}) - with pytest.raises(JWTAudienceError): - client.get("/test_missing_audience", headers=auth_header) + with pytest.raises(JWTScopeError): + client.get("/test_missing_scope", headers=auth_header) -def test_validate_request_jwt_missing_some(app, client, auth_header, mock_get): +def test_validate_request_jwt_missing_some_scopes( + app, client, auth_header, default_audience, mock_get +): """ - Test that if the JWT satisfies some audiences but is missing at least one - audience which is required by an endpoint, a ``jwt.InvalidAudienceError`` + Test that if the JWT satisfies some scopes but is missing at least one + scope which is required by an endpoint, a ``JWTScopeError`` is raised. """ mock_get() - # This should raise JWTAudienceError, since the audience it requires does + # This should raise JWTScopeError, since the scope it requires does # not appear in the default JWT anywhere. - @app.route("/test_missing_audience") - @require_auth_header({"access", "missing_audience"}, "access") + @app.route("/test_missing_scope") + @require_auth_header({"access", "missing_scope"}, default_audience, "access") def bad(): return flask.jsonify({"foo": "bar"}) - with pytest.raises(JWTAudienceError): - client.get("/test_missing_audience", headers=auth_header) + with pytest.raises(JWTScopeError): + client.get("/test_missing_scope", headers=auth_header)