Skip to content

Commit

Permalink
Merge pull request #251 from DisruptiveLabs/before_v1_use_get_type_hints
Browse files Browse the repository at this point in the history
Use typing.get_type_hints instead of .__annotations__
  • Loading branch information
kemingy authored Aug 7, 2022
2 parents 3fa67be + 7ee06cc commit 668c88a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
8 changes: 5 additions & 3 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import re
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints

from pydantic import ValidationError

Expand Down Expand Up @@ -212,8 +212,9 @@ def validate(
try:
self.request_validation(_req, query, json, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)

except ValidationError as err:
Expand Down Expand Up @@ -293,8 +294,9 @@ async def validate(
try:
await self.request_validation(_req, query, json, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
if annotations.get(name):
kwargs[name] = getattr(_req.context, name)

except ValidationError as err:
Expand Down
5 changes: 3 additions & 2 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, get_type_hints

from pydantic import BaseModel, ValidationError

Expand Down Expand Up @@ -184,8 +184,9 @@ def validate(
try:
self.request_validation(request, query, json, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
req_validation_error = err
Expand Down
5 changes: 3 additions & 2 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import namedtuple
from functools import partial
from json import JSONDecodeError
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, get_type_hints

from pydantic import ValidationError

Expand Down Expand Up @@ -93,8 +93,9 @@ async def validate(
try:
await self.request_validation(request, query, json, headers, cookies)
if self.config.annotations:
annotations = get_type_hints(func)
for name in ("query", "json", "headers", "cookies"):
if func.__annotations__.get(name):
if annotations.get(name):
kwargs[name] = getattr(request.context, name)
except ValidationError as err:
req_validation_error = err
Expand Down
20 changes: 15 additions & 5 deletions spectree/spec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from collections import defaultdict
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type
from typing import (
Any,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Type,
get_type_hints,
)

from ._types import FunctionDecorator, ModelType
from .config import Configuration, ModeEnum
Expand Down Expand Up @@ -188,14 +197,15 @@ async def async_validate(*args: Any, **kwargs: Any):
)

if self.config.annotations:
annotations = get_type_hints(func)
nonlocal query
query = func.__annotations__.get("query", query)
query = annotations.get("query", query)
nonlocal json
json = func.__annotations__.get("json", json)
json = annotations.get("json", json)
nonlocal headers
headers = func.__annotations__.get("headers", headers)
headers = annotations.get("headers", headers)
nonlocal cookies
cookies = func.__annotations__.get("cookies", cookies)
cookies = annotations.get("cookies", cookies)

# register
for name, model in zip(
Expand Down

0 comments on commit 668c88a

Please sign in to comment.