diff --git a/app/apps/users/auth.py b/app/apps/users/auth.py index e0f006dbd..28cc1eecb 100644 --- a/app/apps/users/auth.py +++ b/app/apps/users/auth.py @@ -1,4 +1,5 @@ import logging +import time from django.conf import settings from drf_spectacular.contrib.rest_framework_simplejwt import SimpleJWTScheme @@ -8,6 +9,58 @@ from .auth_dev import DevelopmentAuthenticationBackend + +class OIDCAuthenticationBackend(OIDCAuthenticationBackend): + def validate_issuer(self, payload): + issuer = self.get_settings("OIDC_OP_ISSUER") + if not issuer == payload["iss"]: + raise Exception( + '"iss": %r does not match configured value for OIDC_OP_ISSUER: %r' + % (payload["iss"], issuer) + ) + + def validate_audience(self, payload): + client_id = self.get_settings("OIDC_RP_CLIENT_ID") + trusted_audiences = self.get_settings("OIDC_TRUSTED_AUDIENCES", []) + trusted_audiences = set(trusted_audiences) + trusted_audiences.add(client_id) + + audience = payload["aud"] + if not isinstance(audience, list): + audience = [audience] + audience = set(audience) + distrusted_audiences = audience.difference(trusted_audiences) + if distrusted_audiences: + raise Exception( + '"aud" contains distrusted audiences: %r' % distrusted_audiences + ) + + def validate_expiry(self, payload): + expire_time = payload["exp"] + now = time.time() + if now > expire_time: + raise Exception("Id-token is expired %r > %r" % (now, expire_time)) + + def validate_id_token(self, payload): + """Validate the content of the id token as required by OpenID Connect 1.0 + + This aims to fulfill point 2. 3. and 9. under section 3.1.3.7. ID Token + Validation + """ + self.validate_issuer(payload) + self.validate_audience(payload) + self.validate_expiry(payload) + return payload + + def get_userinfo(self, access_token, id_token=None, payload=None): + """ + Get user info from the OIDC provider. + """ + userinfo = self.verify_token(access_token) + self.validate_id_token(userinfo) + return userinfo + + LOGGER = logging.getLogger(__name__) if settings.LOCAL_DEVELOPMENT_AUTHENTICATION: diff --git a/app/config/settings.py b/app/config/settings.py index 97f6e79eb..a2212a1f0 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -295,6 +295,11 @@ def filter_traces(envelope): "OIDC_OP_JWKS_ENDPOINT", "https://login.microsoftonline.com/72fca1b1-2c2e-4376-a445-294d80196804/discovery/v2.0/keys", ) +OIDC_RP_SIGN_ALGO = "RS256" +OIDC_OP_ISSUER = os.getenv( + "OIDC_OP_ISSUER", + "https://login.microsoftonline.com/72fca1b1-2c2e-4376-a445-294d80196804/v2.0", +) LOCAL_DEVELOPMENT_AUTHENTICATION = ( os.getenv("LOCAL_DEVELOPMENT_AUTHENTICATION", False) == "True"