Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: import plugin using import_module #266

Merged
merged 3 commits into from
Oct 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@ SOURCE_FILES=spectree tests examples setup.py
install:
pip install -e .[email,flask,falcon,starlette,dev]

test:
import_test:
pip install -e .[email]
for module in flask falcon starlette; do \
pip install -U $$module; \
bash -c "python tests/import_module/test_$${module}_plugin.py" || exit 1; \
pip uninstall $$module -y; \
done

test: import_test
pip install -U -e .[email,flask,falcon,starlette]
pytest tests -vv -rs

doc:
cd docs && make html

Expand Down
19 changes: 11 additions & 8 deletions spectree/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from collections import namedtuple

from .base import BasePlugin
from .falcon_plugin import FalconAsgiPlugin, FalconPlugin
from .flask_plugin import FlaskPlugin
from .starlette_plugin import StarlettePlugin

__all__ = ["BasePlugin"]

Plugin = namedtuple("Plugin", ("name", "package", "class_name"))

PLUGINS = {
"base": BasePlugin,
"flask": FlaskPlugin,
"falcon": FalconPlugin,
"falcon-asgi": FalconAsgiPlugin,
"starlette": StarlettePlugin,
"base": Plugin(".base", __name__, "BasePlugin"),
"flask": Plugin(".flask_plugin", __name__, "FlaskPlugin"),
"falcon": Plugin(".falcon_plugin", __name__, "FalconPlugin"),
"falcon-asgi": Plugin(".falcon_plugin", __name__, "FalconAsgiPlugin"),
"starlette": Plugin(".starlette_plugin", __name__, "StarlettePlugin"),
}
17 changes: 7 additions & 10 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints

from falcon import HTTP_400, HTTP_415, HTTPError
from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN
from pydantic import ValidationError

from .._types import ModelType
Expand Down Expand Up @@ -51,12 +53,7 @@ class FalconPlugin(BasePlugin):
def __init__(self, spectree):
super().__init__(spectree)

from falcon import HTTP_400, HTTP_415, HTTPError
from falcon.routing.compiled import _FIELD_PATTERN

self.FALCON_HTTP_ERROR = HTTPError
self.FALCON_MEDIA_ERROR_CODE = (HTTP_400, HTTP_415)
self.FIELD_PATTERN = _FIELD_PATTERN
# NOTE from `falcon.routing.compiled.CompiledRouterNode`
self.ESCAPE = r"[\.\(\)\[\]\?\$\*\+\^\|]"
self.ESCAPE_TO = r"\\\g<0>"
Expand Down Expand Up @@ -109,13 +106,13 @@ def parse_func(self, route: Any) -> Dict[str, Any]:
def parse_path(self, route, path_parameter_descriptions):
subs, parameters = [], []
for segment in route.uri_template.strip("/").split("/"):
matches = self.FIELD_PATTERN.finditer(segment)
matches = FALCON_FIELD_PATTERN.finditer(segment)
if not matches:
subs.append(segment)
continue

escaped = re.sub(self.ESCAPE, self.ESCAPE_TO, segment)
subs.append(self.FIELD_PATTERN.sub(self.EXTRACT, escaped))
subs.append(FALCON_FIELD_PATTERN.sub(self.EXTRACT, escaped))

