From 104db2fa15ad9c06296ad76600c732469a5a2718 Mon Sep 17 00:00:00 2001 From: Keming Date: Sat, 15 Oct 2022 11:15:31 +0800 Subject: [PATCH] feat: import plugin using import_module (#266) * feat: import plugin using import_module Signed-off-by: Keming * exit if err Signed-off-by: Keming * fix test Signed-off-by: Keming Signed-off-by: Keming (cherry picked from commit ebcb3d664374d82e6facaea60ad625f9f9e2c47f) --- Makefile | 11 +++++++++- spectree/plugins/__init__.py | 19 ++++++++++------- spectree/plugins/falcon_plugin.py | 22 +++++++++++--------- spectree/plugins/flask_plugin.py | 12 ++--------- spectree/plugins/starlette_plugin.py | 11 ++++------ spectree/spec.py | 9 ++++++-- tests/import_module/__init__.py | 0 tests/import_module/test_falcon_plugin.py | 5 +++++ tests/import_module/test_flask_plugin.py | 4 ++++ tests/import_module/test_starlette_plugin.py | 4 ++++ tests/test_spec.py | 2 +- 11 files changed, 60 insertions(+), 39 deletions(-) create mode 100644 tests/import_module/__init__.py create mode 100644 tests/import_module/test_falcon_plugin.py create mode 100644 tests/import_module/test_flask_plugin.py create mode 100644 tests/import_module/test_starlette_plugin.py diff --git a/Makefile b/Makefile index a7cc6353..e0a41528 100644 --- a/Makefile +++ b/Makefile @@ -5,11 +5,20 @@ SOURCE_FILES=spectree tests examples setup.py install: pip install -e .[email,flask,falcon,starlette,dev] -test: +import_test: + pip install -e .[email] + for module in flask falcon starlette; do \ + pip install -U $$module; \ + bash -c "python tests/import_module/test_$${module}_plugin.py" || exit 1; \ + pip uninstall $$module -y; \ + done + +test: import_test pip install -U -e .[email,flask,falcon,starlette] pytest tests -vv -rs pip uninstall falcon email-validator -y && pip install falcon==2.0.0 pytest tests -vv -rs + doc: cd docs && make html diff --git a/spectree/plugins/__init__.py b/spectree/plugins/__init__.py index bebc018b..6796218d 100644 --- a/spectree/plugins/__init__.py +++ b/spectree/plugins/__init__.py @@ -1,12 +1,15 @@ +from collections import namedtuple + from .base import BasePlugin -from .falcon_plugin import FalconAsgiPlugin, FalconPlugin -from .flask_plugin import FlaskPlugin -from .starlette_plugin import StarlettePlugin + +__all__ = ["BasePlugin"] + +Plugin = namedtuple("Plugin", ("name", "package", "class_name")) PLUGINS = { - "base": BasePlugin, - "flask": FlaskPlugin, - "falcon": FalconPlugin, - "falcon-asgi": FalconAsgiPlugin, - "starlette": StarlettePlugin, + "base": Plugin(".base", __name__, "BasePlugin"), + "flask": Plugin(".flask_plugin", __name__, "FlaskPlugin"), + "falcon": Plugin(".falcon_plugin", __name__, "FalconPlugin"), + "falcon-asgi": Plugin(".falcon_plugin", __name__, "FalconAsgiPlugin"), + "starlette": Plugin(".starlette_plugin", __name__, "StarlettePlugin"), } diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index a9515025..918c973c 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -3,6 +3,8 @@ from functools import partial from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints +from falcon import HTTP_400, HTTP_415, HTTPError +from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN from pydantic import ValidationError from .._types import ModelType @@ -55,13 +57,7 @@ class FalconPlugin(BasePlugin): def __init__(self, spectree): super().__init__(spectree) - from falcon import HTTP_400, HTTP_415, HTTPError - from falcon.routing.compiled import _FIELD_PATTERN - - # used to detect falcon 3.0 request media parse error - self.FALCON_HTTP_ERROR = HTTPError self.FALCON_MEDIA_ERROR_CODE = (HTTP_400, HTTP_415) - self.FIELD_PATTERN = _FIELD_PATTERN # NOTE from `falcon.routing.compiled.CompiledRouterNode` self.ESCAPE = r"[\.\(\)\[\]\?\$\*\+\^\|]" self.ESCAPE_TO = r"\\\g<0>" @@ -114,13 +110,13 @@ def parse_func(self, route: Any) -> Dict[str, Any]: def parse_path(self, route, path_parameter_descriptions): subs, parameters = [], [] for segment in route.uri_template.strip("/").split("/"): - matches = self.FIELD_PATTERN.finditer(segment) + matches = FALCON_FIELD_PATTERN.finditer(segment) if not matches: subs.append(segment) continue escaped = re.sub(self.ESCAPE, self.ESCAPE_TO, segment) - subs.append(self.FIELD_PATTERN.sub(self.EXTRACT, escaped)) + subs.append(FALCON_FIELD_PATTERN.sub(self.EXTRACT, escaped)) for field in matches: variable, converter, argstr = [ @@ -184,11 +180,17 @@ def request_validation(self, req, query, json, headers, cookies): req.context.cookies = cookies.parse_obj(req.cookies) try: media = req.media - except self.FALCON_HTTP_ERROR as err: + except HTTPError as err: if err.status not in self.FALCON_MEDIA_ERROR_CODE: raise media = None if json: + try: + media = req.media + except HTTPError as err: + if err.status not in self.FALCON_MEDIA_ERROR_CODE: + raise + media = None req.context.json = json.parse_obj(media) def validate( @@ -267,7 +269,7 @@ async def request_validation(self, req, query, json, headers, cookies): if json: try: media = await req.get_media() - except self.FALCON_HTTP_ERROR as err: + except HTTPError as err: if err.status not in self.FALCON_MEDIA_ERROR_CODE: raise media = None diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 9f822a68..70d55690 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -1,6 +1,8 @@ from typing import Any, Callable, Optional, get_type_hints +from flask import Blueprint, abort, current_app, jsonify, make_response, request from pydantic import BaseModel, ValidationError +from werkzeug.routing import parse_converter_args from .._types import ModelType from ..response import Response @@ -13,8 +15,6 @@ class FlaskPlugin(BasePlugin): FORM_MIMETYPE = ("application/x-www-form-urlencoded", "multipart/form-data") def find_routes(self): - from flask import current_app - for rule in current_app.url_map.iter_rules(): if any( str(rule).startswith(path) @@ -40,8 +40,6 @@ def bypass(self, func, method): return method in ["HEAD", "OPTIONS"] def parse_func(self, route: Any): - from flask import current_app - if self.blueprint_state: func = self.blueprint_state.app.view_functions[route.endpoint] else: @@ -59,8 +57,6 @@ def parse_func(self, route: Any): yield method, func def parse_path(self, route, path_parameter_descriptions): - from werkzeug.routing import parse_converter_args - subs = [] parameters = [] @@ -178,8 +174,6 @@ def validate( *args: Any, **kwargs: Any, ): - from flask import abort, jsonify, make_response, request - response, req_validation_error, resp_validation_error = None, None, None try: self.request_validation(request, query, json, headers, cookies) @@ -235,8 +229,6 @@ def validate( return response def register_route(self, app): - from flask import Blueprint, jsonify - app.add_url_rule( rule=self.config.spec_url, endpoint=f"openapi_{self.config.path}", diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index 1b8b6e88..d5d2752a 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -5,6 +5,10 @@ from typing import Any, Callable, Optional, get_type_hints from pydantic import ValidationError +from starlette.convertors import CONVERTOR_TYPES +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse +from starlette.routing import compile_path from .._types import ModelType from ..response import Response @@ -15,8 +19,6 @@ def PydanticResponse(content): - from starlette.responses import JSONResponse - class _PydanticResponse(JSONResponse): def render(self, content) -> bytes: self._model_class = content.__class__ @@ -30,13 +32,11 @@ class StarlettePlugin(BasePlugin): def __init__(self, spectree): super().__init__(spectree) - from starlette.convertors import CONVERTOR_TYPES self.conv2type = {conv: typ for typ, conv in CONVERTOR_TYPES.items()} def register_route(self, app): self.app = app - from starlette.responses import HTMLResponse, JSONResponse self.app.add_route( self.config.spec_url, @@ -79,8 +79,6 @@ async def validate( *args: Any, **kwargs: Any, ): - from starlette.requests import Request - from starlette.responses import JSONResponse if isinstance(args[0], Request): instance, request = None, args[0] @@ -186,7 +184,6 @@ def parse_func(self, route): yield method, route.func def parse_path(self, route, path_parameter_descriptions): - from starlette.routing import compile_path _, path, variables = compile_path(route.path) parameters = [] diff --git a/spectree/spec.py b/spectree/spec.py index 0a22b45b..84ea332e 100644 --- a/spectree/spec.py +++ b/spectree/spec.py @@ -1,6 +1,7 @@ from collections import defaultdict from copy import deepcopy from functools import wraps +from importlib import import_module from typing import ( Any, Callable, @@ -67,8 +68,12 @@ def __init__( self.validation_error_status = validation_error_status self.config: Configuration = Configuration.parse_obj(kwargs) self.backend_name = backend_name - self.backend = backend(self) if backend else PLUGINS[backend_name](self) - # init + if backend: + self.backend = backend(self) + else: + plugin = PLUGINS[backend_name] + module = import_module(plugin.name, plugin.package) + self.backend = getattr(module, plugin.class_name)(self) self.models: Dict[str, Any] = {} if app: self.register(app) diff --git a/tests/import_module/__init__.py b/tests/import_module/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/import_module/test_falcon_plugin.py b/tests/import_module/test_falcon_plugin.py new file mode 100644 index 00000000..82837308 --- /dev/null +++ b/tests/import_module/test_falcon_plugin.py @@ -0,0 +1,5 @@ +from spectree import SpecTree + +SpecTree("falcon") +SpecTree("falcon-asgi") +print("=> passed falcon plugin import test") diff --git a/tests/import_module/test_flask_plugin.py b/tests/import_module/test_flask_plugin.py new file mode 100644 index 00000000..6afebb60 --- /dev/null +++ b/tests/import_module/test_flask_plugin.py @@ -0,0 +1,4 @@ +from spectree import SpecTree + +SpecTree("flask") +print("=> passed flask plugin import test") diff --git a/tests/import_module/test_starlette_plugin.py b/tests/import_module/test_starlette_plugin.py new file mode 100644 index 00000000..5ae74c1f --- /dev/null +++ b/tests/import_module/test_starlette_plugin.py @@ -0,0 +1,4 @@ +from spectree import SpecTree + +SpecTree("starlette") +print("=> passed starlette plugin import test") diff --git a/tests/test_spec.py b/tests/test_spec.py index 7b2e864a..9b7fefac 100644 --- a/tests/test_spec.py +++ b/tests/test_spec.py @@ -11,7 +11,7 @@ from spectree import Response from spectree.config import Configuration from spectree.models import Server, ValidationError -from spectree.plugins import FlaskPlugin +from spectree.plugins.flask_plugin import FlaskPlugin from spectree.spec import SpecTree from .common import get_paths