Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose request/connection info #441

Merged
merged 7 commits into from
Mar 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,31 @@ async def async_param(request):
return id


@app.get("/sync/extra/*extra")
def sync_param_extra(request):
extra = request["params"]["extra"]
return extra


@app.get("/async/extra/*extra")
async def async_param_extra(request):
extra = request["params"]["extra"]
return extra


# Request Info


@app.get("/sync/http/param")
def sync_http_param(request):
return jsonify({"url": request["url"], "method": request["method"]})


@app.get("/async/http/param")
async def async_http_param(request):
return jsonify({"url": request["url"], "method": request["method"]})


# HTML serving


Expand Down
35 changes: 35 additions & 0 deletions integration_tests/test_basic_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,38 @@ def test_json_get(route: str, expected_json: dict, session):
for key in expected_json.keys():
assert key in res.json()
assert res.json()[key] == expected_json[key]


@pytest.mark.benchmark
@pytest.mark.parametrize(
"route, expected_json",
[
(
"/sync/http/param",
{
"method": "GET",
"url": {
"host": "127.0.0.1:8080",
"path": "/sync/http/param",
"scheme": "http",
},
},
),
(
"/async/http/param",
{
"method": "GET",
"url": {
"host": "127.0.0.1:8080",
"path": "/async/http/param",
"scheme": "http",
},
},
),
],
)
def test_http_request_info_get(route: str, expected_json: dict, session):
res = get(route)
for key in expected_json.keys():
assert key in res.json()
assert res.json()[key] == expected_json[key]
9 changes: 9 additions & 0 deletions integration_tests/test_get_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def test_param(function_type: str, session):
assert r.text == "12345"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_param_suffix(function_type: str, session):
r = get(f"/{function_type}/extra/foo/1/baz")
assert r.text == "foo/1/baz"
r = get(f"/{function_type}/extra/foo/bar/baz")
assert r.text == "foo/bar/baz"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_serve_html(function_type: str, session):
Expand Down
39 changes: 37 additions & 2 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use actix_web::{http::Method, HttpRequest};
use anyhow::Result;
use dashmap::DashMap;
use pyo3::exceptions::PyValueError;
use pyo3::types::{PyBytes, PyString};
use pyo3::{exceptions, prelude::*};
use pyo3::types::{PyBytes, PyDict, PyString};
use pyo3::{exceptions, intern, prelude::*};

use crate::io_helpers::read_file;

Expand Down Expand Up @@ -108,13 +108,39 @@ impl FunctionInfo {
}
}

#[derive(Default)]
pub struct Url {
pub scheme: String,
pub host: String,
pub path: String,
}

impl Url {
fn new(scheme: &str, host: &str, path: &str) -> Self {
Self {
scheme: scheme.to_string(),
host: host.to_string(),
path: path.to_string(),
}
}

pub fn to_object(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
dict.set_item(intern!(py, "scheme"), self.scheme.as_str())?;
dict.set_item(intern!(py, "host"), self.host.as_str())?;
dict.set_item(intern!(py, "path"), self.path.as_str())?;
Ok(dict.into_py(py))
}
}

#[derive(Default)]
pub struct Request {
pub queries: HashMap<String, String>,
pub headers: HashMap<String, String>,
pub method: Method,
pub params: HashMap<String, String>,
pub body: Bytes,
pub url: Url,
}

impl Request {
Expand All @@ -139,15 +165,24 @@ impl Request {
method: req.method().clone(),
params: HashMap::new(),
body,
url: Url::new(
req.connection_info().scheme(),
req.connection_info().host(),
req.path(),
),
}
}

pub fn to_hashmap(&self, py: Python<'_>) -> Result<HashMap<&str, Py<PyAny>>> {
let mut result = HashMap::new();

result.insert("method", self.method.as_ref().to_object(py));
result.insert("params", self.params.to_object(py));
result.insert("queries", self.queries.to_object(py));
result.insert("headers", self.headers.to_object(py));
result.insert("body", self.body.to_object(py));
result.insert("url", self.url.to_object(py)?);

Ok(result)
}
}
Expand Down