Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middlewares #157

Merged
merged 9 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,27 @@ def message():
async def hello(request):
global callCount
callCount += 1
_message = "Called " + str(callCount) + " times"
message = "Called " + str(callCount) + " times"
print(message)
return jsonify(request)


@app.before_request("/")
async def hello_before_request(request):
global callCount
callCount += 1
print(request)
return ""


@app.after_request("/")
async def hello_after_request(request):
global callCount
callCount += 1
print(request)
return ""


@app.get("/test/:id")
async def test(request):
print(request)
Expand Down
83 changes: 82 additions & 1 deletion robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def __init__(self, file_object):
self.routes = []
self.headers = []
self.routes = []
self.directories = []
self.middlewares = []
self.web_sockets = {}
self.directories = []
self.event_handlers = {}

def add_route(self, route_type, endpoint, handler):
Expand All @@ -61,6 +62,84 @@ def add_route(self, route_type, endpoint, handler):
)
)

def add_middleware_route(self, route_type, endpoint, handler):
"""
[This is base handler for the middleware decorator]

:param route_type [str]: [??]
:param endpoint [str]: [endpoint for the route added]
:param handler [function]: [represents the sync or async function passed as a handler for the route]
"""

""" We will add the status code here only
"""
number_of_params = len(signature(handler).parameters)
self.middlewares.append(
(
route_type,
endpoint,
handler,
asyncio.iscoroutinefunction(handler),
number_of_params,
)
)

def before_request(self, endpoint):
"""
[The @app.before_request decorator to add a get route]

:param endpoint [str]: [endpoint to server the route]
"""

# This inner function is basically a wrapper arround the closure(decorator)
# being returned.
# It takes in a handler and converts it in into a closure
# and returns the arguments.
# Arguments are returned as they could be modified by the middlewares.
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
return args

def inner_handler(*args):
handler(*args)
return args

if asyncio.iscoroutinefunction(handler):
self.add_middleware_route("BEFORE_REQUEST", endpoint, async_inner_handler)
else:
self.add_middleware_route("BEFORE_REQUEST", endpoint, inner_handler)

return inner

def after_request(self, endpoint):
"""
[The @app.after_request decorator to add a get route]

:param endpoint [str]: [endpoint to server the route]
"""

# This inner function is basically a wrapper arround the closure(decorator)
# being returned.
# It takes in a handler and converts it in into a closure
# and returns the arguments.
# Arguments are returned as they could be modified by the middlewares.
def inner(handler):
async def async_inner_handler(*args):
await handler(args)
return args

def inner_handler(*args):
handler(*args)
return args

if asyncio.iscoroutinefunction(handler):
self.add_middleware_route("AFTER_REQUEST", endpoint, async_inner_handler)
else:
self.add_middleware_route("AFTER_REQUEST", endpoint, inner_handler)

return inner

