Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python Auth Updates #714

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 48 additions & 160 deletions backend/src/xfd_django/xfd_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,19 @@
# Third-Party Libraries
from django.conf import settings
from django.forms.models import model_to_dict
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from fastapi import Depends, HTTPException, Request, Security, status
from fastapi.security import APIKeyHeader
import jwt
from jwt import ExpiredSignatureError, InvalidTokenError
import requests

# from .helpers import user_to_dict
from .models import ApiKey, Organization, OrganizationTag, Role, User

# JWT_ALGORITHM = "RS256"
JWT_SECRET = os.getenv("JWT_SECRET")
SECRET_KEY = settings.SECRET_KEY
JWT_ALGORITHM = "HS256"
JWT_TIMEOUT_HOURS = 4

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False)


Expand Down Expand Up @@ -79,6 +76,7 @@ def decode_jwt_token(token):
Returns:
User: The user object decoded from the token, or None if invalid or expired.
"""

try:
payload = jwt.decode(token, JWT_SECRET, algorithm=JWT_ALGORITHM)
user = User.objects.get(id=payload["id"])
Expand All @@ -94,83 +92,8 @@ def hash_key(key: str) -> str:
Returns:
str: hashed API key value
"""
return hashlib.sha256(key.encode()).hexdigest()


# TODO: Confirm still needed
# async def get_user_info_from_cognito(token):
# """
# Get user info from cognito

# Args:
# token (_type_): _description_

# Returns:
# _type_: _description_
# """
# jwks_url = f"https://cognito-idp.us-east-1.amazonaws.com/{os.getenv('REACT_APP_USER_POOL_ID')}/.well-known/jwks.json"
# response = requests.get(jwks_url)
# jwks = response.json()
# unverified_header = jwt.get_unverified_header(token)

# for key in jwks["keys"]:
# if key["kid"] == unverified_header["kid"]:
# rsa_key = {
# "kty": key["kty"],
# "kid": key["kid"],
# "use": key["use"],
# "n": key["n"],
# "e": key["e"],
# }

# user_info = decode_jwt_token(token)
# return user_info


# def create_jwt_token(user):
# """
# Create a JWT token for a given user.

# Args:
# user (User): The user object for whom the token is created.

# Returns:
# str: The encoded JWT token.
# """
# payload = {
# "id": str(user.id),
# "email": user.email,
# "exp": datetime.now(datetime.timezone.utc) + timedelta(hours=JWT_TIMEOUT_HOURS),
# }
# return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)


# def decode_jwt_token(token):
# """
# Decode a JWT token to retrieve the user.

# Args:
# token (str): The JWT token to decode.

# Returns:
# User: The user object decoded from the token, or None if invalid or expired.
# """
# try:
# payload = jwt.decode(token, JWT_SECRET, algorithm=JWT_ALGORITHM)
# user = User.objects.get(id=payload["id"])
# return user
# except (ExpiredSignatureError, InvalidTokenError, User.DoesNotExist):
# return None


# def hash_key(key: str) -> str:
# """
# Helper to hash API key.

# Returns:
# str: hashed API key value
# """
# return hashlib.sha256(key.encode()).hexdigest()
return hashlib.sha256(key.encode()).hexdigest()


# TODO: Confirm still needed
Expand All @@ -180,7 +103,6 @@ async def get_user_info_from_cognito(token):
response = requests.get(jwks_url)
jwks = response.json()
unverified_header = jwt.get_unverified_header(token)

for key in jwks["keys"]:
if key["kid"] == unverified_header["kid"]:
rsa_key = {
Expand All @@ -190,11 +112,28 @@ async def get_user_info_from_cognito(token):
"n": key["n"],
"e": key["e"],
}

user_info = decode_jwt_token(token)
return user_info


async def get_token_from_header(request: Request) -> Optional[str]:
"""
Extract token from the Authorization header, allowing 'Bearer' or raw tokens.
Args:
request (Request): The incoming request object.
Returns:
Optional[str]: The token extracted from the Authorization header, or None if missing.
"""
auth_header = request.headers.get("Authorization")
if auth_header:
if auth_header.startswith("Bearer "):
return auth_header[7:] # Remove 'Bearer ' prefix
return auth_header # Return the token directly if no 'Bearer ' prefix
return None


def get_user_by_api_key(api_key: str):
"""Get a user by their API key."""
hashed_key = sha256(api_key.encode()).hexdigest()
Expand All @@ -209,14 +148,28 @@ def get_user_by_api_key(api_key: str):


