Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falcon ASGI support #46

Merged
merged 14 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Yet another library to generate OpenAPI document and validate request & response
* Validate query, JSON data, response data with [pydantic](https://github.com/samuelcolvin/pydantic/) :wink:
* Current support:
* Flask [demo](#flask)
* Falcon [demo](#falcon)
* Falcon [demo](#falcon) (including ASGI under Falcon 3+)
* Starlette [demo](#starlette)

## Quick Start
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Features
- Current support:

- Flask
- Falcon
- Falcon (including Falcon ASGI)
- Starlette

Quick Start
Expand Down
3 changes: 2 additions & 1 deletion spectree/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .base import BasePlugin
from .falcon_plugin import FalconPlugin
from .falcon_plugin import FalconAsgiPlugin, FalconPlugin
from .flask_plugin import FlaskPlugin
from .starlette_plugin import StarlettePlugin

PLUGINS = {
"base": BasePlugin,
"flask": FlaskPlugin,
"falcon": FalconPlugin,
"falcon-asgi": FalconAsgiPlugin,
"starlette": StarlettePlugin,
}
92 changes: 86 additions & 6 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,33 @@ def __init__(self, html, spec_url):

def on_get(self, req, resp):
resp.content_type = "text/html"
resp.body = self.page
# resp.body is deprecated in Falcon 3
if hasattr(resp, "text"):
resp.text = self.page
else:
resp.body = self.page


DOC_CLASS = [x.__name__ for x in (DocPage, OpenAPI)]
class OpenAPIAsgi(OpenAPI):
async def on_get(self, req, resp):
super().on_get(req, resp)


class DocPageAsgi(DocPage):
async def on_get(self, req, resp):
super().on_get(req, resp)


DOC_CLASS = [x.__name__ for x in (DocPage, OpenAPI, DocPageAsgi, OpenAPIAsgi)]

HTTP_422 = "422 Unprocessable Entity"
HTTP_500 = "500 Internal Service Response Validation Error"


class FalconPlugin(BasePlugin):
OPEN_API_ROUTE_CLASS = OpenAPI
DOC_PAGE_ROUTE_CLASS = DocPage

def __init__(self, spectree):
super().__init__(spectree)
try:
Expand Down Expand Up @@ -56,11 +76,13 @@ def __init__(self, spectree):

def register_route(self, app):
self.app = app
self.app.add_route(self.config.spec_url, OpenAPI(self.spectree.spec))
self.app.add_route(
self.config.spec_url, self.OPEN_API_ROUTE_CLASS(self.spectree.spec)
)
for ui in PAGES:
self.app.add_route(
f"/{self.config.PATH}/{ui}",
DocPage(PAGES[ui], self.config.spec_url),
self.DOC_PAGE_ROUTE_CLASS(PAGES[ui], self.config.spec_url),
)

def find_routes(self):
Expand Down Expand Up @@ -168,7 +190,7 @@ def validate(

except ValidationError as err:
req_validation_error = err
_resp.status = "422 Unprocessable Entity"
_resp.status = HTTP_422
_resp.media = err.errors()

before(_req, _resp, req_validation_error, _self)
Expand All @@ -183,7 +205,7 @@ def validate(
model.parse_obj(_resp.media)
except ValidationError as err:
resp_validation_error = err
_resp.status = "500 Internal Service Response Validation Error"
_resp.status = HTTP_500
_resp.media = err.errors()

after(_req, _resp, resp_validation_error, _self)
Expand All @@ -192,3 +214,61 @@ def bypass(self, func, method):
if isinstance(func, partial):
return True
return inspect.isfunction(func)


class FalconAsgiPlugin(FalconPlugin):
"""Light wrapper around default Falcon plug-in to support Falcon 3.0 ASGI apps"""

ASYNC = True
OPEN_API_ROUTE_CLASS = OpenAPIAsgi
DOC_PAGE_ROUTE_CLASS = DocPageAsgi

async def request_validation(self, req, query, json, headers, cookies):
if query:
req.context.query = query.parse_obj(req.params)
if headers:
req.context.headers = headers.parse_obj(req.headers)
if cookies:
req.context.cookies = cookies.parse_obj(req.cookies)
if json:
try:
media = await req.get_media()
except self.UnsupportedMediaType:
media = None
req.context.json = json.parse_obj(media)

async def validate(
self, func, query, json, headers, cookies, resp, before, after, *args, **kwargs
):
# falcon endpoint method arguments: (self, req, resp)
_self, _req, _resp = args[:3]
req_validation_error, resp_validation_error = None, None
try:
await self.request_validation(_req, query, json, headers, cookies)
if self.config.ANNOTATIONS:
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
kwargs[name] = getattr(_req.context, name)

except ValidationError as err:
req_validation_error = err
_resp.status = HTTP_422
_resp.media = err.errors()

before(_req, _resp, req_validation_error, _self)
if req_validation_error:
return

await func(*args, **kwargs)

if resp and resp.has_model():
model = resp.find_model(_resp.status[:3])
if model:
try:
model.parse_obj(_resp.media)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()

after(_req, _resp, resp_validation_error, _self)
2 changes: 1 addition & 1 deletion spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SpecTree:
"""
Interface

:param str backend_name: choose from ('flask', 'falcon', 'starlette')
:param str backend_name: choose from ('flask', 'falcon', 'falcon-asgi', 'starlette')
:param backend: a backend that inherit `SpecTree.plugins.base.BasePlugin`
:param app: backend framework application instance (can be registered later)
:param before: a callback function of the form
Expand Down
156 changes: 156 additions & 0 deletions tests/test_plugin_falcon_asgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from random import randint

import pytest
from falcon import testing

from spectree import Response, SpecTree

from .common import JSON, Cookies, Headers, Query, Resp, StrDict, api_tag

App = pytest.importorskip("falcon.asgi.App", reason="Missing required Falcon 3.0")


def before_handler(req, resp, err, instance):
if err:
resp.set_header("X-Error", "Validation Error")


def after_handler(req, resp, err, instance):
print(instance.name)
resp.set_header("X-Name", instance.name)
print(resp.get_header("X-Name"))


api = SpecTree(
"falcon-asgi", before=before_handler, after=after_handler, annotations=True
)


class Ping:
name = "health check"

@api.validate(headers=Headers, tags=["test", "health"])
async def on_get(self, req, resp):
"""summary
description
"""
resp.media = {"msg": "pong"}


class UserScore:
name = "sorted random score"

def extra_method(self):
pass

@api.validate(resp=Response(HTTP_200=StrDict))
async def on_get(self, req, resp, name):
self.extra_method()
resp.media = {"name": name}

@api.validate(
query=Query,
json=JSON,
cookies=Cookies,
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
)
async def on_post(self, req, resp, name):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
assert req.cookies["pub"] == "abcdefg"
resp.media = {"name": req.context.json.name, "score": score}


class UserScoreAnnotated:
name = "sorted random score"

def extra_method(self):
pass

@api.validate(resp=Response(HTTP_200=StrDict))
async def on_get(self, req, resp, name):
self.extra_method()
resp.media = {"name": name}

@api.validate(
resp=Response(HTTP_200=Resp, HTTP_401=None),
tags=[api_tag, "test"],
)
async def on_post(
self, req, resp, name, query: Query, json: JSON, cookies: Cookies
):
score = [randint(0, req.context.json.limit) for _ in range(5)]
score.sort(reverse=req.context.query.order)
assert req.context.cookies.pub == "abcdefg"
assert req.cookies["pub"] == "abcdefg"
resp.media = {"name": req.context.json.name, "score": score}


app = App()
app.add_route("/ping", Ping())
app.add_route("/api/user/{name}", UserScore())
app.add_route("/api/user_annotated/{name}", UserScoreAnnotated())
api.register(app)


@pytest.fixture
def client():
return testing.TestClient(app)


def test_falcon_validate(client):
resp = client.simulate_request(
"GET", "/ping", headers={"Content-Type": "text/plain"}
)
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error", resp.headers

resp = client.simulate_request(
"GET", "/ping", headers={"lang": "en-US", "Content-Type": "text/plain"}
)
assert resp.json == {"msg": "pong"}
assert resp.headers.get("X-Error") is None
assert resp.headers.get("X-Name") == "health check"

resp = client.simulate_request(
"GET", "/api/user/falcon", headers={"Content-Type": "text/plain"}
)
assert resp.json == {"name": "falcon"}

resp = client.simulate_request("POST", "/api/user/falcon")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"
assert resp.headers.get("X-Name") is None

resp = client.simulate_request(
"POST",
"/api/user/falcon?order=1",
json=dict(name="falcon", limit=10),
headers={"Cookie": "pub=abcdefg"},
)
assert resp.json["name"] == "falcon"
assert resp.json["score"] == sorted(resp.json["score"], reverse=True)
assert resp.headers.get("X-Name") == "sorted random score"

resp = client.simulate_request(
"POST",
"/api/user/falcon?order=0",
json=dict(name="falcon", limit=10),
headers={"Cookie": "pub=abcdefg"},
)
assert resp.json["name"] == "falcon"
assert resp.json["score"] == sorted(resp.json["score"], reverse=False)
assert resp.headers.get("X-Name") == "sorted random score"


def test_falcon_doc(client):
resp = client.simulate_get("/apidoc/openapi.json")
assert resp.json == api.spec

resp = client.simulate_get("/apidoc/redoc")
assert resp.status_code == 200

resp = client.simulate_get("/apidoc/swagger")
assert resp.status_code == 200