From 5072dd52a7539befe27733952ccb3e8abe925f6a Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Mon, 21 Jun 2021 20:18:40 +0530 Subject: [PATCH 1/2] Add support for sync functions --- Cargo.lock | 4 ++-- robyn/__init__.py | 6 +++++- src/process.rs | 46 ++++++++++++++++++++++++++++++++++++---------- src/server.rs | 12 +++++++++++- test.py | 4 ++++ 5 files changed, 58 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90a6e7c31..1fee96282 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -575,9 +575,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c79ba603c337335df6ba6dd6afc38c38a7d5e1b0c871678439ea973cd62a118e" +checksum = "5fb2ed024293bb19f7a5dc54fe83bf86532a44c12a2bb8ba40d64a4509395ca2" dependencies = [ "autocfg", "bytes", diff --git a/robyn/__init__.py b/robyn/__init__.py index 710ef5719..0d058c332 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -1,5 +1,5 @@ -from .robyn import * from .robyn import Server +from asyncio import iscoroutinefunction class Robyn: """This is the python wrapper for the Robyn binaries. @@ -8,6 +8,10 @@ def __init__(self) -> None: self.server = Server() def add_route(self, route_type, endpoint, handler): + handler = { + "is_async": iscoroutinefunction(handler), + "handler": handler + } self.server.add_route(route_type, endpoint, handler) def start(self): diff --git a/src/process.rs b/src/process.rs index 805c7d01b..c3e2c8d8c 100644 --- a/src/process.rs +++ b/src/process.rs @@ -2,20 +2,46 @@ use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; // pyO3 module use pyo3::prelude::*; +use pyo3::types::{PyAny, PyDict}; -pub async fn handle_message(handler: Py, mut stream: TcpStream) { - let f = Python::with_gil(|py| { - let coro = handler.as_ref(py).call0().unwrap(); - pyo3_asyncio::into_future(&coro).unwrap() +enum PyFunction { + CoRoutine(Py), + OutPut(String), +} + +pub async fn handle_message(process_object: Py, mut stream: TcpStream) { + let function: PyFunction = Python::with_gil(|py| { + let process_object_wrapper: &PyAny = process_object.as_ref(py); + let py_dict = process_object_wrapper.downcast::().unwrap(); + let is_async: bool = py_dict.get_item("is_coroutine").unwrap().extract().unwrap(); + let handler: &PyAny = py_dict.get_item("handler").unwrap(); + if is_async { + let coro = handler.call0().unwrap(); + PyFunction::CoRoutine(coro.into()) + } else { + let s: &str = handler.call0().unwrap().extract().unwrap(); + PyFunction::OutPut(String::from(s)) + } }); - let output = f.await.unwrap(); + let contents = match function { + PyFunction::CoRoutine(coro) => { + let x = Python::with_gil(|py| { + let x = coro.as_ref(py); + pyo3_asyncio::into_future(x).unwrap() + }); + let output = x.await.unwrap(); + Python::with_gil(|py| -> PyResult { + let contents: &str = output.extract(py).unwrap(); + Ok(contents.to_string()) + }) + .unwrap() + } + PyFunction::OutPut(x) => x, + }; + + // let output = op.await.unwrap(); let status_line = "HTTP/1.1 200 OK"; - let contents = Python::with_gil(|py| -> PyResult { - let contents: &str = output.extract(py).unwrap(); - Ok(contents.to_string()) - }) - .unwrap(); let len = contents.len(); let response = format!( diff --git a/src/server.rs b/src/server.rs index 730eae2e4..20eebb5af 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,6 +5,7 @@ use std::sync::Arc; // pyO3 module use pyo3::prelude::*; use pyo3::types::PyAny; +use pyo3::types::PyDict; use tokio::io::AsyncReadExt; use tokio::net::{TcpListener, TcpStream}; @@ -49,7 +50,16 @@ impl Server { }; } - pub fn add_route(&self, route_type: &str, route: String, handler: Py) { + pub fn add_route( + &self, + route_type: &str, + route: String, + handler: Py, + ) { + // Python::with_gil(|py| { + // let py_dict: &PyDict = py_obj.as_ref(py); + // println!("{}", py_dict.get_item("is_coroutine").unwrap()); + // }); println!("{} {} ", route_type, route); let route = Route::new(RouteType::Route((route, route_type.to_string()))); self.router.add_route(route_type, route, handler); diff --git a/test.py b/test.py index dd3b6b328..850b8237f 100644 --- a/test.py +++ b/test.py @@ -13,5 +13,9 @@ async def sleeper(): await asyncio.sleep(5) return "sleep function" +@app.get("/blocker") +def blocker(): + return "blocker function" + app.start() From 42263d021934d882da3fec3366fb0749bbba56d3 Mon Sep 17 00:00:00 2001 From: Sanskar Jethi Date: Mon, 21 Jun 2021 20:44:10 +0530 Subject: [PATCH 2/2] Fix minor bug with the py_obj keys --- src/process.rs | 2 +- test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/process.rs b/src/process.rs index c3e2c8d8c..32fc0792a 100644 --- a/src/process.rs +++ b/src/process.rs @@ -13,7 +13,7 @@ pub async fn handle_message(process_object: Py, mut stream: TcpStream) { let function: PyFunction = Python::with_gil(|py| { let process_object_wrapper: &PyAny = process_object.as_ref(py); let py_dict = process_object_wrapper.downcast::().unwrap(); - let is_async: bool = py_dict.get_item("is_coroutine").unwrap().extract().unwrap(); + let is_async: bool = py_dict.get_item("is_async").unwrap().extract().unwrap(); let handler: &PyAny = py_dict.get_item("handler").unwrap(); if is_async { let coro = handler.call0().unwrap(); diff --git a/test.py b/test.py index 850b8237f..0c8569d24 100644 --- a/test.py +++ b/test.py @@ -15,6 +15,8 @@ async def sleeper(): @app.get("/blocker") def blocker(): + import time + time.sleep(100) return "blocker function" app.start()