diff --git a/realtime/_async/client.py b/realtime/_async/client.py index fda8487..c099e9d 100644 --- a/realtime/_async/client.py +++ b/realtime/_async/client.py @@ -2,7 +2,10 @@ import json import logging import re +from base64 import b64decode +from datetime import datetime from functools import wraps +from math import floor from typing import Any, Callable, Dict, List, Optional import websockets @@ -256,6 +259,30 @@ async def set_auth(self, token: Optional[str]) -> None: Returns: None """ + # No empty string tokens. + if isinstance(token, str) and len(token.strip()) == 0: + raise ValueError("Provide a valid jwt token") + + if token: + parsed = None + try: + payload = token.split(".")[1] + "==" + parsed = json.loads(b64decode(payload).decode("utf-8")) + except Exception: + raise ValueError("InvalidJWTToken") + + if parsed: + # Handle expired token if any. + if "exp" in parsed: + now = floor(datetime.now().timestamp()) + valid = now - parsed["exp"] < 0 + if not valid: + raise ValueError( + f"InvalidJWTToken: Invalid value for JWT claim 'exp' with value { parsed['exp'] }" + ) + else: + raise ValueError("InvalidJWTToken: expected claim 'exp'") + self.access_token = token for _, channel in self.channels.items(): diff --git a/tests/test_connection.py b/tests/test_connection.py index fd510c9..1ee658a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -49,8 +49,8 @@ async def access_token() -> str: async def test_set_auth(socket: AsyncRealtimeClient): await socket.connect() - await socket.set_auth("jwt") - assert socket.access_token == "jwt" + with pytest.raises(ValueError): + await socket.set_auth("jwt") # Invalid JWT. await socket.close()