diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 84f2f5f34..c3595b372 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -2,6 +2,7 @@ import pathlib from collections import defaultdict from typing import Optional +from base64 import b64encode from robyn import ( Request, @@ -13,7 +14,7 @@ serve_html, WebSocketConnector, ) -from robyn.authentication import AuthenticationHandler, BearerGetter, Identity +from robyn.authentication import AuthenticationHandler, BearerGetter, Identity, BasicGetter from robyn.robyn import Headers from robyn.templating import JinjaTemplate @@ -794,6 +795,34 @@ async def async_auth(request: Request): return "authenticated" +@app.get("/sync/auth/basic", auth_required=True, auth_middleware_name="basic") +def sync_auth_basic(request: Request): + assert request.identity is not None + assert request.identity.claims == {"key": "value"} + return "authenticated" + + +@app.get("/async/auth/basic", auth_required=True, auth_middleware_name="basic") +async def async_auth_basic(request: Request): + assert request.identity is not None + assert request.identity.claims == {"key": "value"} + return "authenticated" + + +@app.get("/sync/auth/bearer-2", auth_required=True, auth_middleware_name="bearer-2") +def sync_auth_bearer_2(request: Request): + assert request.identity is not None + assert request.identity.claims == {"key": "value"} + return "authenticated" + + +@app.get("/async/auth/bearer-2", auth_required=True, auth_middleware_name="bearer-2") +async def async_auth_bearer_2(request: Request): + assert request.identity is not None + assert request.identity.claims == {"key": "value"} + return "authenticated" + + # ===== Main ===== @@ -845,7 +874,7 @@ def main(): app.include_router(sub_router) app.include_router(di_subrouter) - class BasicAuthHandler(AuthenticationHandler): + class BearerAuthHandler(AuthenticationHandler): def authenticate(self, request: Request) -> Optional[Identity]: token = self.token_getter.get_token(request) if token is not None: @@ -855,7 +884,29 @@ def authenticate(self, request: Request) -> Optional[Identity]: return Identity(claims={"key": "value"}) return None - app.configure_authentication(BasicAuthHandler(token_getter=BearerGetter())) + class OtherBearerAuthHandler(AuthenticationHandler): + def authenticate(self, request: Request) -> Optional[Identity]: + token = self.token_getter.get_token(request) + if token is not None: + # Useless but we call the set_token method for testing purposes + self.token_getter.set_token(request, token) + if token == "valid-2": + return Identity(claims={"key": "value"}) + return None + + class BasicAuthHandler(AuthenticationHandler): + def authenticate(self, request: Request) -> Optional[Identity]: + username, password = self.token_getter.get_credentials(request) + if username is not None and password is not None: + # Useless but we call the set_token method for testing purposes + self.token_getter.set_token(request, b64encode(f"{username}:{password}".encode()).decode()) + if username == "valid" and password == "valid": + return Identity(claims={"key": "value"}) + return None + + app.configure_authentication(BasicAuthHandler(token_getter=BasicGetter(), name="basic")) + app.configure_authentication(BearerAuthHandler(token_getter=BearerGetter(), name="bearer", default=True)) + app.configure_authentication(OtherBearerAuthHandler(token_getter=BearerGetter(), name="bearer-2")) app.start(port=8080, _check_port=False) diff --git a/integration_tests/test_authentication.py b/integration_tests/test_authentication.py index 34a33f608..5391d6de9 100644 --- a/integration_tests/test_authentication.py +++ b/integration_tests/test_authentication.py @@ -1,4 +1,5 @@ import pytest +from base64 import b64encode from integration_tests.helpers.http_methods_helpers import get @@ -40,3 +41,81 @@ def test_invalid_authentication_no_token(session, function_type: str): r = get(f"/{function_type}/auth", should_check_response=False) assert r.status_code == 401 assert r.headers.get("WWW-Authenticate") == "BearerGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_valid_authentication_bearer_2(session, function_type: str): + r = get(f"/{function_type}/auth/bearer-2", headers={"Authorization": "Bearer valid-2"}) + assert r.text == "authenticated" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_token_bearer_2(session, function_type: str): + r = get( + f"/{function_type}/auth/bearer-2", + headers={"Authorization": "Bearer invalid"}, + should_check_response=False, + ) + assert r.status_code == 401 + assert r.headers.get("WWW-Authenticate") == "BearerGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_header_bearer_2(session, function_type: str): + r = get( + f"/{function_type}/auth/bearer-2", + headers={"Authorization": "Bear valid-2"}, + should_check_response=False, + ) + assert r.status_code == 401 + assert r.headers.get("WWW-Authenticate") == "BearerGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_no_token_bearer_2(session, function_type: str): + r = get(f"/{function_type}/auth/bearer-2", should_check_response=False) + assert r.status_code == 401 + assert r.headers.get("WWW-Authenticate") == "BearerGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_valid_authentication_basic(session, function_type: str): + r = get(f"/{function_type}/auth/basic", headers={"Authorization": f"Basic {b64encode('valid:valid'.encode()).decode()}"}) + assert r.text == "authenticated" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_token_basic(session, function_type: str): + r = get( + f"/{function_type}/auth/basic", + headers={"Authorization": "Basic invalid"}, + should_check_response=False, + ) + assert r.status_code == 401 + assert r.headers.get("WWW-Authenticate") == "BasicGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_header_basic(session, function_type: str): + r = get( + f"/{function_type}/auth/basic", + headers={"Authorization": "Bear valid-2"}, + should_check_response=False, + ) + assert r.status_code == 401 + assert r.headers.get("WWW-Authenticate") == "BasicGetter" + + +@pytest.mark.benchmark +@pytest.mark.parametrize("function_type", ["sync", "async"]) +def test_invalid_authentication_no_token_basic(session, function_type: str): + r = get(f"/{function_type}/auth/basic", should_check_response=False) + assert r.status_code == 401 + assert r.headers.get("WWW-Authenticate") == "BasicGetter" diff --git a/robyn/__init__.py b/robyn/__init__.py index e051e2116..cfd65134f 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -70,7 +70,7 @@ def __init__( self.directories: List[Directory] = [] self.event_handlers = {} self.exception_handler: Optional[Callable] = None - self.authentication_handler: Optional[AuthenticationHandler] = None + self.authentication_handler: List[AuthenticationHandler] = [] def _handle_dev_mode(self): cli_dev_mode = self.config.dev # --dev @@ -84,6 +84,28 @@ def _handle_dev_mode(self): logger.error("Ignoring ROBYN_DEV_MODE environment variable. Dev mode is not supported in the python wrapper.") raise SystemExit("Dev mode is not supported in the python wrapper. Please use the Robyn CLI. e.g. python3 -m robyn app.py") + def auth_handler_configured(self): + handler_count = len(self.authentication_handler) + if handler_count == 0: + return + if handler_count == 1: + self.authentication_handler[0].default = True + return + + default_handlers = [handler for handler in self.authentication_handler if handler.default] + + if len(default_handlers) == 0: + raise ValueError( + "Multiple authentication handlers are configured, but none is set as the default. " + "Please set one of the authentication handlers as the default." + ) + + if len(default_handlers) > 1: + raise ValueError( + "Multiple authentication handlers are configured with more than one default. " + "Please ensure only one authentication handler is set as the default." + ) + def add_route( self, route_type: Union[HttpMethod, str], @@ -91,6 +113,7 @@ def add_route( handler: Callable, is_const: bool = False, auth_required: bool = False, + auth_middleware_name: Optional[str] = None, ): """ Connect a URI to a handler @@ -100,6 +123,7 @@ def add_route( :param handler function: represents the sync or async function passed as a handler for the route :param is_const bool: represents if the handler is a const function or not :param auth_required bool: represents if the route needs authentication or not + :param auth_middleware_name str: represents auth handler name for the route """ """ We will add the status code here only @@ -107,7 +131,7 @@ def add_route( injected_dependencies = self.dependencies.get_dependency_map(self) if auth_required: - self.middleware_router.add_auth_middleware(endpoint)(handler) + self.middleware_router.add_auth_middleware(endpoint, auth_middleware_name)(handler) if isinstance(route_type, str): http_methods = { @@ -226,6 +250,8 @@ def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = T port = int(os.getenv("ROBYN_PORT", port)) open_browser = bool(os.getenv("ROBYN_BROWSER_OPEN", self.config.open_browser)) + self.auth_handler_configured() + if _check_port: while self.is_port_in_use(port): logger.error("Port %s is already in use. Please use a different port.", port) @@ -302,7 +328,7 @@ def inner(handler): return inner - def get(self, endpoint: str, const: bool = False, auth_required: bool = False): + def get(self, endpoint: str, const: bool = False, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.get decorator to add a route with the GET method @@ -310,11 +336,11 @@ def get(self, endpoint: str, const: bool = False, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required) + return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, auth_middleware_name=auth_middleware_name) return inner - def post(self, endpoint: str, auth_required: bool = False): + def post(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.post decorator to add a route with POST method @@ -322,11 +348,11 @@ def post(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def put(self, endpoint: str, auth_required: bool = False): + def put(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.put decorator to add a get route with PUT method @@ -334,11 +360,11 @@ def put(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def delete(self, endpoint: str, auth_required: bool = False): + def delete(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.delete decorator to add a route with DELETE method @@ -346,11 +372,11 @@ def delete(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def patch(self, endpoint: str, auth_required: bool = False): + def patch(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.patch decorator to add a route with PATCH method @@ -358,11 +384,11 @@ def patch(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def head(self, endpoint: str, auth_required: bool = False): + def head(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.head decorator to add a route with HEAD method @@ -370,11 +396,11 @@ def head(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def options(self, endpoint: str, auth_required: bool = False): + def options(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.options decorator to add a route with OPTIONS method @@ -382,11 +408,11 @@ def options(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def connect(self, endpoint: str, auth_required: bool = False): + def connect(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.connect decorator to add a route with CONNECT method @@ -394,11 +420,11 @@ def connect(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner - def trace(self, endpoint: str, auth_required: bool = False): + def trace(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None): """ The @app.trace decorator to add a route with TRACE method @@ -406,7 +432,7 @@ def trace(self, endpoint: str, auth_required: bool = False): """ def inner(handler): - return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required) + return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name) return inner @@ -434,7 +460,7 @@ def configure_authentication(self, authentication_handler: AuthenticationHandler :param authentication_handler: the instance of a class inheriting the AuthenticationHandler base class """ - self.authentication_handler = authentication_handler + self.authentication_handler.append(authentication_handler) self.middleware_router.set_authentication_handler(authentication_handler) diff --git a/robyn/authentication.py b/robyn/authentication.py index 5dd76af54..44f9f4eab 100644 --- a/robyn/authentication.py +++ b/robyn/authentication.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Tuple, Union +from base64 import b64decode from robyn.robyn import Headers, Identity, Request, Response from robyn.status_codes import HTTP_401_UNAUTHORIZED @@ -45,15 +46,30 @@ def set_token(cls, request: Request, token: str): """ raise NotImplementedError() + @classmethod + def get_credentials(cls, request: Request): + """ + Only available for Basic Getter. + Gets credentials from the request token basic. + This method will decode the token and return username password + :param request: The request object. + :return: Tuple of username and password. + """ + pass + class AuthenticationHandler(ABC): - def __init__(self, token_getter: TokenGetter): + def __init__(self, token_getter: TokenGetter, name: str, default: bool = False): """ Creates a new instance of the AuthenticationHandler class. This class is an abstract class used to authenticate a user. :param token_getter: The token getter used to get the token from the request. + :param name: The name of authentication handler ex: jwt, basic, etc. + :param default: set authentication handler is the default handler. """ self.token_getter = token_getter + self.name = name + self.default = default @property def unauthorized_response(self) -> Response: @@ -94,3 +110,32 @@ def get_token(cls, request: Request) -> Optional[str]: @classmethod def set_token(cls, request: Request, token: str): request.headers["Authorization"] = f"Bearer {token}" + + +class BasicGetter(TokenGetter): + """ + This class is used to get the token from the Authorization header. + The scheme of the header must be Basic. + """ + + @classmethod + def get_token(cls, request: Request) -> Optional[str]: + authorization_header = request.headers.get("authorization") + if not authorization_header or not authorization_header.startswith("Basic "): + return None + + return authorization_header[6:] # Remove the "Basic " prefix + + @classmethod + def set_token(cls, request: Request, token: str): + request.headers["Authorization"] = f"Basic {token}" + + @classmethod + def get_credentials(cls, request: Request) -> Union[Tuple[str, str], Tuple[None, None]]: + basic_token = cls.get_token(request) + try: + basic_token_decoded = b64decode(basic_token).decode() + username, _, password = basic_token_decoded.partition(":") + return (username, password) # Return username password from basic authorization + except Exception: + return (None, None) diff --git a/robyn/router.py b/robyn/router.py index 76fe522f9..e761c2f02 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -193,11 +193,11 @@ def __init__(self, dependencies: DependencyMap = DependencyMap()) -> None: super().__init__() self.global_middlewares: List[GlobalMiddleware] = [] self.route_middlewares: List[RouteMiddleware] = [] - self.authentication_handler: Optional[AuthenticationHandler] = None + self.authentication_handler: List[AuthenticationHandler] = [] self.dependencies = dependencies def set_authentication_handler(self, authentication_handler: AuthenticationHandler): - self.authentication_handler = authentication_handler + self.authentication_handler.append(authentication_handler) def add_route( self, @@ -226,7 +226,7 @@ def add_route( self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function)) return handler - def add_auth_middleware(self, endpoint: str): + def add_auth_middleware(self, endpoint: str, name: str = None): """ This method adds an authentication middleware to the specified endpoint. """ @@ -236,11 +236,27 @@ def add_auth_middleware(self, endpoint: str): def decorator(handler): @wraps(handler) def inner_handler(request: Request, *args): - if not self.authentication_handler: + if len(self.authentication_handler) == 0: raise AuthenticationNotConfiguredError() - identity = self.authentication_handler.authenticate(request) + + auth_handler = None + + for authentication_handler in self.authentication_handler: + if name and authentication_handler.name == name: + auth_handler = authentication_handler + break + if not name and authentication_handler.default: + auth_handler = authentication_handler + break + + if not auth_handler: + raise AuthenticationNotConfiguredError() + + identity = auth_handler.authenticate(request) + if identity is None: - return self.authentication_handler.unauthorized_response + return auth_handler.unauthorized_response + request.identity = identity return request