Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement refresh endpoint #11

Merged
merged 15 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion ldap_jwt_auth/auth/jwt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ldap_jwt_auth.core.config import config
from ldap_jwt_auth.core.constants import PRIVATE_KEY, PUBLIC_KEY
from ldap_jwt_auth.core.exceptions import InvalidJWTError
from ldap_jwt_auth.core.exceptions import InvalidJWTError, JWTRefreshError

logger = logging.getLogger()

Expand Down Expand Up @@ -44,6 +44,27 @@ def get_refresh_token(self) -> str:
}
return self._pack_jwt(payload)

def refresh_access_token(self, access_token: str, refresh_token: str):
"""
Refreshes the JWT access token by updating its expiry time, provided that the JWT refresh token is valid.
:param access_token: The JWT access token to refresh.
:param refresh_token: The JWT refresh token.
:raises JWTRefreshError: If the JWT access token cannot be refreshed.
:return: JWT access token with an updated expiry time.
"""
logger.info("Refreshing access token")
self.verify_token(refresh_token)
try:
payload = self._get_jwt_payload(access_token, {"verify_exp": False})
payload["exp"] = datetime.now(timezone.utc) + timedelta(
minutes=config.authentication.access_token_validity_minutes
)
return self._pack_jwt(payload)
except Exception as exc:
message = "Unable to refresh access token"
logger.exception(message)
raise JWTRefreshError(message) from exc

def verify_token(self, token: str) -> Dict[str, Any]:
"""
Verifies that the provided JWT token is valid. It does this by checking that it was signed by the corresponding
Expand Down
6 changes: 6 additions & 0 deletions ldap_jwt_auth/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class InvalidJWTError(Exception):
"""


class JWTRefreshError(Exception):
"""
Exception raised when JWT access token cannot be refreshed.
"""


class LDAPServerError(Exception):
"""
Exception raised when there is problem with the LDAP server.
Expand Down
3 changes: 2 additions & 1 deletion ldap_jwt_auth/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ldap_jwt_auth.core.config import config
from ldap_jwt_auth.core.logger_setup import setup_logger
from ldap_jwt_auth.routers import login, verify
from ldap_jwt_auth.routers import login, refresh, verify

app = FastAPI(title=config.api.title, description=config.api.description)

Expand Down Expand Up @@ -63,6 +63,7 @@ async def custom_validation_exception_handler(request: Request, exc: RequestVali
)

app.include_router(login.router)
app.include_router(refresh.router)
app.include_router(verify.router)


Expand Down
39 changes: 39 additions & 0 deletions ldap_jwt_auth/routers/refresh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Module for providing an API router which defines a route for managing the refreshing/updating of a JWT access token
using a JWT refresh token.
"""
import logging
from typing import Annotated

from fastapi import APIRouter, Body, Cookie, Depends, HTTPException, status
from fastapi.responses import JSONResponse

from ldap_jwt_auth.auth.jwt_handler import JWTHandler
from ldap_jwt_auth.core.exceptions import JWTRefreshError, InvalidJWTError

logger = logging.getLogger()

router = APIRouter(prefix="/refresh", tags=["authentication"])


@router.post(
path="",
summary="Generate an updated JWT access token using the JWT refresh token",
response_description="A JWT access token",
)
def refresh_access_token(
jwt_handler: Annotated[JWTHandler, Depends(JWTHandler)],
token: Annotated[str, Body(description="The JWT access token to refresh", embed=True)],
refresh_token: Annotated[str | None, Cookie(description="The JWT refresh token from an HTTP-only cookie")] = None,
) -> JSONResponse:
# pylint: disable=missing-function-docstring
if refresh_token is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No JWT refresh token found")

try:
access_token = jwt_handler.refresh_access_token(token, refresh_token)
return JSONResponse(content=access_token)
except (InvalidJWTError, JWTRefreshError) as exc:
message = "Unable to refresh access token"
logger.exception(message)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=message) from exc
153 changes: 107 additions & 46 deletions test/unit/auth/test_jwt_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,49 @@
import pytest

from ldap_jwt_auth.auth.jwt_handler import JWTHandler
from ldap_jwt_auth.core.exceptions import InvalidJWTError
from ldap_jwt_auth.core.exceptions import InvalidJWTError, JWTRefreshError

VALID_ACCESS_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoyNTM0MDIzMDA3OTl9.bagU2Wix8wKzydVU_L3Z"
"ZuuMAxGxV4OTuZq_kS2Fuwm839_8UZOkICnPTkkpvsm1je0AWJaIXLGgwEa5zUjpG6lTrMMmzR9Zi63F0NXpJqQqoOZpTBMYBaggsXqFkdsv-yAKUZ"
"8MfjCEyk3UZ4PXZmEcUZcLhKcXZr4kYJPjio2e5WOGpdjK6q7s-iHGs9DQFT_IoCnw9CkyOKwYdgpB35hIGHkNjiwVSHpyKbFQvzJmIv5XCTSRYqq0"
"1fldh-QYuZqZeuaFidKbLRH610o2-1IfPMUr-yPtj5PZ-AaX-XTLkuMqdVMCk0_jeW9Os2BPtyUDkpcu1fvW3_S6_dK3nQ"
)

