Skip to content

Commit

Permalink
pysnippetGH-22: Handle possible exceptions in token data obtaining flow
Browse files Browse the repository at this point in the history
  • Loading branch information
ArtyomVancyan committed Oct 6, 2023
1 parent 23658d3 commit 7b18c0a
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/fastapi_oauth2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
from urllib.parse import urljoin

import httpx
from oauthlib.oauth2 import OAuth2Error
from oauthlib.oauth2 import WebApplicationClient
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
from social_core.backends.oauth import BaseOAuth2
from social_core.exceptions import AuthException
from social_core.strategy import BaseStrategy
from starlette.requests import Request
from starlette.responses import RedirectResponse

from .claims import Claims
from .client import OAuth2Client
from .exceptions import OAuth2LoginError
from .exceptions import OAuth2AuthenticationError
from .exceptions import OAuth2BadCredentialsError
from .exceptions import OAuth2InvalidRequestError


class OAuth2Strategy(BaseStrategy):
Expand Down Expand Up @@ -92,11 +95,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:

async def token_data(self, request: Request, **httpx_client_args) -> dict:
if not request.query_params.get("code"):
raise OAuth2LoginError(400, "'code' parameter was not found in callback request")
raise OAuth2InvalidRequestError(400, "'code' parameter was not found in callback request")
if not request.query_params.get("state"):
raise OAuth2LoginError(400, "'state' parameter was not found in callback request")
raise OAuth2InvalidRequestError(400, "'state' parameter was not found in callback request")
if request.query_params.get("state") != self._state:
raise OAuth2LoginError(400, "'state' parameter does not match")
raise OAuth2InvalidRequestError(400, "'state' parameter does not match")

redirect_uri = self.get_redirect_uri(request)
scheme = "http" if request.auth.http else "https"
Expand All @@ -113,12 +116,16 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
headers.update({"Accept": "application/json"})
auth = httpx.BasicAuth(self.client_id, self.client_secret)
async with httpx.AsyncClient(auth=auth, **httpx_client_args) as session:
response = await session.post(token_url, headers=headers, content=content)
try:
response = await session.post(token_url, headers=headers, content=content)
self._oauth_client.parse_request_body_response(json.dumps(response.json()))
return self.standardize(self.backend.user_data(self.access_token))
except (CustomOAuth2Error, Exception) as e:
raise OAuth2LoginError(400, str(e))
except OAuth2Error as e:
raise OAuth2InvalidRequestError(400, str(e))
except httpx.HTTPError as e:
raise OAuth2BadCredentialsError(400, str(e))
except (AuthException, Exception) as e:
raise OAuth2AuthenticationError(401, str(e))

async def token_redirect(self, request: Request, **kwargs) -> RedirectResponse:
access_token = request.auth.jwt_create(await self.token_data(request, **kwargs))
Expand Down

0 comments on commit 7b18c0a

Please sign in to comment.