Skip to content

Commit

Permalink
pysnippetGH-22: Using starlette like exception handling
Browse files Browse the repository at this point in the history
That is:
- raising starlette.authentication.AuthenticationError
- providing an on_error callback turning starlet 400 into 401
  to keep same api
- letting the user provide their own on_error when instantiating
  the middleware.
  • Loading branch information
vokimon committed Jul 30, 2024
1 parent 089d648 commit b3ce107
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/fastapi_oauth2/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
from jose.jwt import encode as jwt_encode
from starlette.authentication import AuthCredentials
from starlette.authentication import AuthenticationBackend
from starlette.authentication import AuthenticationError
from starlette.authentication import BaseUser
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.responses import PlainTextResponse
from starlette.responses import Response
from starlette.types import ASGIApp
from starlette.types import Receive
from starlette.types import Scope
Expand Down Expand Up @@ -111,9 +114,9 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
try:
token_data = Auth.jwt_decode(param)
except JOSEError as e:
raise OAuth2AuthenticationError(401, str(e))
raise AuthenticationError(str(e))
if token_data["exp"] and token_data["exp"] < int(datetime.now(timezone.utc).timestamp()):
raise OAuth2AuthenticationError(401, "Token expired")
raise AuthenticationError("Token expired")

user = User(token_data)
auth = Auth(user.pop("scope", []))
Expand All @@ -138,6 +141,7 @@ def __init__(
app: ASGIApp,
config: Union[OAuth2Config, dict],
callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None,
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
**kwargs, # AuthenticationMiddleware kwargs
) -> None:
"""Initiates the middleware with the given configuration.
Expand All @@ -151,9 +155,13 @@ def __init__(
elif not isinstance(config, OAuth2Config):
raise TypeError("config is not a valid type")
self.default_application_middleware = app
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), on_error = on_error or self.on_error, **kwargs)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
return await self.auth_middleware(scope, receive, send)
await self.default_application_middleware(scope, receive, send)

@staticmethod
def on_error(conn: HTTPConnection, exc: Exception) -> Response:
return PlainTextResponse(str(exc), status_code=401)

0 comments on commit b3ce107

Please sign in to comment.