From b82613a39945342fd2bc7368c335ea247b84dca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Trval?= Date: Tue, 3 Oct 2023 10:34:26 +0200 Subject: [PATCH] fix: flask import of _endpoint_from_view_func is now resolved new util function import_check_view_func #567 --- CHANGELOG.rst | 2 +- flask_restx/api.py | 11 +++++------ flask_restx/utils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 14 ++++++++++++++ 4 files changed, 66 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 73eea75a..7013ddaf 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -35,7 +35,7 @@ Bug Fixes :: * Fixing test as HTTP Header MIMEAccept expects quality-factor number in form of `X.X` (#547) [chipndell] - + * Fixing flask 3.0+ compatibility of `flask.scaffold import _endpoint_from_view_func` Import error. (#565) [Ryu-CZ] .. _enhancements-1.2.0: diff --git a/flask_restx/api.py b/flask_restx/api.py index 5996dd59..bd0413dd 100644 --- a/flask_restx/api.py +++ b/flask_restx/api.py @@ -14,10 +14,6 @@ from flask import url_for, request, current_app from flask import make_response as original_flask_make_response -try: - from flask.helpers import _endpoint_from_view_func -except ImportError: - from flask.scaffold import _endpoint_from_view_func from flask.signals import got_request_exception from jsonschema import RefResolver @@ -45,10 +41,13 @@ from .postman import PostmanCollectionV1 from .resource import Resource from .swagger import Swagger -from .utils import default_id, camel_to_dash, unpack +from .utils import default_id, camel_to_dash, unpack, import_check_view_func from .representations import output_json from ._http import HTTPStatus +endpoint_from_view_func = import_check_view_func() + + RE_RULES = re.compile("(<.*>)") # List headers that should never be handled by Flask-RESTX @@ -850,7 +849,7 @@ def _blueprint_setup_add_url_rule_patch( rule = blueprint_setup.url_prefix + rule options.setdefault("subdomain", blueprint_setup.subdomain) if endpoint is None: - endpoint = _endpoint_from_view_func(view_func) + endpoint = endpoint_from_view_func(view_func) defaults = blueprint_setup.url_defaults if "defaults" in options: defaults = dict(defaults, **options.pop("defaults")) diff --git a/flask_restx/utils.py b/flask_restx/utils.py index 809a29b3..35dec2ae 100644 --- a/flask_restx/utils.py +++ b/flask_restx/utils.py @@ -1,4 +1,6 @@ import re +import warnings +import typing from collections import OrderedDict from copy import deepcopy @@ -20,6 +22,10 @@ ) +class FlaskCompatibilityWarning(DeprecationWarning): + pass + + def merge(first, second): """ Recursively merges two dictionaries. @@ -118,3 +124,43 @@ def unpack(response, default_code=HTTPStatus.OK): return data, code or default_code, headers else: raise ValueError("Too many response values") + + +def to_view_name(view_func: typing.Callable) -> str: + """Helper that returns the default endpoint for a given + function. This always is the function name. + + Note: copy of simple flask internal helper + """ + assert view_func is not None, "expected view func if endpoint is not provided." + return view_func.__name__ + + +def import_check_view_func(): + """ + Resolve import flask _endpoint_from_view_func. + + Show warning if function cannot be found and provide copy of last known implementation. + + Note: This helper method exists because reoccurring problem with flask function, but + actual method body remaining the same in each flask version. + """ + import importlib.metadata + + flask_version = importlib.metadata.version("flask").split(".") + try: + if flask_version[0] == "1": + from flask.helpers import _endpoint_from_view_func + elif flask_version[0] == "2": + from flask.scaffold import _endpoint_from_view_func + elif flask_version[0] == "3": + from flask.sansio.scaffold import _endpoint_from_view_func + else: + warnings.simplefilter("once", FlaskCompatibilityWarning) + _endpoint_from_view_func = None + except ImportError: + warnings.simplefilter("once", FlaskCompatibilityWarning) + _endpoint_from_view_func = None + if _endpoint_from_view_func is None: + _endpoint_from_view_func = to_view_name + return _endpoint_from_view_func diff --git a/tests/test_utils.py b/tests/test_utils.py index d98d68d0..fe3a1adb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -98,3 +98,17 @@ def test_value_headers_default_code(self): def test_too_many_values(self): with pytest.raises(ValueError): utils.unpack((None, None, None, None)) + + +class ToViewNameTest(object): + def test_none(self): + with pytest.raises(AssertionError): + _ = utils.to_view_name(None) + + def test_name(self): + assert self.test_none == self.test_none.__name__ + + +class ImportCheckViewFuncTest(object): + def test_callable(self): + assert callable(utils.import_check_view_func())