Skip to content

Commit

Permalink
feat: import plugin using import_module (#266)
Browse files Browse the repository at this point in the history
* feat: import plugin using import_module

Signed-off-by: Keming <kemingy94@gmail.com>

* exit if err

Signed-off-by: Keming <kemingy94@gmail.com>

* fix test

Signed-off-by: Keming <kemingy94@gmail.com>

Signed-off-by: Keming <kemingy94@gmail.com>
(cherry picked from commit ebcb3d6)
  • Loading branch information
kemingy committed Oct 15, 2022
1 parent 86fd888 commit 104db2f
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 39 deletions.
11 changes: 10 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ SOURCE_FILES=spectree tests examples setup.py
install:
pip install -e .[email,flask,falcon,starlette,dev]

test:
import_test:
pip install -e .[email]
for module in flask falcon starlette; do \
pip install -U $$module; \
bash -c "python tests/import_module/test_$${module}_plugin.py" || exit 1; \
pip uninstall $$module -y; \
done

test: import_test
pip install -U -e .[email,flask,falcon,starlette]
pytest tests -vv -rs
pip uninstall falcon email-validator -y && pip install falcon==2.0.0
pytest tests -vv -rs

doc:
cd docs && make html

Expand Down
19 changes: 11 additions & 8 deletions spectree/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from collections import namedtuple

from .base import BasePlugin
from .falcon_plugin import FalconAsgiPlugin, FalconPlugin
from .flask_plugin import FlaskPlugin
from .starlette_plugin import StarlettePlugin

__all__ = ["BasePlugin"]

Plugin = namedtuple("Plugin", ("name", "package", "class_name"))

PLUGINS = {
"base": BasePlugin,
"flask": FlaskPlugin,
"falcon": FalconPlugin,
"falcon-asgi": FalconAsgiPlugin,
"starlette": StarlettePlugin,
"base": Plugin(".base", __name__, "BasePlugin"),
"flask": Plugin(".flask_plugin", __name__, "FlaskPlugin"),
"falcon": Plugin(".falcon_plugin", __name__, "FalconPlugin"),
"falcon-asgi": Plugin(".falcon_plugin", __name__, "FalconAsgiPlugin"),
"starlette": Plugin(".starlette_plugin", __name__, "StarlettePlugin"),
}
22 changes: 12 additions & 10 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints

from falcon import HTTP_400, HTTP_415, HTTPError
from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN
from pydantic import ValidationError

from .._types import ModelType
Expand Down Expand Up @@ -55,13 +57,7 @@ class FalconPlugin(BasePlugin):
def __init__(self, spectree):
super().__init__(spectree)

from falcon import HTTP_400, HTTP_415, HTTPError
from falcon.routing.compiled import _FIELD_PATTERN

# used to detect falcon 3.0 request media parse error
self.FALCON_HTTP_ERROR = HTTPError
self.FALCON_MEDIA_ERROR_CODE = (HTTP_400, HTTP_415)
self.FIELD_PATTERN = _FIELD_PATTERN
# NOTE from `falcon.routing.compiled.CompiledRouterNode`
self.ESCAPE = r"[\.\(\)\[\]\?\$\*\+\^\|]"
self.ESCAPE_TO = r"\\\g<0>"
Expand Down Expand Up @@ -114,13 +110,13 @@ def parse_func(self, route: Any) -> Dict[str, Any]:
def parse_path(self, route, path_parameter_descriptions):
subs, parameters = [], []
for segment in route.uri_template.strip("/").split("/"):
matches = self.FIELD_PATTERN.finditer(segment)
matches = FALCON_FIELD_PATTERN.finditer(segment)
if not matches:
subs.append(segment)
continue

escaped = re.sub(self.ESCAPE, self.ESCAPE_TO, segment)
subs.append(self.FIELD_PATTERN.sub(self.EXTRACT, escaped))
subs.append(FALCON_FIELD_PATTERN.sub(self.EXTRACT, escaped))

for field in matches:
variable, converter, argstr = [
Expand Down Expand Up @@ -184,11 +180,17 @@ def request_validation(self, req, query, json, headers, cookies):
req.context.cookies = cookies.parse_obj(req.cookies)
try:
media = req.media
except self.FALCON_HTTP_ERROR as err:
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
if json:
try:
media = req.media
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
req.context.json = json.parse_obj(media)

def validate(
Expand Down Expand Up @@ -267,7 +269,7 @@ async def request_validation(self, req, query, json, headers, cookies):
if json:
try:
media = await req.get_media()
except self.FALCON_HTTP_ERROR as err:
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
Expand Down
12 changes: 2 additions & 10 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Callable, Optional, get_type_hints

from flask import Blueprint, abort, current_app, jsonify, make_response, request
from pydantic import BaseModel, ValidationError
from werkzeug.routing import parse_converter_args

from .._types import ModelType
from ..response import Response
Expand All @@ -13,8 +15,6 @@ class FlaskPlugin(BasePlugin):
FORM_MIMETYPE = ("application/x-www-form-urlencoded", "multipart/form-data")

def find_routes(self):
from flask import current_app

for rule in current_app.url_map.iter_rules():
if any(
str(rule).startswith(path)
Expand All @@ -40,8 +40,6 @@ def bypass(self, func, method):
return method in ["HEAD", "OPTIONS"]

def parse_func(self, route: Any):
from flask import current_app

if self.blueprint_state:
func = self.blueprint_state.app.view_functions[route.endpoint]
else:
Expand All @@ -59,8 +57,6 @@ def parse_func(self, route: Any):
yield method, func

def parse_path(self, route, path_parameter_descriptions):
from werkzeug.routing import parse_converter_args

subs = []
parameters = []

Expand Down Expand Up @@ -178,8 +174,6 @@ def validate(
*args: Any,
**kwargs: Any,
):
from flask import abort, jsonify, make_response, request

response, req_validation_error, resp_validation_error = None, None, None
try:
self.request_validation(request, query, json, headers, cookies)
Expand Down Expand Up @@ -235,8 +229,6 @@ def validate(
return response

def register_route(self, app):
from flask import Blueprint, jsonify

app.add_url_rule(
rule=self.config.spec_url,
endpoint=f"openapi_{self.config.path}",
Expand Down
11 changes: 4 additions & 7 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from typing import Any, Callable, Optional, get_type_hints

from pydantic import ValidationError
from starlette.convertors import CONVERTOR_TYPES
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import compile_path

from .._types import ModelType
from ..response import Response
Expand All @@ -15,8 +19,6 @@


def PydanticResponse(content):
from starlette.responses import JSONResponse

class _PydanticResponse(JSONResponse):
def render(self, content) -> bytes:
self._model_class = content.__class__
Expand All @@ -30,13 +32,11 @@ class StarlettePlugin(BasePlugin):

def __init__(self, spectree):
super().__init__(spectree)
from starlette.convertors import CONVERTOR_TYPES

self.conv2type = {conv: typ for typ, conv in CONVERTOR_TYPES.items()}

def register_route(self, app):
self.app = app
from starlette.responses import HTMLResponse, JSONResponse

self.app.add_route(
self.config.spec_url,
Expand Down Expand Up @@ -79,8 +79,6 @@ async def validate(
*args: Any,
**kwargs: Any,
):
from starlette.requests import Request
from starlette.responses import JSONResponse

if isinstance(args[0], Request):
instance, request = None, args[0]
Expand Down Expand Up @@ -186,7 +184,6 @@ def parse_func(self, route):
yield method, route.func

def parse_path(self, route, path_parameter_descriptions):
from starlette.routing import compile_path

_, path, variables = compile_path(route.path)
parameters = []
Expand Down
9 changes: 7 additions & 2 deletions spectree/spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from copy import deepcopy
from functools import wraps
from importlib import import_module
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -67,8 +68,12 @@ def __init__(
self.validation_error_status = validation_error_status
self.config: Configuration = Configuration.parse_obj(kwargs)
self.backend_name = backend_name
self.backend = backend(self) if backend else PLUGINS[backend_name](self)
# init
if backend:
self.backend = backend(self)
else:
plugin = PLUGINS[backend_name]
module = import_module(plugin.name, plugin.package)
self.backend = getattr(module, plugin.class_name)(self)
self.models: Dict[str, Any] = {}
if app:
self.register(app)
Expand Down
Empty file added tests/import_module/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/import_module/test_falcon_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from spectree import SpecTree

SpecTree("falcon")
SpecTree("falcon-asgi")
print("=> passed falcon plugin import test")
4 changes: 4 additions & 0 deletions tests/import_module/test_flask_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from spectree import SpecTree

SpecTree("flask")
print("=> passed flask plugin import test")
4 changes: 4 additions & 0 deletions tests/import_module/test_starlette_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from spectree import SpecTree

SpecTree("starlette")
print("=> passed starlette plugin import test")
2 changes: 1 addition & 1 deletion tests/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from spectree import Response
from spectree.config import Configuration
from spectree.models import Server, ValidationError
from spectree.plugins import FlaskPlugin
from spectree.plugins.flask_plugin import FlaskPlugin
from spectree.spec import SpecTree

from .common import get_paths
Expand Down

0 comments on commit 104db2f

Please sign in to comment.