VALID_REFRESH_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI1MzQwMjMwMDc5OX0.h4Hv_sq4-ika1rpuRx7k3pp0cF_BZ65WVSbIHS7oh9SjPpGHt"
"GhVHU1IJXzFtyA9TH-68JpAZ24Dm6bXbH6VJKoc7RCbmJXm44ufN32ga7jDqXH340oKvi_wdhEHaCf2HXjzsHHD7_D6XIcxU71v2W5_j8Vuwpr3SdX"
"6ea_yLIaCDWynN6FomPtUepQAOg3c7DdKohbJD8WhKIDV8UKuLtFdRBfN4HEK5nNs0JroROPhcYM9L_JIQZpdI0c83fDFuXQC-cAygzrSnGJ6O4DyS"
"cNL3VBNSmNTBtqYOs1szvkpvF9rICPgbEEJnbS6g5kmGld3eioeuDJIxeQglSbxog"
)

EXPIRED_ACCESS_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjotNjIxMzU1OTY4MDB9.G_cfC8PNYE5yERyyQNRk"
"9mTmDusU_rEPgm7feo2lWQF6QMNnf8PUN-61FfMNRVE0QDSvAmIMMNEOa8ma0JHZARafgnYJfn1_FSJSoRxC740GpG8EFSWrpM-dQXnoD263V9FlK-"
"On6IbhF-4Rh9MdoxNyZk2Lj7NvCzJ7gbgbgYM5-sJXLxB-I5LfMfuYM3fx2cRixZFA153l46tFzcMVBrAiBxl_LdyxTIOPfHF0UGlaW2UtFi02gyBU"
"4E4wTOqPc4t_CSi1oBSbY7h9O63i8IU99YsOCdvZ7AD3ePxyM1xJR7CFHycg9Z_IDouYnJmXpTpbFMMl7SjME3cVMfMrAQ"
)

EXPIRED_REFRESH_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOi02MjEzNTU5NjgwMH0.Er0A8dvdZi7o1FK3b-Te2IkUjDJZjI0aANsP7bbAbeITPRnR0"
"YEhavmuLT1zaoALQjUzfSgtH0s3I-YbUr2ssqG1DnKh83uts3J2_EXIXQZBeuZisCW1nN1LC2nsR6o4HQEsbMsINjJviHeMWS8nRC06XXpN1WFPaGB"
"xXkLFeDWb3SXiirZ79m7lUBwQvVzpfeA337e_AejG45mtadgfW3xpDCw-6sVVIA-cuzruxnjRKAzJrw_goA9X4MukRXbnzou2mgkxFKs_-6hdTFDI-"
"B47wYqalP6KC5nqzjrCpvjmukgM-DN0uAhm2TUzUmE5EXtRLEYMRqsSmog4hYq1Nw"
)

EXPECTED_ACCESS_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNzA1NDg1OTAwfQ.aWJ8T8RGHF93YhRSP9nOAD"
"EKY9nFjVIDu7RQhPGiMpvhgdpPBP17VQPbJ6Smt8mG1TjLXjquJZaDQRF7syrJd8ESDo-lh3ef-cMWg2hWZpbtpQaPaNHLAAMrjZo97qLxrBjeOKjY"
"ggqwKMr-7g_LlB--z9GiQrLJVhpGxAXjnTy9VSrioZIU7OE9L9tUyOI7LGjY0X2znWQ3Loy5sMwCP_SeFHBPolKXiErKeLItriaxYNEc5l5VXD2wsK"
"G9L8dDZZwe4BSU2eyT_2hhPTrVNfI8-J1KtwpLywC0NfS0Vaksy4HG2IbH8hpl6gaLZhtr2C5_0H_IpkTsvm_Zsnzhbg"
)

EXPECTED_REFRESH_TOKEN = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MDYwOTA0MDB9.IHua0NcHiLOz7vamvcR4lxt-t51_UgzIQzho5vYK2UdHjG-bA5Sk"
"9YhHQy480UK4FiIKohpb8G70OwmsSCjzxvbo41MZKdz3z0z_4-L0_LSGLGGmxbvPaHy6_SI8qI1f7KOAD6T3OU1zIFTcyoREEN2uNRyjMnGcQzh72d"
"NkRAFEF3um4S2WVL0mwQ6ZltAjCiA2R8o5Eu3Aq67lkbq00ml69rfecT1JXiAfjrnW0J64COJDbQ9kVCNM1YrpqLBmROHMOOw9o7Qz1h78LbtKarVk"
"VGaPIxhdZsWKjZwDD-6h15NZuKTAmcPUaucx6Dd4uCjJHld1BNsfKfX_81G03g"
)


