Skip to content

Commit

Permalink
Fixed item number 1
Browse files Browse the repository at this point in the history
  • Loading branch information
kliwongan committed Jan 19, 2023
1 parent f16e85c commit 6f5b5ae
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 72 deletions.
11 changes: 0 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 11 additions & 15 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from robyn import Robyn, serve_file, jsonify, WS, serve_html
from robyn.robyn import Response
from dataclasses import dataclass
from typing import Optional

from robyn.templating import JinjaTemplate

Expand All @@ -10,7 +8,7 @@
import os
import pathlib
import logging
from query_types import TestQueryType
from conftest import NestedCls, TestQueryType, TestForwardRef, TestCtor

app = Robyn(__file__)
websocket = WS(app, "/web_socket")
Expand Down Expand Up @@ -279,16 +277,6 @@ async def file_download_async():
file_path = os.path.join(current_file_path, "downloads", "test.txt")
return serve_file(file_path)

@dataclass
class Test():
f: int
g: int

@dataclass
class NestedCls():
f: Test
special: Optional[str] = "Nice"

@app.post("/query_validation", validate=True)
async def test_validation(a: int, b: str):
return jsonify({'a': a, 'b': b})
Expand All @@ -297,10 +285,18 @@ async def test_validation(a: int, b: str):
async def test_validation_complex(a: int, b: str, c: NestedCls):
return jsonify({'a': a, 'b': b, 'c': c})

@app.post("/query_validation_forwardref", validate=True)
async def test_validation_forwardref(a: int, b: str, c: TestQueryType):
@app.post("/query_validation_import", validate=True)
async def test_validation_import(a: int, b: str, c: TestQueryType):
return jsonify({'a': a, 'b': b, 'c': c})

@app.post("/query_validation_forwardref", validate=True)
async def test_validation_forwardref(a: TestForwardRef):
return jsonify({'a': a})

@app.post("/query_validation_ctor", validate=True)
async def test_validation_ctor(a: TestCtor):
return jsonify({'a': a})

if __name__ == "__main__":
app.add_request_header("server", "robyn")
current_file_path = pathlib.Path(__file__).parent.resolve()
Expand Down
34 changes: 33 additions & 1 deletion integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import signal
import sys
from typing import List
from typing import List, Optional
import pytest
import subprocess
import pathlib
Expand Down Expand Up @@ -89,3 +89,35 @@ def test_session():
time.sleep(5)
yield
kill_process(process)

# Classes for testing
class Test():
f: int
g: int

class NestedCls():
f: Test
special: Optional[str] = "Nice"

class TestCtor:
a: int
b: str

def __init__(self, a: int, b: str = "Nice"):
self.a = a
self.b = b

class Nested:
c: int
d: str

class TestQueryType:
a: int
b: str

class TestForwardRef:
a: 'Ref'

class Ref:
a: int
b: str
11 changes: 0 additions & 11 deletions integration_tests/query_types.py

This file was deleted.

59 changes: 43 additions & 16 deletions integration_tests/test_query_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,42 @@
}

CORRECT_BODY_COMPLEX_FORWARDREF = {
"a": 5,
"b": "hello",
"c": {
"a": 5,
"b": {"c": 7, "d": "Hello"}
}
"a": {
"a": {"a": 5, "b": "Hello"}
}
}

INCORRECT_BODY_COMPLEX_FORWARDREF = {
"a": 5,
"b": "hello",
"c": {
"a": {
"a": {"a": 5, "b": 7}
}
}

CORRECT_BODY_COMPLEX_CTOR = {
"a": {
"a": 5,
"b": {"c": 7, "d": 9}
}
"b": "Nicer"
}
}

INCORRECT_BODY_COMPLEX_CTOR = {
"a": {
"a": 5,
"b": 6
}
}

CORRECT_BODY_COMPLEX_CTOR_NODEFAULT = {
"a": {
"a": 5
}
}

CORRECT_BODY_COMPLEX_CTOR_NODEFAULT_RESULT = {
"a": {
"a": 5,
"b": "Nice"
}
}

def test_post_simple_correct(session):
Expand All @@ -86,11 +107,17 @@ def test_post_complex_default_correct(session):
assert res.status_code == 200
assert res.json() == CORRECT_BODY_COMPLEX_NODEFAULT_RESULT

def test_post_complex_forwardref_correct(session):
res = requests.post(f"{BASE_URL}/query_validation_forwardref", json=CORRECT_BODY_COMPLEX_FORWARDREF)
def test_post_complex_ctor_correct(session):
res = requests.post(f"{BASE_URL}/query_validation_ctor", json=CORRECT_BODY_COMPLEX_CTOR)
assert res.status_code == 200
assert res.json() == CORRECT_BODY_COMPLEX_FORWARDREF
assert res.json() == CORRECT_BODY_COMPLEX_CTOR

def test_post_complex_forwardref_incorrect(session):
res = requests.post(f"{BASE_URL}/query_validation_forwardref", json=INCORRECT_BODY_COMPLEX_FORWARDREF)
def test_post_complex_ctor_incorrect(session):
res = requests.post(f"{BASE_URL}/query_validation_ctor", json=INCORRECT_BODY_COMPLEX_CTOR)
assert res.status_code == 500

def test_post_complex_ctor_nodefault(session):
res = requests.post(f"{BASE_URL}/query_validation_ctor", json=CORRECT_BODY_COMPLEX_CTOR_NODEFAULT)
assert res.status_code == 200
assert res.json() == CORRECT_BODY_COMPLEX_CTOR_NODEFAULT_RESULT

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = "robyn"
version = "0.22.0"
version = "0.21.0"
description = "A web server that is fast!"
authors = ["Sanskar Jethi <sansyrox@gmail.com>"]

