Skip to content

Commit

Permalink
Merge pull request #714 from cisagov/NV-python-auth-updates
Browse files Browse the repository at this point in the history
Python Auth Updates
  • Loading branch information
nickviola authored Nov 14, 2024
2 parents a572a2d + 8a42c31 commit a9a0d02
Showing 1 changed file with 48 additions and 160 deletions.
208 changes: 48 additions & 160 deletions backend/src/xfd_django/xfd_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,19 @@
# Third-Party Libraries
from django.conf import settings
from django.forms.models import model_to_dict
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from fastapi import Depends, HTTPException, Request, Security, status
from fastapi.security import APIKeyHeader
import jwt
from jwt import ExpiredSignatureError, InvalidTokenError
import requests

# from .helpers import user_to_dict
from .models import ApiKey, Organization, OrganizationTag, Role, User

# JWT_ALGORITHM = "RS256"
JWT_SECRET = os.getenv("JWT_SECRET")
SECRET_KEY = settings.SECRET_KEY
JWT_ALGORITHM = "HS256"
JWT_TIMEOUT_HOURS = 4

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)


Expand Down Expand Up @@ -79,6 +76,7 @@ def decode_jwt_token(token):
Returns:
User: The user object decoded from the token, or None if invalid or expired.
"""

try:
payload = jwt.decode(token, JWT_SECRET, algorithm=JWT_ALGORITHM)
user = User.objects.get(id=payload["id"])
Expand All @@ -94,83 +92,8 @@ def hash_key(key: str) -> str:
Returns:
str: hashed API key value
"""
return hashlib.sha256(key.encode()).hexdigest()


# TODO: Confirm still needed
# async def get_user_info_from_cognito(token):
# """
# Get user info from cognito

# Args:
# token (_type_): _description_

# Returns:
# _type_: _description_
# """
# jwks_url = f"https://cognito-idp.us-east-1.amazonaws.com/{os.getenv('REACT_APP_USER_POOL_ID')}/.well-known/jwks.json"
# response = requests.get(jwks_url)
# jwks = response.json()
# unverified_header = jwt.get_unverified_header(token)

# for key in jwks["keys"]:
# if key["kid"] == unverified_header["kid"]:
# rsa_key = {
# "kty": key["kty"],
# "kid": key["kid"],
# "use": key["use"],
# "n": key["n"],
# "e": key["e"],
# }

# user_info = decode_jwt_token(token)
# return user_info


# def create_jwt_token(user):
# """
# Create a JWT token for a given user.

# Args:
# user (User): The user object for whom the token is created.

# Returns:
# str: The encoded JWT token.
# """
# payload = {
# "id": str(user.id),
# "email": user.email,
# "exp": datetime.now(datetime.timezone.utc) + timedelta(hours=JWT_TIMEOUT_HOURS),
# }
# return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)


# def decode_jwt_token(token):
# """
# Decode a JWT token to retrieve the user.

# Args:
# token (str): The JWT token to decode.

# Returns:
# User: The user object decoded from the token, or None if invalid or expired.
# """
# try:
# payload = jwt.decode(token, JWT_SECRET, algorithm=JWT_ALGORITHM)
# user = User.objects.get(id=payload["id"])
# return user
# except (ExpiredSignatureError, InvalidTokenError, User.DoesNotExist):
# return None


# def hash_key(key: str) -> str:
# """
# Helper to hash API key.

# Returns:
# str: hashed API key value
# """
# return hashlib.sha256(key.encode()).hexdigest()
return hashlib.sha256(key.encode()).hexdigest()


# TODO: Confirm still needed
Expand All @@ -180,7 +103,6 @@ async def get_user_info_from_cognito(token):
response = requests.get(jwks_url)
jwks = response.json()
unverified_header = jwt.get_unverified_header(token)

for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {
Expand All @@ -190,11 +112,28 @@ async def get_user_info_from_cognito(token):
"n": key["n"],
"e": key["e"],
}

user_info = decode_jwt_token(token)
return user_info


async def get_token_from_header(request: Request) -> Optional[str]:
"""
Extract token from the Authorization header, allowing 'Bearer' or raw tokens.
Args:
request (Request): The incoming request object.
Returns:
Optional[str]: The token extracted from the Authorization header, or None if missing.
"""
auth_header = request.headers.get("Authorization")
if auth_header:
if auth_header.startswith("Bearer "):
return auth_header[7:] # Remove 'Bearer ' prefix
return auth_header # Return the token directly if no 'Bearer ' prefix
return None


def get_user_by_api_key(api_key: str):
"""Get a user by their API key."""
hashed_key = sha256(api_key.encode()).hexdigest()
Expand All @@ -209,14 +148,28 @@ def get_user_by_api_key(api_key: str):


