diff --git a/src/api/qualicharge/api/v1/__init__.py b/src/api/qualicharge/api/v1/__init__.py index a69e3e4f..68d5345b 100644 --- a/src/api/qualicharge/api/v1/__init__.py +++ b/src/api/qualicharge/api/v1/__init__.py @@ -1,18 +1,32 @@ """QualiCharge API v1.""" import logging -from typing import Annotated +from typing import Annotated, Union -from fastapi import FastAPI, Security +from fastapi import FastAPI, Request, Security, status +from fastapi.responses import JSONResponse from qualicharge.auth.models import IDToken, User from qualicharge.auth.oidc import get_token +from qualicharge.exceptions import OIDCAuthenticationError, OIDCProviderException logger = logging.getLogger(__name__) app = FastAPI(title="QualiCharge API (v1)") +@app.exception_handler(OIDCAuthenticationError) +@app.exception_handler(OIDCProviderException) +async def authentication_exception_handler( + request: Request, exc: Union[OIDCAuthenticationError, OIDCProviderException] +): + """Handle authentication errors.""" + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"message": f"Authentication failed: {exc.name}"}, + ) + + @app.get("/whoami") async def me(token: Annotated[IDToken, Security(get_token)]) -> User: """A test endpoint to validate user authentication.""" diff --git a/src/api/qualicharge/auth/oidc.py b/src/api/qualicharge/auth/oidc.py index 6461d39c..aba95a6c 100644 --- a/src/api/qualicharge/auth/oidc.py +++ b/src/api/qualicharge/auth/oidc.py @@ -104,9 +104,15 @@ def get_token( "verify_aud": True, }, ) - except (ExpiredSignatureError, JWTError, JWTClaimsError) as exc: + except ExpiredSignatureError as exc: + logger.error("Token signature expired: %s", exc) + raise OIDCAuthenticationError("Token signature expired") from exc + except JWTError as exc: logger.error("Unable to decode the ID token: %s", exc) raise OIDCAuthenticationError("Unable to decode ID token") from exc + except JWTClaimsError as exc: + logger.error("Bad token claims: %s", exc) + raise OIDCAuthenticationError("Bad token claims") from exc logger.debug(f"{decoded_token=}") return IDToken(**decoded_token) diff --git a/src/api/qualicharge/exceptions.py b/src/api/qualicharge/exceptions.py index d7e4e1de..a299f6d3 100644 --- a/src/api/qualicharge/exceptions.py +++ b/src/api/qualicharge/exceptions.py @@ -1,9 +1,17 @@ """QualiCharge exceptions.""" -class OIDCAuthenticationError(Exception): +class QualiChargeExceptionMixin: + """A mixin for QualiCharge exceptions.""" + + def __init__(self, name: str): + """Add name property for our exception handler.""" + self.name = name + + +class OIDCAuthenticationError(QualiChargeExceptionMixin, Exception): """Raised when the OIDC authentication flow fails.""" -class OIDCProviderException(Exception): +class OIDCProviderException(QualiChargeExceptionMixin, Exception): """Raised when the OIDC provider does not behave as expected.""" diff --git a/src/api/tests/api/v1/test_root.py b/src/api/tests/api/v1/test_root.py index 6a4cb9ac..a8babc7c 100644 --- a/src/api/tests/api/v1/test_root.py +++ b/src/api/tests/api/v1/test_root.py @@ -15,3 +15,5 @@ def test_whoami_auth(client_auth): response = client_auth.get("/whoami") assert response.status_code == status.HTTP_200_OK assert response.json() == {"email": "john@doe.com"} + +# FIXME add tests with invalid tokens diff --git a/src/api/tests/test_auth.py b/src/api/tests/test_auth.py index f1c706a3..17f93541 100644 --- a/src/api/tests/test_auth.py +++ b/src/api/tests/test_auth.py @@ -96,5 +96,5 @@ def test_get_token(httpx_mock, monkeypatch, id_token_factory: IDTokenFactory): key="secret", ), ) - with pytest.raises(OIDCAuthenticationError, match="Unable to decode ID token"): + with pytest.raises(OIDCAuthenticationError, match="Token signature expired"): get_token(security_scopes=SecurityScopes(), token=bearer_token)