Expand Down
2 changes: 1 addition & 1 deletion robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from robyn.robyn import FunctionInfo, SocketHeld
from robyn.router import MiddlewareRouter, Router, WebSocketRouter
from robyn.types import Directory, Header
from robyn.dependencies import check_params_dependencies
from robyn.ws import WS
from robyn.dependencies import get_signature, check_params_dependencies
from robyn.env_populator import load_vars

logger = logging.getLogger(__name__)
Expand Down
64 changes: 50 additions & 14 deletions robyn/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from typing import Any, Callable, List, Dict, Union, ForwardRef, cast
from inspect import signature, Signature, Parameter, getmembers
from typing import Any, Callable, List, Dict, Union, ForwardRef, cast, get_type_hints
from inspect import signature, Signature, Parameter, getmembers, isfunction
import msgspec
import sys

Expand All @@ -10,8 +10,11 @@
class BaseValidation(msgspec.Struct, forbid_unknown_fields=True):
"""
Base class for validation, used to explicitly forbid unknown fields
The response body MUST be exactly as the params specify
The response body MUST be exactly as the params specify,
hence forbid_unknown_fields = True
"""
# Create a metaclass so we can "vary" the arguments sent into msgspec.Struct
# to increase user customizability?
pass

def decode_bytearray(req):
Expand Down Expand Up @@ -39,6 +42,14 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation

def model_to_dict(model: object) -> dict:
ret_val = dict()
for key, value in getmembers(model):
if not key.startswith('_'):
ret_val[key] = value

return ret_val

def check_custom_class(annotation) -> bool:
"""
Checks if the given class/annotation is a user-defined class
Expand All @@ -49,9 +60,34 @@ def check_custom_class(annotation) -> bool:
and not issubclass(annotation, sequence_types + (dict, )) \
and not issubclass(annotation, defaults)

def get_class_signature(call: Callable[..., Any]):
"""
Gets the signature of a class that has no constructor
"""
# Get the type annotations of the class
annotations = get_type_hints(call)
actual_params = annotations.keys()
globalns = getattr(call, "__globals__", {})

# Get object representation as dictionary (to get the default values)
class_dict = model_to_dict(call)

# Generate the signature
params = [
Parameter(
name = key,
default = class_dict.get(key, Parameter.empty),
kind = Parameter.POSITIONAL_OR_KEYWORD,
annotation = get_typed_annotation(annotations[key], globalns),
) for key in actual_params
]

return Signature(params)


def get_signature(call: Callable[..., Any]) -> Signature:
"""
Gets the dependencies of the function wrapped by the routing decorators
Gets the function signature wrapped by the routing decorators
"""
# Credit to FastAPI
sign = signature(call)
Expand All @@ -74,7 +110,15 @@ def create_model(call: Callable[..., Any]) -> msgspec.Struct:
"""
Creates a validation model for a given handler function
"""
sign = get_signature(call)
sign: Any = Any
# If constructor does not exist, and the class is a non-default class
# then use get_class_signature. Otherwise, if the constructor exists, we
# base the model off of the constructor's type annotations
if check_custom_class(type(call)) and not isfunction(call.__init__):
sign = get_class_signature(call)
else:
sign = get_signature(call)

cstruct = list()
for param in sign.parameters.values():
cust_model = None
Expand All @@ -85,14 +129,6 @@ def create_model(call: Callable[..., Any]) -> msgspec.Struct:

return msgspec.defstruct(call.__name__, fields=cstruct, bases=(BaseValidation,))

def model_to_dict(model: BaseValidation) -> dict:
ret_val = dict()
for key, value in getmembers(model):
if not key.startswith('_'):
ret_val[key] = value

return ret_val

def check_params_dependencies(call: Callable[..., Any], request: Dict[Any, Any]):
"""
Checks if the params match the dependencies of the route
Expand All @@ -106,7 +142,7 @@ def check_params_dependencies(call: Callable[..., Any], request: Dict[Any, Any])
func_input = msgspec.json.decode(bytes(request['body']), type=model)
except msgspec.ValidationError as exc:
raise msgspec.ValidationError from exc

return model_to_dict(func_input)


Expand Down
2 changes: 1 addition & 1 deletion robyn/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def serve_file(file_path: str) -> Dict[str, Any]:
}


def jsonify(input_dict: dict) -> bytes:
def jsonify(input_dict: dict) -> str:
"""
This function serializes input dict to a json string
Expand Down
4 changes: 3 additions & 1 deletion src/executors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ fn get_function_output<'a>(
let request_hashmap = request.to_hashmap(py).unwrap();

if function.validate_params {
println!("Validating params");
// Perform query param validation
let request_hashmap = request.to_hashmap(py).unwrap();
let handler = function.handler.as_ref(py);
let robyn = py.import("robyn").unwrap();
let check_dependencies = robyn.call_method1("check_params_dependencies", (handler, request_hashmap,));

println!("Got result of check dependencies {:?}", check_dependencies);
// Match error so that if the dependencies don't match
// we raise an internal server error
match check_dependencies {
Ok(r) => {
println!("Okay validation response from check dependencies");
let kwargs: &PyDict = r.extract().unwrap();
let response = handler.call((), Some(kwargs)).unwrap();
Ok(response)
Expand Down

0 comments on commit 6f5b5ae

Please sign in to comment.