def mock_datetime_now() -> datetime:
Expand All @@ -23,52 +65,81 @@ def test_get_access_token(datetime_mock):
"""
Test getting an access token.
"""
expected_access_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoxNzA1NDg1OTAwfQ.aWJ8T8RGHF93YhRSP9"
"nOADEKY9nFjVIDu7RQhPGiMpvhgdpPBP17VQPbJ6Smt8mG1TjLXjquJZaDQRF7syrJd8ESDo-lh3ef-cMWg2hWZpbtpQaPaNHLAAMrjZo97qLx"
"rBjeOKjYggqwKMr-7g_LlB--z9GiQrLJVhpGxAXjnTy9VSrioZIU7OE9L9tUyOI7LGjY0X2znWQ3Loy5sMwCP_SeFHBPolKXiErKeLItriaxYN"
"Ec5l5VXD2wsKG9L8dDZZwe4BSU2eyT_2hhPTrVNfI8-J1KtwpLywC0NfS0Vaksy4HG2IbH8hpl6gaLZhtr2C5_0H_IpkTsvm_Zsnzhbg"
)
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
access_token = jwt_handler.get_access_token("username")

assert access_token == expected_access_token
assert access_token == EXPECTED_ACCESS_TOKEN


@patch("ldap_jwt_auth.auth.jwt_handler.datetime")
def test_get_refresh_token(datetime_mock):
"""
Test getting a refresh token.
"""
expected_refresh_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MDYwOTA0MDB9.IHua0NcHiLOz7vamvcR4lxt-t51_UgzIQzho5vYK2UdHjG-b"
"A5Sk9YhHQy480UK4FiIKohpb8G70OwmsSCjzxvbo41MZKdz3z0z_4-L0_LSGLGGmxbvPaHy6_SI8qI1f7KOAD6T3OU1zIFTcyoREEN2uNRyjMn"
"GcQzh72dNkRAFEF3um4S2WVL0mwQ6ZltAjCiA2R8o5Eu3Aq67lkbq00ml69rfecT1JXiAfjrnW0J64COJDbQ9kVCNM1YrpqLBmROHMOOw9o7Qz"
"1h78LbtKarVkVGaPIxhdZsWKjZwDD-6h15NZuKTAmcPUaucx6Dd4uCjJHld1BNsfKfX_81G03g"
)
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
refresh_token = jwt_handler.get_refresh_token()

assert refresh_token == expected_refresh_token
assert refresh_token == EXPECTED_REFRESH_TOKEN


@patch("ldap_jwt_auth.auth.jwt_handler.datetime")
def test_refresh_access_token(datetime_mock):
"""
Test refreshing an expired access token with a valid refresh token.
"""
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
access_token = jwt_handler.refresh_access_token(EXPIRED_ACCESS_TOKEN, VALID_REFRESH_TOKEN)

assert access_token == EXPECTED_ACCESS_TOKEN


@patch("ldap_jwt_auth.auth.jwt_handler.datetime")
def test_refresh_access_token_with_valid_access_token(datetime_mock):
"""
Test refreshing a valid access token with a valid refresh token.
"""
datetime_mock.now.return_value = mock_datetime_now()

jwt_handler = JWTHandler()
access_token = jwt_handler.refresh_access_token(VALID_ACCESS_TOKEN, VALID_REFRESH_TOKEN)

assert access_token == EXPECTED_ACCESS_TOKEN


def test_refresh_access_token_with_invalid_access_token():
"""
Test refreshing an invalid access token with a valid refresh token.
"""
jwt_handler = JWTHandler()

with pytest.raises(JWTRefreshError) as exc:
jwt_handler.refresh_access_token("invalid", VALID_REFRESH_TOKEN)
assert str(exc.value) == "Unable to refresh access token"


def test_refresh_access_token_with_expired_refresh_token():
"""
Test refreshing an expired access token with an expired refresh token.
"""
jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.refresh_access_token(EXPIRED_ACCESS_TOKEN, EXPIRED_REFRESH_TOKEN)
assert str(exc.value) == "Invalid JWT token"


