diff --git a/robyn/__init__.py b/robyn/__init__.py index 268be447d..08c085991 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -110,7 +110,7 @@ def add_route( injected_dependencies = self.dependencies.get_dependency_map(self) if auth_required: - self.middleware_router.add_auth_middleware(endpoint)(handler) + self.middleware_router.add_auth_middleware(endpoint, route_type)(handler) if isinstance(route_type, str): http_methods = { diff --git a/robyn/processpool.py b/robyn/processpool.py index 85de6683e..2878387c0 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -180,8 +180,8 @@ def spawn_process( for middleware_type, middleware_function in global_middlewares: server.add_global_middleware(middleware_type, middleware_function) - for route_type, endpoint, function in route_middlewares: - server.add_middleware_route(route_type, endpoint, function) + for middleware_type, endpoint, function, route_type in route_middlewares: + server.add_middleware_route(middleware_type, endpoint, function, route_type) if Events.STARTUP in event_handlers: server.add_startup_handler(event_handlers[Events.STARTUP]) diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index f354f99bd..9402be920 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -343,6 +343,7 @@ class Server: middleware_type: MiddlewareType, route: str, function: FunctionInfo, + route_type: HttpMethod, ) -> None: pass def add_startup_handler(self, function: FunctionInfo) -> None: diff --git a/robyn/router.py b/robyn/router.py index d63017768..c777214cc 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -30,6 +30,7 @@ class RouteMiddleware(NamedTuple): middleware_type: MiddlewareType route: str function: FunctionInfo + route_type: HttpMethod class GlobalMiddleware(NamedTuple): @@ -263,6 +264,7 @@ def add_route( self, middleware_type: MiddlewareType, endpoint: str, + route_type: HttpMethod, handler: Callable, injected_dependencies: dict, ) -> Callable: @@ -283,10 +285,10 @@ def add_route( params, new_injected_dependencies, ) - self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function)) + self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function, route_type)) return handler - def add_auth_middleware(self, endpoint: str): + def add_auth_middleware(self, endpoint: str, route_type: HttpMethod): """ This method adds an authentication middleware to the specified endpoint. """ @@ -308,6 +310,7 @@ def inner_handler(request: Request, *args): self.add_route( MiddlewareType.BEFORE_REQUEST, endpoint, + route_type, inner_handler, injected_dependencies, ) @@ -336,11 +339,12 @@ def inner_handler(*args, **kwargs): self.add_route( middleware_type, endpoint, + HttpMethod.GET, async_inner_handler, injected_dependencies, ) else: - self.add_route(middleware_type, endpoint, inner_handler, injected_dependencies) + self.add_route(middleware_type, endpoint, HttpMethod.GET, inner_handler, injected_dependencies) else: params = dict(inspect.signature(handler).parameters) diff --git a/src/server.rs b/src/server.rs index 7d7771986..09fd9e430 100644 --- a/src/server.rs +++ b/src/server.rs @@ -365,13 +365,28 @@ impl Server { middleware_type: &MiddlewareType, route: &str, function: FunctionInfo, + http_method: HttpMethod, ) { + let mut endpoint_prefixed_with_method = http_method.to_string(); + + if !route.starts_with('/') { + endpoint_prefixed_with_method.push('/'); + } + + endpoint_prefixed_with_method.push_str(route); + debug!( "MiddleWare Route added for {:?} {} ", - middleware_type, route + middleware_type, &endpoint_prefixed_with_method ); + self.middleware_router - .add_route(middleware_type, route, function, None) + .add_route( + middleware_type, + &endpoint_prefixed_with_method, + function, + None, + ) .unwrap(); } @@ -420,13 +435,15 @@ async fn index( ) -> impl Responder { let mut request = Request::from_actix_request(&req, payload, &global_request_headers).await; + let route = format!("{}{}", req.method(), req.uri().path()); + // Before middleware // Global let mut before_middlewares = middleware_router.get_global_middlewares(&MiddlewareType::BeforeRequest); // Route specific if let Some((function, route_params)) = - middleware_router.get_route(&MiddlewareType::BeforeRequest, req.uri().path()) + middleware_router.get_route(&MiddlewareType::BeforeRequest, &route) { before_middlewares.push(function); request.path_params = route_params; @@ -487,8 +504,7 @@ async fn index( let mut after_middlewares = middleware_router.get_global_middlewares(&MiddlewareType::AfterRequest); // Route specific - if let Some((function, _)) = - middleware_router.get_route(&MiddlewareType::AfterRequest, req.uri().path()) + if let Some((function, _)) = middleware_router.get_route(&MiddlewareType::AfterRequest, &route) { after_middlewares.push(function); } diff --git a/src/types/mod.rs b/src/types/mod.rs index a272ee682..60932c537 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -50,6 +50,13 @@ impl HttpMethod { } } +// for: https://stackoverflow.com/a/32712140/9652621 +impl std::fmt::Display for HttpMethod { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + #[pyclass] #[derive(Default, Debug, Clone)] pub struct Url {