def get_current_active_user(
request: Request,
api_key: Optional[str] = Security(api_key_header),
token: Optional[str] = Depends(oauth2_scheme),
token: Optional[str] = Depends(get_token_from_header),
):
"""Ensure the current user is authenticated and active."""
"""
Ensure the current user is authenticated and active, supporting either API key or token.
Args:
request (Request): The incoming request object.
api_key (Optional[str]): The API key provided in headers.
token (Optional[str]): The JWT token from the Authorization header.
Returns:
User: The authenticated user object.
Raises:
HTTPException: If authentication fails or credentials are invalid.
"""
user = None
if api_key:
user = get_user_by_api_key(api_key)
else:
elif token:
try:
# Decode token in Authorization header to get user
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
Expand Down Expand Up @@ -245,6 +198,12 @@ def get_current_active_user(
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No valid authentication credentials provided",
)

if user is None:
print("User not authenticated")
raise HTTPException(
Expand All @@ -255,9 +214,8 @@ def get_current_active_user(


async def process_user(decoded_token, access_token, refresh_token):
# Find the user by email
"""Process a user based on decoded token information."""
user = User.objects.filter(email=decoded_token["email"]).first()

if not user:
# Create a new user if they don't exist from Okta fields in SAML Response
user = User(
Expand All @@ -275,19 +233,9 @@ async def process_user(decoded_token, access_token, refresh_token):
user.lastLoggedIn = datetime.now()
user.save()

# # Create response object
# response = JSONResponse({"message": "User processed"})

# # Set cookies for access token and refresh token
# response.set_cookie(key="access_token", value=access_token, httponly=True, secure=True)
# response.set_cookie(key="refresh_token", value=refresh_token, httponly=True, secure=True)
# print(f"Response output: {str(response.headers)}")

# If user exists, generate a signed JWT token
if user:
if not JWT_SECRET:
raise HTTPException(status_code=500, detail="JWT_SECRET is not defined")

# Generate JWT token
signed_token = jwt.encode(
{
Expand All @@ -299,32 +247,14 @@ async def process_user(decoded_token, access_token, refresh_token):
algorithm=JWT_ALGORITHM,
)

# Set JWT token as a cookie
# response.set_cookie(key="id_token", value=signed_token, httponly=True, secure=True)

# Return the response with token and user info
# return JSONResponse(

process_resp = {
"token": signed_token,
"user": user_to_dict(user)
# "user": {
# "id": str(user.id),
# "email": user.email,
# "firstName": user.firstName,
# "lastName": user.lastName,
# "state": user.state,
# "regionId": user.regionId,
# }
}
print(f"process_resp: {process_resp}")
process_resp = {"token": signed_token, "user": user_to_dict(user)}
return process_resp

else:
raise HTTPException(status_code=400, detail="User not found")


async def get_jwt_from_code(auth_code: str):
"""Exchange authorization code for JWT tokens and decode."""
try:
callback_url = os.getenv("REACT_APP_COGNITO_CALLBACK_URL")
client_id = os.getenv("REACT_APP_COGNITO_CLIENT_ID")
Expand All @@ -342,13 +272,10 @@ async def get_jwt_from_code(auth_code: str):
"Content-Type": "application/x-www-form-urlencoded",
}

# Make Oauth2/token request with code
response = requests.post(
authorize_token_url, headers=headers, data=urlencode(authorize_token_body)
)
token_response = response.json()
print(f"oauth2/token response: {token_response}")

# Convert the id_token to bytes
id_token = token_response["id_token"].encode("utf-8")
access_token = token_response.get("access_token")
Expand All @@ -357,7 +284,6 @@ async def get_jwt_from_code(auth_code: str):
# Decode the token without verifying the signature (if needed)
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
print(f"decoded token: {decoded_token}")

return {
"refresh_token": refresh_token,
"id_token": id_token,
Expand All @@ -370,44 +296,6 @@ async def get_jwt_from_code(auth_code: str):
pass


# TODO: determine if we still need.
# async def handle_cognito_callback(body):
# try:
# print(f"handle_cognito_callback body input: {str(body)}")
# user_info = await get_user_info_from_cognito(body["token"])
# print(f"handle_cognito_callback user_info: {str(user_info)}")
# user = await update_or_create_user(user_info)
# token = create_jwt_token(user)
# print(f"handle_cognito_callback token: {str(token)}")
# return token, user
# except Exception as error:
# print(f"Error : {str(error)}")
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)
# ) from error


# # TODO: Uncomment the token and if not user token once the JWT from OKTA is working
# def get_current_active_user(
# api_key: str = Security(api_key_header),
# # token: str = Depends(oauth2_scheme),
# ):
# """Ensure the current user is authenticated and active."""
# user = None
# if api_key:
# user = get_user_by_api_key(api_key)
# # if not user and token:
# # user = decode_jwt_token(token)
# if user is None:
# print("User not authenticated")
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="Invalid authentication credentials",
# )
# print(f"Authenticated user: {user.id}")
# return user


def is_global_write_admin(current_user) -> bool:
"""Check if the user has global write admin permissions."""
return current_user and current_user.userType == "globalAdmin"
Expand Down

0 comments on commit a9a0d02

Please sign in to comment.