def test_verify_token_with_access_token():
"""
Test verifying a valid access token.
"""
access_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjoyNTM0MDIzMDA3OTl9.bagU2Wix8wKzydVU"
"_L3ZZuuMAxGxV4OTuZq_kS2Fuwm839_8UZOkICnPTkkpvsm1je0AWJaIXLGgwEa5zUjpG6lTrMMmzR9Zi63F0NXpJqQqoOZpTBMYBaggsXqFkd"
"sv-yAKUZ8MfjCEyk3UZ4PXZmEcUZcLhKcXZr4kYJPjio2e5WOGpdjK6q7s-iHGs9DQFT_IoCnw9CkyOKwYdgpB35hIGHkNjiwVSHpyKbFQvzJm"
"Iv5XCTSRYqq01fldh-QYuZqZeuaFidKbLRH610o2-1IfPMUr-yPtj5PZ-AaX-XTLkuMqdVMCk0_jeW9Os2BPtyUDkpcu1fvW3_S6_dK3nQ"
)

jwt_handler = JWTHandler()
payload = jwt_handler.verify_token(access_token)
payload = jwt_handler.verify_token(VALID_ACCESS_TOKEN)

assert payload == {"username": "username", "exp": 253402300799}

Expand All @@ -77,15 +148,8 @@ def test_verify_token_with_refresh_token():
"""
Test verifying a valid refresh token.
"""
refresh_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI1MzQwMjMwMDc5OX0.h4Hv_sq4-ika1rpuRx7k3pp0cF_BZ65WVSbIHS7oh9SjP"
"pGHtGhVHU1IJXzFtyA9TH-68JpAZ24Dm6bXbH6VJKoc7RCbmJXm44ufN32ga7jDqXH340oKvi_wdhEHaCf2HXjzsHHD7_D6XIcxU71v2W5_j8V"
"uwpr3SdX6ea_yLIaCDWynN6FomPtUepQAOg3c7DdKohbJD8WhKIDV8UKuLtFdRBfN4HEK5nNs0JroROPhcYM9L_JIQZpdI0c83fDFuXQC-cAyg"
"zrSnGJ6O4DyScNL3VBNSmNTBtqYOs1szvkpvF9rICPgbEEJnbS6g5kmGld3eioeuDJIxeQglSbxog"
)

jwt_handler = JWTHandler()
payload = jwt_handler.verify_token(refresh_token)
payload = jwt_handler.verify_token(VALID_REFRESH_TOKEN)

assert payload == {"exp": 253402300799}

Expand All @@ -94,33 +158,30 @@ def test_verify_token_with_expired_access_token():
"""
Test verifying an expired access token.
"""
expired_access_token = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VybmFtZSI6InVzZXJuYW1lIiwiZXhwIjotNjIxMzU1OTY4MDB9.G_cfC8PNYE5yERyy"
"QNRk9mTmDusU_rEPgm7feo2lWQF6QMNnf8PUN-61FfMNRVE0QDSvAmIMMNEOa8ma0JHZARafgnYJfn1_FSJSoRxC740GpG8EFSWrpM-dQXnoD2"
"63V9FlK-On6IbhF-4Rh9MdoxNyZk2Lj7NvCzJ7gbgbgYM5-sJXLxB-I5LfMfuYM3fx2cRixZFA153l46tFzcMVBrAiBxl_LdyxTIOPfHF0UGla"
"W2UtFi02gyBU4E4wTOqPc4t_CSi1oBSbY7h9O63i8IU99YsOCdvZ7AD3ePxyM1xJR7CFHycg9Z_IDouYnJmXpTpbFMMl7SjME3cVMfMrAQ"
)

jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.verify_token(expired_access_token)
jwt_handler.verify_token(EXPIRED_ACCESS_TOKEN)
assert str(exc.value) == "Invalid JWT token"


def test_verify_token_with_expired_refresh_token():
"""
Test verifying an expired refresh token.
"""
expired_refresh_tokenb = (
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOi02MjEzNTU5NjgwMH0.Er0A8dvdZi7o1FK3b-Te2IkUjDJZjI0aANsP7bbAbeITP"
"RnR0YEhavmuLT1zaoALQjUzfSgtH0s3I-YbUr2ssqG1DnKh83uts3J2_EXIXQZBeuZisCW1nN1LC2nsR6o4HQEsbMsINjJviHeMWS8nRC06XXp"
"N1WFPaGBxXkLFeDWb3SXiirZ79m7lUBwQvVzpfeA337e_AejG45mtadgfW3xpDCw-6sVVIA-cuzruxnjRKAzJrw_goA9X4MukRXbnzou2mgkxF"
"Ks_-6hdTFDI-B47wYqalP6KC5nqzjrCpvjmukgM-DN0uAhm2TUzUmE5EXtRLEYMRqsSmog4hYq1Nw"
)
jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.verify_token(EXPIRED_REFRESH_TOKEN)
assert str(exc.value) == "Invalid JWT token"


def test_verify_token_with_invalid_token():
"""
Test verifying an invalid access token.
"""
jwt_handler = JWTHandler()

with pytest.raises(InvalidJWTError) as exc:
jwt_handler.verify_token(expired_refresh_tokenb)
jwt_handler.verify_token("invalid")
assert str(exc.value) == "Invalid JWT token"