diff --git a/server/src/api/v1.py b/server/src/api/v1.py index bd7c0a27..79758755 100644 --- a/server/src/api/v1.py +++ b/server/src/api/v1.py @@ -179,20 +179,24 @@ def check_token_reservation_timeout( def check_token_permissions( auth_token: str, secret_key: str, - priority: int, - queue: str, - reservation_timeout: int, + job_data: dict, ) -> bool: """ Validates token received from client and checks if it can push a job to the queue with the requested priority """ + priority_level = job_data.get("job_priority", 0) + job_queue = job_data["job_queue"] priority_allowed = check_token_priority( - auth_token, secret_key, queue, priority + auth_token, secret_key, job_queue, priority_level ) - queue_allowed = check_token_queue(auth_token, secret_key, queue) + queue_allowed = check_token_queue(auth_token, secret_key, job_queue) + + reserve_data = job_data.get("reserve_data", {}) + # default reservation timeout is 1 hour + reservation_timeout = reserve_data.get("timeout", 3600) reservation_time_allowed = check_token_reservation_timeout( - auth_token, secret_key, reservation_timeout, queue + auth_token, secret_key, reservation_timeout, job_queue ) return priority_allowed and queue_allowed and reservation_time_allowed @@ -222,15 +226,10 @@ def job_builder(data: dict, auth_token: str): priority_level = data.get("job_priority", 0) job_queue = data["job_queue"] - reserve_data = data.get("reserve_data", {}) - # default reservation timeout is 1 hour - reservation_timeout = reserve_data.get("timeout", 3600) allowed = check_token_permissions( auth_token, os.environ.get("JWT_SIGNING_KEY"), - priority_level, - job_queue, - reservation_timeout, + data, ) if not allowed: abort( @@ -792,14 +791,7 @@ def generate_token(allowed_resources, secret_key): "iat": datetime.now(timezone.utc), # Issued at time "sub": "access_token", } - if "max_priority" in allowed_resources: - token_payload["max_priority"] = allowed_resources["max_priority"] - if "allowed_queues" in allowed_resources: - token_payload["allowed_queues"] = allowed_resources["allowed_queues"] - if "max_reservation_time" in allowed_resources: - token_payload["max_reservation_time"] = allowed_resources[ - "max_reservation_time" - ] + token_payload.update(allowed_resources) token = jwt.encode(token_payload, secret_key, algorithm="HS256") return token @@ -822,12 +814,25 @@ def validate_client_key_pair(client_id: str, client_key: str): client_permissions_entry["client_secret_hash"].encode("utf8"), ): return None + client_permissions_entry.pop("_id", None) return client_permissions_entry @v1.post("/oauth2/token") def retrieve_token(): - """Get JWT with priority and queue permissions""" + """ + Get JWT with priority and queue permissions + + Before being encrypted, the JWT can contain fields like: + { + exp: , + iat: , + sub: , + max_priority: , + allowed_queues: , + max_reservation_time: , + } + """ auth_header = request.authorization if auth_header is None: return "No authorization header specified", 401