def get_current_active_user(
request: Request,
api_key: Optional[str] = Security(api_key_header),
token: Optional[str] = Depends(oauth2_scheme),
token: Optional[str] = Depends(get_token_from_header),
):
"""Ensure the current user is authenticated and active."""
"""
Ensure the current user is authenticated and active, supporting either API key or token.
Args:
request (Request): The incoming request object.
api_key (Optional[str]): The API key provided in headers.
token (Optional[str]): The JWT token from the Authorization header.
Returns:
User: The authenticated user object.
Raises:
HTTPException: If authentication fails or credentials are invalid.
"""
user = None
if api_key:
user = get_user_by_api_key(api_key)
else:
elif token:
try:
# Decode token in Authorization header to get user
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
Expand Down Expand Up @@ -245,6 +198,12 @@ def get_current_active_user(
detail="Invalid token",
headers={"WWW-Authenticate": "Bearer"},
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No valid authentication credentials provided",
)

if user is None:
print("User not authenticated")
raise HTTPException(
Expand All @@ -255,9 +214,8 @@ def get_current_active_user(


async def process_user(decoded_token, access_token, refresh_token):
# Find the user by email
"""Process a user based on decoded token information."""
user = User.objects.filter(email=decoded_token["email"]).first()

if not user:
# Create a new user if they don't exist from Okta fields in SAML Response
user = User(
Expand All @@ -275,19 +233,9 @@ async def process_user(decoded_token, access_token, refresh_token):
user.lastLoggedIn = datetime.now()
user.save()

# # Create response object
# response = JSONResponse({"message": "User processed"})

# # Set cookies for access token and refresh token
# response.set_cookie(key="access_token", value=access_token, httponly=True, secure=True)
# response.set_cookie(key="refresh_token", value=refresh_token, httponly=True, secure=True)
# print(f"Response output: {str(response.headers)}")

# If user exists, generate a signed JWT token
if user:
if not JWT_SECRET:
raise HTTPException(status_code=500, detail="JWT_SECRET is not defined")

# Generate JWT token
signed_token = jwt.encode(
{
Expand All @@ -299,32 +247,14 @@ async def process_user(decoded_token, access_token, refresh_token):
algorithm=JWT_ALGORITHM,
)

# Set JWT token as a cookie
# response.set_cookie(key="id_token", value=signed_token, httponly=True, secure=True)

# Return the response with token and user info
# return JSONResponse(

process_resp = {
"token": signed_token,
"user": user_to_dict(user)
# "user": {
# "id": str(user.id),
# "email": user.email,
# "firstName": user.firstName,
# "lastName": user.lastName,
# "state": user.state,
# "regionId": user.regionId,
# }
}
print(f"process_resp: {process_resp}")
process_resp = {"token": signed_token, "user": user_to_dict(user)}
return process_resp

else:
raise HTTPException(status_code=400, detail="User not found")


async def get_jwt_from_code(auth_code: str):
"""Exchange authorization code for JWT tokens and decode."""
try:
callback_url = os.getenv("REACT_APP_COGNITO_CALLBACK_URL")
client_id = os.getenv("REACT_APP_COGNITO_CLIENT_ID")
Expand All @@ -342,13 +272,10 @@ async def get_jwt_from_code(auth_code: str):
"Content-Type": "application/x-www-form-urlencoded",
}

# Make Oauth2/token request with code
response = requests.post(
authorize_token_url, headers=headers, data=urlencode(authorize_token_body)
)
token_response = response.json()
print(f"oauth2/token response: {token_response}")

# Convert the id_token to bytes
id_token = token_response["id_token"].encode("utf-8")
access_token = token_response.get("access_token")
Expand All @@ -357,7 +284,6 @@ async def get_jwt_from_code(auth_code: str):
# Decode the token without verifying the signature (if needed)
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
print(f"decoded token: {decoded_token}")

return {
"refresh_token": refresh_token,
"id_token": id_token,
Expand All @@ -370,44 +296,6 @@ async def get_jwt_from_code(auth_code: str):
pass


# TODO: determine if we still need.
# async def handle_cognito_callback(body):
# try:
# print(f"handle_cognito_callback body input: {str(body)}")
# user_info = await get_user_info_from_cognito(body["token"])
# print(f"handle_cognito_callback user_info: {str(user_info)}")
# user = await update_or_create_user(user_info)
# token = create_jwt_token(user)
# print(f"handle_cognito_callback token: {str(token)}")
# return token, user
# except Exception as error:
# print(f"Error : {str(error)}")
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)
# ) from error


# # TODO: Uncomment the token and if not user token once the JWT from OKTA is working
# def get_current_active_user(
# api_key: str = Security(api_key_header),
# # token: str = Depends(oauth2_scheme),
# ):
# """Ensure the current user is authenticated and active."""
# user = None
# if api_key:
# user = get_user_by_api_key(api_key)
# # if not user and token:
# # user = decode_jwt_token(token)
# if user is None:
# print("User not authenticated")
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="Invalid authentication credentials",
# )
# print(f"Authenticated user: {user.id}")
# return user


def is_global_write_admin(current_user) -> bool:
"""Check if the user has global write admin permissions."""
return current_user and current_user.userType == "globalAdmin"
Expand Down
Loading