for field in matches:
variable, converter, argstr = [
Expand Down Expand Up @@ -180,7 +177,7 @@ def request_validation(self, req, query, json, form, headers, cookies):
if json:
try:
media = req.media
except self.FALCON_HTTP_ERROR as err:
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
Expand Down Expand Up @@ -268,15 +265,15 @@ async def request_validation(self, req, query, json, form, headers, cookies):
if json:
try:
media = await req.get_media()
except self.FALCON_HTTP_ERROR as err:
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
media = None
req.context.json = json.parse_obj(media)
if form:
try:
form_data = await req.get_media()
except self.FALCON_HTTP_ERROR as err:
except HTTPError as err:
if err.status not in self.FALCON_MEDIA_ERROR_CODE:
raise
req.context.form = None
Expand Down
12 changes: 2 additions & 10 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Callable, Mapping, Optional, Tuple, get_type_hints

from flask import Blueprint, abort, current_app, jsonify, make_response, request
from pydantic import BaseModel, ValidationError
from werkzeug.routing import parse_converter_args

from .._types import ModelType
from ..response import Response
Expand All @@ -12,8 +14,6 @@ class FlaskPlugin(BasePlugin):
blueprint_state = None

def find_routes(self):
from flask import current_app

for rule in current_app.url_map.iter_rules():
if any(
str(rule).startswith(path)
Expand All @@ -39,8 +39,6 @@ def bypass(self, func, method):
return method in ["HEAD", "OPTIONS"]

def parse_func(self, route: Any):
from flask import current_app

if self.blueprint_state:
func = self.blueprint_state.app.view_functions[route.endpoint]
else:
Expand All @@ -62,8 +60,6 @@ def parse_path(
route: Optional[Mapping[str, str]],
path_parameter_descriptions: Optional[Mapping[str, str]],
) -> Tuple[str, list]:
from werkzeug.routing import parse_converter_args

subs = []
parameters = []

Expand Down Expand Up @@ -183,8 +179,6 @@ def validate(
*args: Any,
**kwargs: Any,
):
from flask import abort, jsonify, make_response, request

response, req_validation_error, resp_validation_error = None, None, None
try:
self.request_validation(request, query, json, form, headers, cookies)
Expand Down Expand Up @@ -240,8 +234,6 @@ def validate(
return response

def register_route(self, app):
from flask import Blueprint, jsonify

app.add_url_rule(
rule=self.config.spec_url,
endpoint=f"openapi_{self.config.path}",
Expand Down
11 changes: 4 additions & 7 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from typing import Any, Callable, Optional, get_type_hints

from pydantic import ValidationError
from starlette.convertors import CONVERTOR_TYPES
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import compile_path

from .._types import ModelType
from ..response import Response
Expand All @@ -15,8 +19,6 @@


def PydanticResponse(content):
from starlette.responses import JSONResponse

class _PydanticResponse(JSONResponse):
def render(self, content) -> bytes:
self._model_class = content.__class__
Expand All @@ -30,13 +32,11 @@ class StarlettePlugin(BasePlugin):

def __init__(self, spectree):
super().__init__(spectree)
from starlette.convertors import CONVERTOR_TYPES

self.conv2type = {conv: typ for typ, conv in CONVERTOR_TYPES.items()}

def register_route(self, app):
self.app = app
from starlette.responses import HTMLResponse, JSONResponse

self.app.add_route(
self.config.spec_url,
Expand Down Expand Up @@ -86,8 +86,6 @@ async def validate(
*args: Any,
**kwargs: Any,
):
from starlette.requests import Request
from starlette.responses import JSONResponse

if isinstance(args[0], Request):
instance, request = None, args[0]
Expand Down Expand Up @@ -193,7 +191,6 @@ def parse_func(self, route):
yield method, route.func

def parse_path(self, route, path_parameter_descriptions):
from starlette.routing import compile_path

_, path, variables = compile_path(route.path)
parameters = []
Expand Down
9 changes: 7 additions & 2 deletions spectree/spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from copy import deepcopy
from functools import wraps
from importlib import import_module
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -67,8 +68,12 @@ def __init__(
self.validation_error_status = validation_error_status
self.config: Configuration = Configuration.parse_obj(kwargs)
self.backend_name = backend_name
self.backend = backend(self) if backend else PLUGINS[backend_name](self)
# init
if backend:
self.backend = backend(self)
else:
plugin = PLUGINS[backend_name]
module = import_module(plugin.name, plugin.package)
self.backend = getattr(module, plugin.class_name)(self)
self.models: Dict[str, Any] = {}
if app:
self.register(app)
Expand Down
Empty file added tests/import_module/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/import_module/test_falcon_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from spectree import SpecTree

SpecTree("falcon")
SpecTree("falcon-asgi")
print("=> passed falcon plugin import test")
4 changes: 4 additions & 0 deletions tests/import_module/test_flask_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from spectree import SpecTree

SpecTree("flask")
print("=> passed flask plugin import test")
4 changes: 4 additions & 0 deletions tests/import_module/test_starlette_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from spectree import SpecTree

SpecTree("starlette")
print("=> passed starlette plugin import test")
2 changes: 1 addition & 1 deletion tests/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from spectree import Response
from spectree.config import Configuration
from spectree.models import Server, ValidationError
from spectree.plugins import FlaskPlugin
from spectree.plugins.flask_plugin import FlaskPlugin
from spectree.spec import SpecTree

from .common import get_paths
Expand Down