From ba5ba62f4157ad9c601b6c50c31c4f77443c4edd Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Mon, 25 Oct 2021 21:18:06 +0100 Subject: [PATCH] Reqeusts object is now optional --- robyn/__init__.py | 4 +++- src/processor.rs | 32 ++++++++++++++++++++++++++------ src/router.rs | 35 +++++++++++++++++++++-------------- src/server.rs | 24 ++++++++++++++++++++---- test_python/base_routes.py | 6 +++--- 5 files changed, 73 insertions(+), 28 deletions(-) diff --git a/robyn/__init__.py b/robyn/__init__.py index 3593f50c7..fcebaa44a 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -2,6 +2,7 @@ import os import argparse import asyncio +from inspect import signature from .robyn import Server from .responses import static_file, jsonify @@ -40,8 +41,9 @@ def add_route(self, route_type, endpoint, handler): """ We will add the status code here only """ + number_of_params = len(signature(handler).parameters) self.server.add_route( - route_type, endpoint, handler, asyncio.iscoroutinefunction(handler) + route_type, endpoint, handler, asyncio.iscoroutinefunction(handler), number_of_params ) def add_directory(self, route, directory_path, index_file=None, show_files_listing=False): diff --git a/src/processor.rs b/src/processor.rs index acf157a54..5a656379f 100644 --- a/src/processor.rs +++ b/src/processor.rs @@ -35,12 +35,22 @@ pub fn apply_headers(response: &mut HttpResponseBuilder, headers: &Arc) /// pub async fn handle_request( function: PyFunction, + number_of_params: u8, headers: &Arc, payload: &mut web::Payload, req: &HttpRequest, route_params: HashMap, ) -> HttpResponse { - let contents = match execute_function(function, payload, headers, req, route_params).await { + let contents = match execute_function( + function, + payload, + headers, + req, + route_params, + number_of_params, + ) + .await + { Ok(res) => res, Err(err) => { println!("Error: {:?}", err); @@ -76,6 +86,7 @@ async fn execute_function( headers: &Headers, req: &HttpRequest, route_params: HashMap, + number_of_params: u8, ) -> Result { let mut data: Option> = None; @@ -119,7 +130,12 @@ async fn execute_function( }; // this makes the request object to be accessible across every route - let coro: PyResult<&PyAny> = handler.call1((request,)); + let coro: PyResult<&PyAny> = match number_of_params { + 0 => handler.call0(), + 1 => handler.call1((request,)), + // this is done to accomodate any future params + 2_u8..=u8::MAX => handler.call1((request,)), + }; pyo3_asyncio::tokio::into_future(coro?) })?; @@ -166,16 +182,20 @@ async fn execute_function( let handler = handler.as_ref(py); request.insert("params", route_params.into_py(py)); request.insert("headers", headers_python.into_py(py)); - let output: PyResult<&PyAny> = match data { + match data { Some(res) => { let data = res.into_py(py); request.insert("body", data); - - handler.call1((request,)) } - None => handler.call1((request,)), + None => {} }; + let output: PyResult<&PyAny> = match number_of_params { + 0 => handler.call0(), + 1 => handler.call1((request,)), + // this is done to accomodate any future params + 2_u8..=u8::MAX => handler.call1((request,)), + }; let output: &str = output?.extract()?; Ok(output.to_string()) }) diff --git a/src/router.rs b/src/router.rs index 21b84ada4..5f3029ddb 100644 --- a/src/router.rs +++ b/src/router.rs @@ -11,15 +11,15 @@ use matchit::Node; /// Contains the thread safe hashmaps of different routes pub struct Router { - get_routes: Arc>>, - post_routes: Arc>>, - put_routes: Arc>>, - delete_routes: Arc>>, - patch_routes: Arc>>, - head_routes: Arc>>, - options_routes: Arc>>, - connect_routes: Arc>>, - trace_routes: Arc>>, + get_routes: Arc>>, + post_routes: Arc>>, + put_routes: Arc>>, + delete_routes: Arc>>, + patch_routes: Arc>>, + head_routes: Arc>>, + options_routes: Arc>>, + connect_routes: Arc>>, + trace_routes: Arc>>, } impl Router { @@ -38,7 +38,7 @@ impl Router { } #[inline] - fn get_relevant_map(&self, route: Method) -> Option<&Arc>>> { + fn get_relevant_map(&self, route: Method) -> Option<&Arc>>> { match route { Method::GET => Some(&self.get_routes), Method::POST => Some(&self.post_routes), @@ -54,7 +54,7 @@ impl Router { } #[inline] - fn get_relevant_map_str(&self, route: &str) -> Option<&Arc>>> { + fn get_relevant_map_str(&self, route: &str) -> Option<&Arc>>> { let method = match Method::from_bytes(route.as_bytes()) { Ok(res) => res, Err(_) => return None, @@ -65,7 +65,14 @@ impl Router { // Checks if the functions is an async function // Inserts them in the router according to their nature(CoRoutine/SyncFunction) - pub fn add_route(&self, route_type: &str, route: &str, handler: Py, is_async: bool) { + pub fn add_route( + &self, + route_type: &str, + route: &str, + handler: Py, + is_async: bool, + number_of_params: u8, + ) { let table = match self.get_relevant_map_str(route_type) { Some(table) => table, None => return, @@ -80,7 +87,7 @@ impl Router { table .write() .unwrap() - .insert(route.to_string(), function) + .insert(route.to_string(), (function, number_of_params)) .unwrap(); } @@ -88,7 +95,7 @@ impl Router { &self, route_method: Method, route: &str, - ) -> Option<(PyFunction, HashMap)> { + ) -> Option<((PyFunction, u8), HashMap)> { let table = self.get_relevant_map(route_method)?; match table.read().unwrap().at(route) { Ok(res) => { diff --git a/src/server.rs b/src/server.rs index 0be087cd2..9a74595be 100644 --- a/src/server.rs +++ b/src/server.rs @@ -143,9 +143,17 @@ impl Server { /// Add a new route to the routing tables /// can be called after the server has been started - pub fn add_route(&self, route_type: &str, route: &str, handler: Py, is_async: bool) { + pub fn add_route( + &self, + route_type: &str, + route: &str, + handler: Py, + is_async: bool, + number_of_params: u8, + ) { println!("Route added for {} {} ", route_type, route); - self.router.add_route(route_type, route, handler, is_async); + self.router + .add_route(route_type, route, handler, is_async, number_of_params); } } @@ -164,8 +172,16 @@ async fn index( req: HttpRequest, ) -> impl Responder { match router.get_route(req.method().clone(), req.uri().path()) { - Some((handler_function, route_params)) => { - handle_request(handler_function, &headers, &mut payload, &req, route_params).await + Some(((handler_function, number_of_params), route_params)) => { + handle_request( + handler_function, + number_of_params, + &headers, + &mut payload, + &req, + route_params, + ) + .await } None => { let mut response = HttpResponse::Ok(); diff --git a/test_python/base_routes.py b/test_python/base_routes.py index 2ec3f3a25..19c8c6c19 100644 --- a/test_python/base_routes.py +++ b/test_python/base_routes.py @@ -12,8 +12,8 @@ @app.get("/") -async def h(request): - print(request) +async def h(requests): + print(requests) global callCount callCount += 1 message = "Called " + str(callCount) + " times" @@ -26,7 +26,7 @@ async def test(request): return static_file("./index.html") @app.get("/jsonify") -async def json_get(request): +async def json_get(): return jsonify({"hello": "world"})