Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: auth route conflict #987

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_route(
injected_dependencies = self.dependencies.get_dependency_map(self)

if auth_required:
self.middleware_router.add_auth_middleware(endpoint)(handler)
self.middleware_router.add_auth_middleware(endpoint, route_type)(handler)

if isinstance(route_type, str):
http_methods = {
Expand Down
4 changes: 2 additions & 2 deletions robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def spawn_process(
for middleware_type, middleware_function in global_middlewares:
server.add_global_middleware(middleware_type, middleware_function)

for route_type, endpoint, function in route_middlewares:
server.add_middleware_route(route_type, endpoint, function)
for middleware_type, endpoint, function, route_type in route_middlewares:
server.add_middleware_route(middleware_type, endpoint, function, route_type)

if Events.STARTUP in event_handlers:
server.add_startup_handler(event_handlers[Events.STARTUP])
Expand Down
1 change: 1 addition & 0 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ class Server:
middleware_type: MiddlewareType,
route: str,
function: FunctionInfo,
route_type: HttpMethod,
) -> None:
pass
def add_startup_handler(self, function: FunctionInfo) -> None:
Expand Down
10 changes: 7 additions & 3 deletions robyn/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RouteMiddleware(NamedTuple):
middleware_type: MiddlewareType
route: str
function: FunctionInfo
route_type: HttpMethod


class GlobalMiddleware(NamedTuple):
Expand Down Expand Up @@ -263,6 +264,7 @@ def add_route(
self,
middleware_type: MiddlewareType,
endpoint: str,
route_type: HttpMethod,
handler: Callable,
injected_dependencies: dict,
) -> Callable:
Expand All @@ -283,10 +285,10 @@ def add_route(
params,
new_injected_dependencies,
)
self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function))
self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function, route_type))
return handler

def add_auth_middleware(self, endpoint: str):
def add_auth_middleware(self, endpoint: str, route_type: HttpMethod):
"""
This method adds an authentication middleware to the specified endpoint.
"""
Expand All @@ -308,6 +310,7 @@ def inner_handler(request: Request, *args):
self.add_route(
MiddlewareType.BEFORE_REQUEST,
endpoint,
route_type,
inner_handler,
injected_dependencies,
)
Expand Down Expand Up @@ -336,11 +339,12 @@ def inner_handler(*args, **kwargs):
self.add_route(
middleware_type,
endpoint,
HttpMethod.GET,
async_inner_handler,
injected_dependencies,
)
else:
self.add_route(middleware_type, endpoint, inner_handler, injected_dependencies)
self.add_route(middleware_type, endpoint, HttpMethod.GET, inner_handler, injected_dependencies)
Comment on lines +342 to +347
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a GET http method? Is a middleware always treated as a get request?

else:
params = dict(inspect.signature(handler).parameters)

Expand Down
26 changes: 21 additions & 5 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,28 @@ impl Server {
middleware_type: &MiddlewareType,
route: &str,
function: FunctionInfo,
http_method: HttpMethod,
) {
let mut endpoint_prefixed_with_method = http_method.to_string();

if !route.starts_with('/') {
endpoint_prefixed_with_method.push('/');
}

endpoint_prefixed_with_method.push_str(route);

debug!(
"MiddleWare Route added for {:?} {} ",
middleware_type, route
middleware_type, &endpoint_prefixed_with_method
);

self.middleware_router
.add_route(middleware_type, route, function, None)
.add_route(
middleware_type,
&endpoint_prefixed_with_method,
function,
None,
)
.unwrap();
}

Expand Down Expand Up @@ -420,13 +435,15 @@ async fn index(
) -> impl Responder {
let mut request = Request::from_actix_request(&req, payload, &global_request_headers).await;

let route = format!("{}{}", req.method(), req.uri().path());

// Before middleware
// Global
let mut before_middlewares =
middleware_router.get_global_middlewares(&MiddlewareType::BeforeRequest);
// Route specific
if let Some((function, route_params)) =
middleware_router.get_route(&MiddlewareType::BeforeRequest, req.uri().path())
middleware_router.get_route(&MiddlewareType::BeforeRequest, &route)
{
before_middlewares.push(function);
request.path_params = route_params;
Expand Down Expand Up @@ -487,8 +504,7 @@ async fn index(
let mut after_middlewares =
middleware_router.get_global_middlewares(&MiddlewareType::AfterRequest);
// Route specific
if let Some((function, _)) =
middleware_router.get_route(&MiddlewareType::AfterRequest, req.uri().path())
if let Some((function, _)) = middleware_router.get_route(&MiddlewareType::AfterRequest, &route)
{
after_middlewares.push(function);
}
Expand Down
7 changes: 7 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ impl HttpMethod {
}
}

// for: https://stackoverflow.com/a/32712140/9652621
impl std::fmt::Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

#[pyclass]
#[derive(Default, Debug, Clone)]
pub struct Url {
Expand Down
Loading