def add_directory(
self, route, directory_path, index_file=None, show_files_listing=False
):
Expand Down Expand Up @@ -95,6 +174,7 @@ def start(self, url="127.0.0.1", port=5000):
if not self.dev:
workers = self.workers
socket = SocketHeld(url, port)
print(self.middlewares)
for _ in range(self.processes):
copied_socket = socket.try_clone()
p = Process(
Expand All @@ -103,6 +183,7 @@ def start(self, url="127.0.0.1", port=5000):
self.directories,
self.headers,
self.routes,
self.middlewares,
self.web_sockets,
self.event_handlers,
copied_socket,
Expand Down
7 changes: 6 additions & 1 deletion robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def spawn_process(
directories, headers, routes, web_sockets, event_handlers, socket, workers
directories, headers, routes, middlewares, web_sockets, event_handlers, socket, workers
):
"""
This function is called by the main process handler to create a server runtime.
Expand All @@ -21,6 +21,7 @@ def spawn_process(
:param directories tuple: the list of all the directories and related data in a tuple
:param headers tuple: All the global headers in a tuple
:param routes tuple: The routes touple, containing the description about every route.
:param middlewares tuple: The middleware router touple, containing the description about every route.
sansyrox marked this conversation as resolved.
Show resolved Hide resolved
:param web_sockets list: This is a list of all the web socket routes
:param event_handlers Dict: This is an event dict that contains the event handlers
:param socket Socket: This is the main tcp socket, which is being shared across multiple processes.
Expand Down Expand Up @@ -53,6 +54,10 @@ def spawn_process(
route_type, endpoint, handler, is_async, number_of_params = route
server.add_route(route_type, endpoint, handler, is_async, number_of_params)

for route in middlewares:
route_type, endpoint, handler, is_async, number_of_params = route
server.add_middleware_route(route_type, endpoint, handler, is_async, number_of_params)

if "startup" in event_handlers:
server.add_startup_handler(event_handlers[Events.STARTUP][0], event_handlers[Events.STARTUP][1])

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod processor;
mod router;
mod routers;
mod server;
mod shared_socket;
mod types;
Expand Down
141 changes: 135 additions & 6 deletions src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::{bail, Result};
use crate::types::{Headers, PyFunction};
use futures_util::stream::StreamExt;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::{PyDict, PyTuple};

use std::fs::File;
use std::io::Read;
Expand Down Expand Up @@ -40,7 +40,7 @@ pub async fn handle_request(
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<&str, &str>,
queries: HashMap<String, String>,
) -> HttpResponse {
let contents = match execute_http_function(
function,
Expand All @@ -67,6 +67,36 @@ pub async fn handle_request(
response.body(contents)
}

pub async fn handle_middleware_request(
function: PyFunction,
number_of_params: u8,
headers: &Arc<Headers>,
payload: &mut web::Payload,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
) -> Py<PyTuple> {
let contents = match execute_middleware_function(
function,
payload,
headers,
req,
route_params,
queries,
number_of_params,
)
.await
{
Ok(res) => res,
Err(err) => Python::with_gil(|py| {
println!("{:?}", err);
PyTuple::empty(py).into_py(py)
}),
};

contents
}

// ideally this should be async
/// A function to read lossy files and serve it as a html response
///
Expand All @@ -81,6 +111,101 @@ fn read_file(file_path: &str) -> String {
String::from_utf8_lossy(&buf).to_string()
}

async fn execute_middleware_function<'a>(
function: PyFunction,
payload: &mut web::Payload,
headers: &Headers,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<String, String>,
number_of_params: u8,
) -> Result<Py<PyTuple>> {
// TODO:
// try executing the first version of middleware(s) here
// with just headers as params

let mut data: Option<Vec<u8>> = None;

if req.method() == Method::POST
|| req.method() == Method::PUT
|| req.method() == Method::PATCH
|| req.method() == Method::DELETE
{
let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
bail!("Body content Overflow");
}
body.extend_from_slice(&chunk);
}

data = Some(body.to_vec())
}

// request object accessible while creating routes
let mut request = HashMap::new();
let mut headers_python = HashMap::new();
for elem in headers.into_iter() {
headers_python.insert(elem.key().clone(), elem.value().clone());
}

match function {
PyFunction::CoRoutine(handler) => {
let output = Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));
request.insert("body", data.into_py(py));

// this makes the request object to be accessible across every route
let coro: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((request,)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((request,)),
};
pyo3_asyncio::tokio::into_future(coro?)
})?;

let output = output.await?;

let res = Python::with_gil(|py| -> PyResult<Py<PyTuple>> {
let output: Py<PyTuple> = output.extract(py).unwrap();
Ok(output)
})?;

Ok(res)
}

PyFunction::SyncFunction(handler) => {
tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let handler = handler.as_ref(py);
request.insert("params", route_params.into_py(py));
request.insert("queries", queries.into_py(py));
request.insert("headers", headers_python.into_py(py));
request.insert("body", data.into_py(py));

let output: PyResult<&PyAny> = match number_of_params {
0 => handler.call0(),
1 => handler.call1((request,)),
// this is done to accomodate any future params
2_u8..=u8::MAX => handler.call1((request,)),
};

let output: Py<PyTuple> = output?.extract().unwrap();

Ok(output)
})
})
.await?
}
}
}

// Change this!
#[inline]
async fn execute_http_function(
Expand All @@ -89,7 +214,7 @@ async fn execute_http_function(
headers: &Headers,
req: &HttpRequest,
route_params: HashMap<String, String>,
queries: HashMap<&str, &str>,
queries: HashMap<String, String>,
number_of_params: u8,
) -> Result<String> {
let mut data: Option<Vec<u8>> = None;
Expand Down Expand Up @@ -211,9 +336,12 @@ async fn execute_http_function(
}
}

pub async fn execute_event_handler(event_handler: Option<PyFunction>, event_loop: Py<PyAny>) {
pub async fn execute_event_handler(
event_handler: Option<Arc<PyFunction>>,
event_loop: Arc<Py<PyAny>>,
) {
match event_handler {
Some(handler) => match handler {
Some(handler) => match &(*handler) {
PyFunction::SyncFunction(function) => {
println!("Startup event handler");
Python::with_gil(|py| {
Expand All @@ -225,7 +353,8 @@ pub async fn execute_event_handler(event_handler: Option<PyFunction>, event_loop
println!("Startup event handler async");

let coroutine = function.as_ref(py).call0().unwrap();
pyo3_asyncio::into_future_with_loop(event_loop.as_ref(py), coroutine).unwrap()
pyo3_asyncio::into_future_with_loop((*event_loop).as_ref(py), coroutine)
.unwrap()
});
future.await.unwrap();
}
Expand Down
Loading