From 8f970761e2f87489bc1cff73eb21cb0d637446c0 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 21 Jun 2023 15:45:30 +1000 Subject: [PATCH 1/5] feat(internal): add msgspec signature model. Linting fixes. Update tests/unit/test_kwargs/test_path_params.py fix signature namespace issue support min_length and max_length support min_length and max_length handle constraints on union types fix error message tests chore(signature-model): remove pydantic and attrs signature models chore(signature model): fix python 3.8 compat --- .pre-commit-config.yaml | 14 +- litestar/_kwargs/dependencies.py | 2 +- litestar/_signature/__init__.py | 6 +- litestar/_signature/model.py | 230 +++++++++ litestar/_signature/models/__init__.py | 3 - .../models/attrs_signature_model.py | 471 ------------------ litestar/_signature/models/base.py | 160 ------ .../models/pydantic_signature_model.py | 184 ------- litestar/_signature/utils.py | 135 +---- litestar/app.py | 14 +- litestar/handlers/base.py | 9 +- .../handlers/websocket_handlers/listener.py | 5 +- litestar/serialization.py | 45 +- litestar/testing/helpers.py | 4 - litestar/utils/typing.py | 1 + poetry.lock | 31 +- pyproject.toml | 7 +- .../e2e/test_routing/test_path_resolution.py | 6 +- tests/unit/test_app.py | 2 +- tests/unit/test_dto/test_attrs_fail.py | 21 - .../test_dto/test_factory/test_integration.py | 2 +- tests/unit/test_dto/test_integration.py | 2 +- .../test_http_handlers/test_to_response.py | 5 +- tests/unit/test_openapi/test_parameters.py | 5 +- tests/unit/test_openapi/test_schema.py | 5 +- tests/unit/test_params.py | 35 +- .../test_attrs_signature_modelling.py | 129 ----- tests/unit/test_signature/test_parsing.py | 411 ++------------- tests/unit/test_signature/test_utils.py | 27 - tests/unit/test_signature/test_validation.py | 319 ++++++++++++ 30 files changed, 694 insertions(+), 1596 deletions(-) create mode 100644 litestar/_signature/model.py delete mode 100644 litestar/_signature/models/__init__.py delete mode 100644 litestar/_signature/models/attrs_signature_model.py delete mode 100644 litestar/_signature/models/base.py delete mode 100644 litestar/_signature/models/pydantic_signature_model.py delete mode 100644 tests/unit/test_dto/test_attrs_fail.py delete mode 100644 tests/unit/test_signature/test_attrs_signature_modelling.py delete mode 100644 tests/unit/test_signature/test_utils.py create mode 100644 tests/unit/test_signature/test_validation.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6cf0cf609..ea2863bd0f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -70,7 +70,7 @@ repos: exclude: "test_apps|tools|docs|tests/examples|tests/docker_service_fixtures" additional_dependencies: [ - polyfactory, + msgspec>=0.17.0, aiosqlite, annotated_types, async_timeout, @@ -81,7 +81,6 @@ repos: beanie, beautifulsoup4, brotli, - cattrs, click, fakeredis>=2.10.2, fast-query-parsers, @@ -92,13 +91,14 @@ repos: jsbeautifier, mako, mongomock_motor, - msgspec, multidict, opentelemetry-instrumentation-asgi, opentelemetry-sdk, oracledb, piccolo, picologging, + polyfactory, + prometheus_client, psycopg, pydantic, pytest, @@ -121,7 +121,6 @@ repos: types-pyyaml, types-redis, uvicorn, - prometheus_client, ] - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.317 @@ -130,7 +129,7 @@ repos: exclude: "test_apps|tools|docs|_openapi|tests/examples|tests/docker_service_fixtures" additional_dependencies: [ - polyfactory, + msgspec>=0.17.0, aiosqlite, annotated_types, async_timeout, @@ -141,7 +140,6 @@ repos: beanie, beautifulsoup4, brotli, - cattrs, click, fakeredis>=2.10.2, fast-query-parsers, @@ -152,13 +150,14 @@ repos: jsbeautifier, mako, mongomock_motor, - msgspec, multidict, opentelemetry-instrumentation-asgi, opentelemetry-sdk, oracledb, piccolo, picologging, + polyfactory, + prometheus_client, psycopg, pydantic, pytest, @@ -181,7 +180,6 @@ repos: types-pyyaml, types-redis, uvicorn, - prometheus_client, ] - repo: local hooks: diff --git a/litestar/_kwargs/dependencies.py b/litestar/_kwargs/dependencies.py index a5bea1d1f8..17b9d8bb06 100644 --- a/litestar/_kwargs/dependencies.py +++ b/litestar/_kwargs/dependencies.py @@ -3,7 +3,7 @@ from inspect import isasyncgen, isgenerator from typing import TYPE_CHECKING, Any -from litestar._signature.utils import get_signature_model +from litestar._signature import get_signature_model from litestar.utils.compat import async_next __all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency") diff --git a/litestar/_signature/__init__.py b/litestar/_signature/__init__.py index 5307d1f9fc..5b26c373cd 100644 --- a/litestar/_signature/__init__.py +++ b/litestar/_signature/__init__.py @@ -1,4 +1,4 @@ -from .models.base import SignatureModel -from .utils import create_signature_model, get_signature_model +from .model import SignatureModel +from .utils import get_signature_model -__all__ = ("create_signature_model", "SignatureModel", "get_signature_model") +__all__ = ("SignatureModel", "get_signature_model") diff --git a/litestar/_signature/model.py b/litestar/_signature/model.py new file mode 100644 index 0000000000..e52aa8806c --- /dev/null +++ b/litestar/_signature/model.py @@ -0,0 +1,230 @@ +# ruff: noqa: UP006 +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Literal, Optional, Sequence, Set, TypedDict, Union, cast + +from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct +from msgspec.structs import asdict +from pydantic import ValidationError as PydanticValidationError +from typing_extensions import Annotated + +from litestar._signature.utils import create_type_overrides, validate_signature_dependencies +from litestar.enums import ScopeType +from litestar.exceptions import InternalServerException, ValidationException +from litestar.params import DependencyKwarg, KwargDefinition, ParameterKwarg +from litestar.serialization import dec_hook +from litestar.typing import FieldDefinition # noqa: TCH +from litestar.utils import make_non_optional_union +from litestar.utils.dataclass import simple_asdict +from litestar.utils.typing import unwrap_union + +if TYPE_CHECKING: + from typing_extensions import NotRequired + + from litestar.connection import ASGIConnection + from litestar.types import AnyCallable + from litestar.utils.signature import ParsedSignature + + +__all__ = ( + "ErrorMessage", + "SignatureModel", +) + + +class ErrorMessage(TypedDict): + # key may not be set in some cases, like when a query param is set but + # doesn't match the required length during `attrs` validation + # in this case, we don't show a key at all as it will be empty + key: NotRequired[str] + message: str + source: NotRequired[Literal["cookie", "body", "header", "query"]] + + +MSGSPEC_CONSTRAINT_FIELDS = ( + "gt", + "ge", + "lt", + "le", + "multiple_of", + "pattern", + "min_length", + "max_length", +) + +ERR_RE = re.compile(r"`\$\.(.+)`$") + + +class SignatureModel(Struct): + """Model that represents a function signature that uses a msgspec specific type or types.""" + + # NOTE: we have to use Set and Dict here because python 3.8 goes haywire if we use 'set' and 'dict' + dependency_name_set: ClassVar[Set[str]] + fields: ClassVar[Dict[str, FieldDefinition]] + return_annotation: ClassVar[Any] + + @classmethod + def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception: + """Create an exception class - either a ValidationException or an InternalServerException, depending on whether + the failure is in client provided values or injected dependencies. + + Args: + connection: An ASGI connection instance. + messages: A list of error messages. + + Returns: + An Exception + """ + method = connection.method if hasattr(connection, "method") else ScopeType.WEBSOCKET # pyright: ignore + if client_errors := [ + err_message + for err_message in messages + if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message + ]: + return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors) + return InternalServerException() + + @classmethod + def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage: + """Build an error message. + + Args: + keys: A list of keys. + exc_msg: A message. + connection: An ASGI connection instance. + + Returns: + An ErrorMessage + """ + + message: ErrorMessage = {"message": exc_msg.split(" - ")[0]} + + if not keys: + return message + + message["key"] = key = ".".join(keys) + + if key in connection.query_params: + message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query") + + elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg): + if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie: + source = "cookie" + elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header: + source = "header" + else: + source = "query" + message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source) + + return message + + @classmethod + def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: + """Extract values from the connection instance and return a dict of parsed values. + + Args: + connection: The ASGI connection instance. + **kwargs: A dictionary of kwargs. + + Raises: + ValidationException: If validation failed. + InternalServerException: If another exception has been raised. + + Returns: + A dictionary of parsed values + """ + messages: list[ErrorMessage] = [] + try: + return convert(kwargs, cls, strict=False, dec_hook=dec_hook).to_dict() + except PydanticValidationError as e: + for exc in e.errors(): + keys = [str(loc) for loc in exc["loc"]] + message = cls._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection) + messages.append(message) + raise cls._create_exception(messages=messages, connection=connection) from e + except ValidationError as e: + match = ERR_RE.search(str(e)) + keys = [str(match.group(1)) if match else "n/a"] + message = cls._build_error_message(keys=keys, exc_msg=str(e), connection=connection) + messages.append(message) + raise cls._create_exception(messages=messages, connection=connection) from e + + def to_dict(self) -> dict[str, Any]: + """Normalize access to the signature model's dictionary method, because different backends use different methods + for this. + + Returns: A dictionary of string keyed values. + """ + return asdict(self) + + @classmethod + def create( + cls, + dependency_name_set: set[str], + fn: AnyCallable, + parsed_signature: ParsedSignature, + has_data_dto: bool = False, + ) -> type[SignatureModel]: + fn_name = ( + fn_name if (fn_name := getattr(fn, "__name__", "anonymous")) and fn_name != "" else "anonymous" + ) + + dependency_names = validate_signature_dependencies( + dependency_name_set=dependency_name_set, fn_name=fn_name, parsed_signature=parsed_signature + ) + type_overrides = create_type_overrides(parsed_signature, has_data_dto) + + struct_fields: list[tuple[str, Any, Any]] = [] + + for field_definition in parsed_signature.parameters.values(): + annotation = type_overrides.get(field_definition.name, field_definition.annotation) + + if isinstance(field_definition.kwarg_definition, KwargDefinition): + meta_kwargs: dict[str, Any] = {"extra": {}} + + kwarg_definition = simple_asdict(field_definition.kwarg_definition, exclude_empty=True) + if min_items := kwarg_definition.pop("min_items", None): + meta_kwargs["min_length"] = min_items + if max_items := kwarg_definition.pop("max_items", None): + meta_kwargs["max_length"] = max_items + + for k, v in kwarg_definition.items(): + if hasattr(Meta, k) and v is not None: + meta_kwargs[k] = v + else: + meta_kwargs["extra"][k] = v + + meta = Meta(**meta_kwargs) + if field_definition.is_optional: + annotation = Optional[Annotated[make_non_optional_union(annotation), meta]] + elif field_definition.is_union and meta_kwargs.keys() & MSGSPEC_CONSTRAINT_FIELDS: + # unwrap inner types of a union and apply constraints to each individual type + # see https://github.com/jcrist/msgspec/issues/447 + annotation = Union[ + tuple(Annotated[inner_type, meta] for inner_type in unwrap_union(annotation)) # pyright: ignore + ] + else: + annotation = Annotated[annotation, meta] + + elif ( + isinstance(field_definition.kwarg_definition, DependencyKwarg) + and field_definition.kwarg_definition.skip_validation + ): + annotation = Any + + default = field_definition.default if field_definition.has_default else NODEFAULT + struct_fields.append((field_definition.name, annotation, default)) + + return defstruct( # type:ignore[return-value] + f"{fn_name}_signature_model", + struct_fields, + bases=(cls,), + module=getattr(fn, "__module__", None), + namespace={ + "return_annotation": parsed_signature.return_type.annotation, + "dependency_name_set": dependency_names, + "fields": parsed_signature.parameters, + }, + kw_only=True, + ) diff --git a/litestar/_signature/models/__init__.py b/litestar/_signature/models/__init__.py deleted file mode 100644 index b4e8a46cbf..0000000000 --- a/litestar/_signature/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import SignatureModel - -__all__ = ("SignatureModel",) diff --git a/litestar/_signature/models/attrs_signature_model.py b/litestar/_signature/models/attrs_signature_model.py deleted file mode 100644 index 81eea1e285..0000000000 --- a/litestar/_signature/models/attrs_signature_model.py +++ /dev/null @@ -1,471 +0,0 @@ -from __future__ import annotations - -import re -import traceback -from dataclasses import asdict, replace -from datetime import date, datetime, time, timedelta, timezone -from functools import lru_cache, partial -from pathlib import PurePath -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Union, - cast, -) -from uuid import UUID - -from _decimal import Decimal -from typing_extensions import get_args - -from litestar._signature.models.base import ErrorMessage, SignatureModel -from litestar.connection import ASGIConnection, Request, WebSocket -from litestar.datastructures import ImmutableState, MultiDict, State, UploadFile -from litestar.exceptions import MissingDependencyException -from litestar.params import DependencyKwarg, KwargDefinition -from litestar.types import Empty -from litestar.utils.predicates import is_optional_union, is_union -from litestar.utils.typing import get_origin_or_inner_type, make_non_optional_union, unwrap_union - -try: - import attr - import attrs - import cattrs -except ImportError as e: - raise MissingDependencyException("attrs") from e - -try: - from dateutil.parser import parse -except ImportError as e: - raise MissingDependencyException("python-dateutil", "attrs") from e - -try: - from pytimeparse.timeparse import timeparse -except ImportError as e: - raise MissingDependencyException("pytimeparse", "attrs") from e - -if TYPE_CHECKING: - from litestar.utils.signature import ParsedSignature - -__all__ = ("AttrsSignatureModel",) -key_re = re.compile("@ (attribute|index) (.*)|'(.*)'") -TRUE_SET = {"1", "true", "on", "t", "y", "yes"} -FALSE_SET = {"0", "false", "off", "f", "n", "no"} - -try: - import pydantic - - def _structure_base_model(value: Any, cls: type[pydantic.BaseModel]) -> pydantic.BaseModel: - return value if isinstance(value, pydantic.BaseModel) else cls(**value) - - pydantic_hooks: list[tuple[type[Any], Callable[[Any, type[Any]], Any]]] = [ - (pydantic.BaseModel, _structure_base_model), - ] -except ImportError: - pydantic_hooks = [] - - -StructureException = Union[ - cattrs.ClassValidationError, cattrs.IterableValidationError, ValueError, TypeError, AttributeError -] - - -def _pass_through_structure_hook(value: Any, _: type[Any]) -> Any: - return value - - -def _pass_through_unstructure_hook(value: Any) -> Any: - return value - - -def _structure_bool(value: Any, _: type[bool]) -> bool: - if isinstance(value, bytes): - value = value.decode("utf-8").lower() - - if isinstance(value, str): - value = value.lower() - - if value == 0 or value in FALSE_SET: - return False - - if value == 1 or value in TRUE_SET: - return True - - raise ValueError(f"Cannot convert {value} to bool") - - -def _structure_datetime(value: Any, cls: type[datetime]) -> datetime: - if isinstance(value, datetime): - return value - - try: - return cls.fromtimestamp(float(value), tz=timezone.utc) - except (ValueError, TypeError): - pass - - return parse(value) - - -def _structure_date(value: Any, cls: type[date]) -> date: - if isinstance(value, date) and not isinstance(value, datetime): - return value - - if isinstance(value, (float, int, Decimal)): - return datetime.fromtimestamp(float(value), tz=timezone.utc).date() - - dt = _structure_datetime(value=value, cls=datetime) - return cls(year=dt.year, month=dt.month, day=dt.day) - - -def _structure_time(value: Any, cls: type[time]) -> time: - if isinstance(value, time): - return value - - if isinstance(value, str): - return cls.fromisoformat(value) - - dt = _structure_datetime(value=value, cls=datetime) - return cls(hour=dt.hour, minute=dt.minute, second=dt.second, microsecond=dt.microsecond, tzinfo=dt.tzinfo) - - -def _structure_timedelta(value: Any, cls: type[timedelta]) -> timedelta: - if isinstance(value, timedelta): - return value - if isinstance(value, (float, int, Decimal)): - return cls(seconds=int(value)) - return cls(seconds=timeparse(value)) # pyright: ignore - - -def _structure_decimal(value: Any, cls: type[Decimal]) -> Decimal: - return cls(str(value)) - - -def _structure_path(value: Any, cls: type[PurePath]) -> PurePath: - return cls(str(value)) - - -def _structure_uuid(value: Any, cls: type[UUID]) -> UUID: - return value if isinstance(value, UUID) else cls(str(value)) - - -def _structure_multidict(value: Any, cls: type[MultiDict]) -> MultiDict: - return cls(value) - - -def _structure_str(value: Any, cls: type[str]) -> str: - # see: https://github.com/python-attrs/cattrs/issues/26#issuecomment-358594015 - if value is None: - raise ValueError - return cls(value) - - -hooks: list[tuple[type[Any], Callable[[Any, type[Any]], Any]]] = [ - (ASGIConnection, _pass_through_structure_hook), - (Decimal, _structure_decimal), - (ImmutableState, _pass_through_structure_hook), - (MultiDict, _structure_multidict), - (PurePath, _structure_path), - (Request, _pass_through_structure_hook), - (State, _pass_through_structure_hook), - (UUID, _structure_uuid), - (UploadFile, _pass_through_structure_hook), - (WebSocket, _pass_through_structure_hook), - (bool, _structure_bool), - (date, _structure_date), - (datetime, _structure_datetime), - (str, _structure_str), - (time, _structure_time), - (timedelta, _structure_timedelta), - *pydantic_hooks, -] - - -def _create_default_structuring_hooks( - converter: cattrs.Converter, -) -> tuple[Callable, Callable]: - """Create scoped default hooks for a given converter. - - Notes: - - We are forced to use this pattern because some types cannot be handled by cattrs out of the box. For example, - union types, optionals, complex union types etc. - - See: https://github.com/python-attrs/cattrs/issues/311 - Args: - converter: A converter instance - - Returns: - A tuple of hook handlers - """ - - @lru_cache(1024) - def _default_structuring_hook(value: Any, annotation: Any) -> Any: - for arg in unwrap_union(annotation) or get_args(annotation): - try: - return converter.structure(arg, value) - except ValueError: # pragma: no cover - continue - return value - - return ( - _pass_through_unstructure_hook, - _default_structuring_hook, - ) - - -class Converter(cattrs.Converter): - def __init__(self) -> None: - super().__init__() - - # this is a hack to create a catch-all hook, see: https://github.com/python-attrs/cattrs/issues/311 - self._structure_func._function_dispatch._handler_pairs[-1] = ( - *_create_default_structuring_hooks(self), - False, - ) - - # ensure attrs instances are not unstructured into dict - self.register_unstructure_hook_factory( - # the first parameter is a predicate that tests the value. In this case we are testing for an attrs - # decorated class that does not have the AttrsSignatureModel anywhere in its mro chain. - lambda x: attrs.has(x) and AttrsSignatureModel not in list(x.__mro__), - # the "unstructuring" hook we are registering is a lambda that receives the class constructor and returns - # another lambda that will take a value and receive it unmodified. - # this is a hack to ensure that no attrs constructors are called during unstructuring. - lambda x: lambda x: x, - ) - - for cls, structure_hook in hooks: - self.register_structure_hook(cls, structure_hook) - self.register_unstructure_hook(cls, _pass_through_unstructure_hook) - - -_converter: Converter = Converter() - - -def _create_validators( - annotation: Any, kwarg_definition: KwargDefinition -) -> list[Callable[[Any, attrs.Attribute[Any], Any], Any]] | Callable[[Any, attrs.Attribute[Any], Any], Any]: - validators: list[Callable[[Any, attrs.Attribute[Any], Any], Any]] = [ - validator(value) # type: ignore[operator] - for value, validator in [ - (kwarg_definition.gt, attrs.validators.gt), - (kwarg_definition.ge, attrs.validators.ge), - (kwarg_definition.lt, attrs.validators.lt), - (kwarg_definition.le, attrs.validators.le), - (kwarg_definition.min_length, attrs.validators.min_len), - (kwarg_definition.max_length, attrs.validators.max_len), - (kwarg_definition.min_items, attrs.validators.min_len), - (kwarg_definition.max_items, attrs.validators.max_len), - ( - kwarg_definition.pattern, - partial(attrs.validators.matches_re, flags=0), - ), - ] - if value is not None - ] - if is_optional_union(annotation): - annotation = make_non_optional_union(annotation) - instance_of_validator = attrs.validators.instance_of( - unwrap_union(annotation) if is_union(annotation) else (get_origin_or_inner_type(annotation) or annotation) - ) - return attrs.validators.optional([instance_of_validator, *validators]) - - instance_of_validator = attrs.validators.instance_of(get_origin_or_inner_type(annotation) or annotation) - return [instance_of_validator, *validators] - - -@attr.define -class AttrsSignatureModel(SignatureModel): - """Model that represents a function signature that uses a pydantic specific type or types.""" - - @classmethod - def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: - try: - signature = _converter.structure(obj=kwargs, cl=cls) - except (cattrs.ClassValidationError, ValueError, TypeError, AttributeError) as e: - raise cls._create_exception(messages=cls._extract_exceptions(e, connection), connection=connection) from e - - return cast("dict[str, Any]", _converter.unstructure(obj=signature)) - - def to_dict(self) -> dict[str, Any]: - return attrs.asdict(self) - - @classmethod - def _extract_exceptions(cls, e: StructureException, connection: ASGIConnection) -> list[ErrorMessage]: - """Extracts and normalizes cattrs exceptions. - - Args: - e: An ExceptionGroup - which is a py3.11 feature. We use hasattr instead of instance checks to avoid installing this. - connection: The connection instance. - - Returns: - A list of normalized exception messages. - """ - - error_messages: list[ErrorMessage] = [] - - if isinstance(e, cattrs.ClassValidationError): - for exc in cast("list[StructureException]", e.exceptions): - if messages := cls._get_messages_from_traceback(exc, connection): - error_messages.extend(messages) - - return error_messages - - @classmethod - def _get_messages_from_traceback(cls, exc: StructureException, connection: ASGIConnection) -> list[ErrorMessage]: - """Gets a message from an attrs validation error. - - The key will be a dot-separated string of the attribute path that failed - validation. The message will be the error message from the exception (or the - last exception in the exception group, when applicable). - - Args: - exc: The exception to get the message from - connection: The connection instance - - Returns: - An error message - """ - - error_data = cls._get_data_from_exception(exc=exc) - return cls._get_error_messages(error_data=error_data, connection=connection) - - @classmethod - def _get_data_from_exception( - cls, exc: StructureException, prefix: str = "", error_data: dict | None = None - ) -> dict[str, str]: - """Gets the keys from an attrs validation error. - - Handles nested structures (e.g. a model attribute references another - model) by going through all exceptions in the exception group - """ - - error_data = error_data or {} - - if isinstance(exc, (cattrs.ClassValidationError, cattrs.IterableValidationError)): - formatted_exception = traceback.format_exception_only(type(exc), value=exc) - key = cls._get_key_from_formatted_exception(formatted_exception) - - new_prefix = f"{prefix}.{key}" if prefix else key - - for sub_exc in cast("list[StructureException]", exc.exceptions): - error_data = cls._get_data_from_exception(sub_exc, new_prefix, error_data) - # when using attrs as the preferred backend validation but - # pydantic as the model, you can still get pydantic - # validation errors. - elif isinstance(exc, pydantic.ValidationError): - formatted_exception = traceback.format_exception_only(type(exc), value=exc) - key = cls._get_key_from_formatted_exception(formatted_exception) - - for error in exc.errors(): - error_key = ".".join([key, *[str(loc) for loc in error["loc"]]]) - error_data[error_key] = error["msg"] - else: - formatted_exception = traceback.format_exception(type(exc), value=exc, tb=exc.__traceback__) - key = cls._get_key_from_formatted_exception(formatted_exception) - key = f"{prefix}.{key}" if prefix else key - - error_data[key] = str(exc) - - return error_data - - @classmethod - def _get_key_from_formatted_exception(cls, formatted_exception: list[str]) -> str: - """Gets the key from a formatted exception.""" - return next( - (key for line in formatted_exception if (match := key_re.findall(line)) and (key := match[0][1].strip())), - "", - ) - - @classmethod - def _get_error_messages(cls, error_data: dict[str, str], connection: ASGIConnection) -> list[ErrorMessage]: - """Build an error message. - - Args: - error_data: A mapping of error location (dot-notated) to their error. - connection: An ASGI connection instance. - - Returns: - An ErrorMessage - """ - - messages: list[ErrorMessage] = [] - - for key, error in error_data.items(): - keys = key.split(".") - message = super()._build_error_message(keys=keys, exc_msg=error, connection=connection) - messages.append(message) - - return messages - - @classmethod - def populate_field_definitions(cls) -> None: - cls.fields = {} - - for key, attribute in attrs.fields_dict(cls).items(): - metadata = dict(attribute.metadata) - field_definition = metadata.pop("field_definition") - cls.fields[key] = replace( - field_definition, - name=key, - default=attribute.default if attribute.default is not attr.NOTHING else Empty, - extra=metadata, - ) - - @classmethod - def create( - cls, - fn_name: str, - fn_module: str | None, - parsed_signature: ParsedSignature, - dependency_names: set[str], - type_overrides: dict[str, Any], - ) -> type[SignatureModel]: - attributes: dict[str, Any] = {} - - for parameter in parsed_signature.parameters.values(): - annotation = type_overrides.get(parameter.name, parameter.annotation) - - if kwarg_definition := parameter.kwarg_definition: - if isinstance(kwarg_definition, DependencyKwarg): - attribute = attr.attrib( - type=Any if kwarg_definition.skip_validation else annotation, - default=kwarg_definition.default if kwarg_definition.default is not Empty else None, - metadata={ - "kwarg_definition": kwarg_definition, - "field_definition": parameter, - }, - ) - else: - attribute = attr.attrib( - type=annotation, - metadata={ - **asdict(kwarg_definition), - "kwarg_definition": kwarg_definition, - "field_definition": parameter, - }, - default=kwarg_definition.default if kwarg_definition.default is not Empty else attr.NOTHING, - validator=_create_validators(annotation=annotation, kwarg_definition=kwarg_definition), - ) - elif parameter.has_default: - attribute = attr.attrib( - type=annotation, default=parameter.default, metadata={"field_definition": parameter} - ) - else: - attribute = attr.attrib( - type=annotation, - default=None if parameter.is_optional else attr.NOTHING, - metadata={"field_definition": parameter}, - ) - - attributes[parameter.name] = attribute - - model: type[AttrsSignatureModel] = attrs.make_class( - f"{fn_name}_signature_model", - attrs=attributes, - bases=(AttrsSignatureModel,), - slots=True, - kw_only=True, - ) - model.return_annotation = parsed_signature.return_type.annotation # pyright: ignore - model.dependency_name_set = dependency_names # pyright: ignore - model.populate_field_definitions() # pyright: ignore - return model diff --git a/litestar/_signature/models/base.py b/litestar/_signature/models/base.py deleted file mode 100644 index c4ce3eaf64..0000000000 --- a/litestar/_signature/models/base.py +++ /dev/null @@ -1,160 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Sequence, TypedDict, cast - -from litestar.enums import ScopeType -from litestar.exceptions import InternalServerException, ValidationException -from litestar.params import ParameterKwarg - -if TYPE_CHECKING: - from typing_extensions import NotRequired - - from litestar.connection import ASGIConnection - from litestar.typing import FieldDefinition - from litestar.utils.signature import ParsedSignature - -__all__ = ("SignatureModel",) - - -class ErrorMessage(TypedDict): - # key may not be set in some cases, like when a query param is set but - # doesn't match the required length during `attrs` validation - # in this case, we don't show a key at all as it will be empty - key: NotRequired[str] - message: str - source: NotRequired[Literal["cookie", "body", "header", "query"]] - - -class SignatureModel(ABC): - """Base model for Signature modelling.""" - - dependency_name_set: ClassVar[set[str]] - return_annotation: ClassVar[Any] - fields: ClassVar[dict[str, FieldDefinition]] - - @classmethod - def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception: - """Create an exception class - either a ValidationException or an InternalServerException, depending on whether - the failure is in client provided values or injected dependencies. - - Args: - connection: An ASGI connection instance. - messages: A list of error messages. - - Returns: - An Exception - """ - method = connection.method if hasattr(connection, "method") else ScopeType.WEBSOCKET # pyright: ignore - if client_errors := [ - err_message - for err_message in messages - if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message - ]: - return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors) - return InternalServerException( - detail=f"A dependency failed validation for {method} {connection.url}", extra=messages - ) - - @classmethod - def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage: - """Build an error message. - - Args: - keys: A list of keys. - exc_msg: A message. - connection: An ASGI connection instance. - - Returns: - An ErrorMessage - """ - - message: ErrorMessage = {"message": exc_msg} - - if len(keys) > 1: - key_start = 0 - - if keys[0] == "data": - key_start = 1 - message["source"] = "body" - - message["key"] = ".".join(keys[key_start:]) - elif keys: - key = keys[0] - message["key"] = key - - if key in connection.query_params: - message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query") - - elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg): - if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie: - source = "cookie" - elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header: - source = "header" - else: - source = "query" - message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source) - - return message - - @classmethod - @abstractmethod - def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: - """Extract values from the connection instance and return a dict of parsed values. - - Args: - connection: The ASGI connection instance. - **kwargs: A dictionary of kwargs. - - Raises: - ValidationException: If validation failed. - InternalServerException: If another exception has been raised. - - Returns: - A dictionary of parsed values - """ - raise NotImplementedError - - @abstractmethod - def to_dict(self) -> dict[str, Any]: - """Normalize access to the signature model's dictionary method, because different backends use different methods - for this. - - Returns: A dictionary of string keyed values. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def populate_field_definitions(cls) -> None: - """Populate the class signature fields. - - Returns: - None. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def create( - cls, - fn_name: str, - fn_module: str | None, - parsed_signature: ParsedSignature, - dependency_names: set[str], - type_overrides: dict[str, Any], - ) -> type[SignatureModel]: - """Create a SignatureModel. - - Args: - fn_name: Name of the callable. - fn_module: Name of the function's module, if any. - parsed_signature: A parsed signature. - dependency_names: A set of dependency names. - type_overrides: A dictionary of type overrides, either will override a parameter type with a type derived - from a plugin, or set the type to ``Any`` if validation should be skipped for the parameter. - - Returns: - The created SignatureModel. - """ - raise NotImplementedError diff --git a/litestar/_signature/models/pydantic_signature_model.py b/litestar/_signature/models/pydantic_signature_model.py deleted file mode 100644 index 65d6b24d84..0000000000 --- a/litestar/_signature/models/pydantic_signature_model.py +++ /dev/null @@ -1,184 +0,0 @@ -from __future__ import annotations - -from dataclasses import asdict, replace -from typing import TYPE_CHECKING, Any - -from pydantic import BaseConfig, BaseModel, ValidationError, create_model -from pydantic.fields import FieldInfo, ModelField - -from litestar._signature.models.base import ErrorMessage, SignatureModel -from litestar.constants import UNDEFINED_SENTINELS -from litestar.params import DependencyKwarg -from litestar.types import Empty -from litestar.typing import FieldDefinition -from litestar.utils.predicates import is_pydantic_constrained_field - -if TYPE_CHECKING: - from litestar.connection import ASGIConnection - from litestar.utils.signature import ParsedSignature - -__all__ = ("PydanticSignatureModel",) - - -class PydanticSignatureModel(SignatureModel, BaseModel): - """Model that represents a function signature that uses a pydantic specific type or types.""" - - class Config(BaseConfig): - copy_on_model_validation = "none" - arbitrary_types_allowed = True - - @classmethod - def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: - """Extract values from the connection instance and return a dict of parsed values. - - Args: - connection: The ASGI connection instance. - **kwargs: A dictionary of kwargs. - - Raises: - ValidationException: If validation failed. - InternalServerException: If another exception has been raised. - - Returns: - A dictionary of parsed values - """ - try: - signature = cls(**kwargs) - except ValidationError as e: - messages = cls._get_error_messages(e, connection) - raise cls._create_exception(messages=messages, connection=connection) from e - - return signature.to_dict() - - def to_dict(self) -> dict[str, Any]: - """Normalize access to the signature model's dictionary method, because different backends use different methods - for this. - - Returns: A dictionary of string keyed values. - """ - return {key: self.__getattribute__(key) for key in self.__fields__} - - @classmethod - def field_definition_from_model_field(cls, model_field: ModelField) -> FieldDefinition: - """Create a FieldDefinition instance from a pydantic ModelField. - - Args: - model_field: A pydantic ModelField instance. - - Returns: - A FieldDefinition - """ - inner_types = ( - tuple(cls.field_definition_from_model_field(sub_field) for sub_field in model_field.sub_fields) - if model_field.sub_fields - else None - ) - - default = model_field.field_info.default if model_field.field_info.default not in UNDEFINED_SENTINELS else Empty - - return FieldDefinition.from_kwarg( - inner_types=inner_types, - default=default, - extra=model_field.field_info.extra or {}, - annotation=model_field.annotation, - name=model_field.name, - ) - - @classmethod - def populate_field_definitions(cls) -> None: - """Populate the class signature fields. - - Returns: - None. - """ - cls.fields = {} - - for field_name, field in cls.__fields__.items(): - field_definition = field.field_info.extra.pop("field_definition") - default = field.field_info.default if field.field_info.default not in UNDEFINED_SENTINELS else Empty - if field_definition.is_optional and default is Empty: - default = None - - cls.fields[field_name] = replace(field_definition, default=default) - - @classmethod - def create( - cls, - fn_name: str, - fn_module: str | None, - parsed_signature: ParsedSignature, - dependency_names: set[str], - type_overrides: dict[str, Any], - ) -> type[PydanticSignatureModel]: - """Create a pydantic based SignatureModel. - - Args: - fn_name: Name of the callable. - fn_module: Name of the function's module, if any. - parsed_signature: A ParsedSignature instance. - dependency_names: A set of dependency names. - type_overrides: A dictionary of type overrides, either will override a parameter type with a type derived - from a plugin, or set the type to ``Any`` if validation should be skipped for the parameter. - - Returns: - The created PydanticSignatureModel. - """ - field_definitions: dict[str, tuple[Any, Any]] = {} - - for parameter in parsed_signature.parameters.values(): - field_type = type_overrides.get(parameter.name, parameter.annotation) - - if kwarg_definition := parameter.kwarg_definition: - if isinstance(kwarg_definition, DependencyKwarg): - field_info = FieldInfo( - default=kwarg_definition.default if kwarg_definition.default is not Empty else None, - kwarg_definition=kwarg_definition, - field_definition=parameter, - ) - if kwarg_definition.skip_validation: - field_type = Any - else: - kwargs: dict[str, Any] = {k: v for k, v in asdict(kwarg_definition).items() if v is not Empty} - - if "pattern" in kwargs: - kwargs["regex"] = kwargs["pattern"] - - field_info = FieldInfo( - **kwargs, - kwarg_definition=kwarg_definition, - field_definition=parameter, - ) - else: - field_info = FieldInfo(default=..., field_definition=parameter) - - if is_pydantic_constrained_field(parameter.default): - field_type = parameter.default - elif parameter.has_default: - field_info.default = parameter.default - elif parameter.is_optional: - field_info.default = None - - field_definitions[parameter.name] = (field_type, field_info) - - model: type[PydanticSignatureModel] = create_model( # type: ignore - f"{fn_name}_signature_model", - __base__=PydanticSignatureModel, - __module__=fn_module or "pydantic.main", - **field_definitions, # pyright: ignore - ) - model.return_annotation = parsed_signature.return_type.annotation - model.dependency_name_set = dependency_names - model.populate_field_definitions() - return model - - @classmethod - def _get_error_messages(cls, e: ValidationError, connection: ASGIConnection) -> list[ErrorMessage]: - """Get error messages from a ValidationError.""" - messages: list[ErrorMessage] = [] - - for exc in e.errors(): - keys = [str(loc) for loc in exc["loc"]] - message = super()._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection) - messages.append(message) - - return messages diff --git a/litestar/_signature/utils.py b/litestar/_signature/utils.py index c4cafe97f9..34eab43d7d 100644 --- a/litestar/_signature/utils.py +++ b/litestar/_signature/utils.py @@ -1,151 +1,66 @@ from __future__ import annotations -from inspect import getmembers, isclass -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, cast -import pydantic - -from litestar._signature.models.pydantic_signature_model import PydanticSignatureModel from litestar.constants import SKIP_VALIDATION_NAMES from litestar.exceptions import ImproperlyConfiguredException from litestar.params import DependencyKwarg -from litestar.types import AnyCallable, Empty -from litestar.utils.helpers import unwrap_partial -from litestar.utils.predicates import is_attrs_class - -pydantic_types: tuple[Any, ...] = tuple( - cls for _, cls in getmembers(pydantic.types, isclass) if "pydantic.types" in repr(cls) -) - +from litestar.types import Empty if TYPE_CHECKING: - from typing_extensions import TypeAlias - - from litestar._signature.models.base import SignatureModel - from litestar.typing import FieldDefinition from litestar.utils.signature import ParsedSignature -__all__ = ( - "create_signature_model", - "get_signature_model", -) - - -SignatureModelType: TypeAlias = "type[SignatureModel]" - - -def create_signature_model( - dependency_name_set: set[str], - fn: AnyCallable, - preferred_validation_backend: Literal["pydantic", "attrs"], - parsed_signature: ParsedSignature, - has_data_dto: bool = False, -) -> type[SignatureModel]: - """Create a model for a callable's signature. The model can than be used to parse and validate before passing it to - the callable. - - Args: - dependency_name_set: A set of dependency names - fn: A callable. - preferred_validation_backend: Validation/Parsing backend to prefer, if installed - parsed_signature: A parsed signature for the handler/dependency function. - has_data_dto: Is a data DTO defined for the handler? - - Returns: - A signature model. - """ - - unwrapped_fn = cast("AnyCallable", unwrap_partial(fn)) - fn_name = getattr(fn, "__name__", "anonymous") - fn_module = getattr(fn, "__module__", None) - - if fn_name == "": - fn_name = "anonymous" - - dependency_names = _validate_dependencies( - dependency_name_set=dependency_name_set, fn=unwrapped_fn, parsed_signature=parsed_signature - ) + from .model import SignatureModel - model_class = _get_signature_model_type( - preferred_validation_backend=preferred_validation_backend, parsed_signature=parsed_signature - ) - - type_overrides = _create_type_overrides(parsed_signature, has_data_dto) - - return model_class.create( - fn_name=fn_name, - fn_module=fn_module, - parsed_signature=parsed_signature, - dependency_names={*dependency_name_set, *dependency_names}, - type_overrides=type_overrides, - ) +__all__ = ("create_type_overrides", "validate_signature_dependencies", "get_signature_model") def get_signature_model(value: Any) -> type[SignatureModel]: """Retrieve and validate the signature model from a provider or handler.""" try: - return cast(SignatureModelType, value.signature_model) + return cast("type[SignatureModel]", value.signature_model) except AttributeError as e: # pragma: no cover raise ImproperlyConfiguredException(f"The 'signature_model' attribute for {value} is not set") from e -def _any_attrs_annotation(parsed_signature: ParsedSignature) -> bool: - return any( - any(is_attrs_class(t.annotation) for t in field_definition.inner_types) - or is_attrs_class(field_definition.annotation) - for field_definition in parsed_signature.parameters.values() - ) +def create_type_overrides(parsed_signature: ParsedSignature, has_data_dto: bool) -> dict[str, Any]: + """Create typing overrides for field definitions. + Args: + parsed_signature: A parsed function signature. + has_data_dto: Whether the signature contains a data DTO. -def _create_type_overrides(parsed_signature: ParsedSignature, has_data_dto: bool) -> dict[str, Any]: + Returns: + A dictionary of typing overrides + """ type_overrides = {} - for parameter in parsed_signature.parameters.values(): - if _should_skip_validation(parameter): - type_overrides[parameter.name] = Any + for field_definition in parsed_signature.parameters.values(): + if field_definition.name in SKIP_VALIDATION_NAMES or ( + isinstance(field_definition.kwarg_definition, DependencyKwarg) + and field_definition.kwarg_definition.skip_validation + ): + type_overrides[field_definition.name] = Any + if has_data_dto and "data" in parsed_signature.parameters: type_overrides["data"] = Any - return type_overrides - -def _get_signature_model_type( - preferred_validation_backend: Literal["pydantic", "attrs"], - parsed_signature: ParsedSignature, -) -> type[SignatureModel]: - if preferred_validation_backend == "attrs" or _any_attrs_annotation(parsed_signature): - from litestar._signature.models.attrs_signature_model import AttrsSignatureModel - - return AttrsSignatureModel - return PydanticSignatureModel - - -def _should_skip_validation(field_definition: FieldDefinition) -> bool: - """Whether the parameter should skip validation. - - Returns: - A boolean indicating whether the parameter should be validated. - """ - return field_definition.name in SKIP_VALIDATION_NAMES or ( - isinstance(field_definition.kwarg_definition, DependencyKwarg) - and field_definition.kwarg_definition.skip_validation - ) + return type_overrides -def _validate_dependencies( - dependency_name_set: set[str], fn: AnyCallable, parsed_signature: ParsedSignature +def validate_signature_dependencies( + dependency_name_set: set[str], fn_name: str, parsed_signature: ParsedSignature ) -> set[str]: """Validate dependencies of ``parsed_signature``. Args: dependency_name_set: A set of dependency names - fn: A callable. + fn_name: A callable's name. parsed_signature: A parsed signature. Returns: A set of validated dependency names. """ - fn_name = getattr(fn, "__name__", "anonymous") - - dependency_names: set[str] = set() + dependency_names: set[str] = set(dependency_name_set) for parameter in parsed_signature.parameters.values(): if isinstance(parameter.kwarg_definition, DependencyKwarg) and parameter.name not in dependency_name_set: diff --git a/litestar/app.py b/litestar/app.py index e973c40a02..2a163893ee 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -8,7 +8,7 @@ from functools import partial from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Literal, Mapping, Sequence, TypedDict, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Mapping, Sequence, TypedDict, cast from litestar._asgi import ASGIRouter from litestar._asgi.utils import get_route_handlers, wrap_in_exception_handler @@ -39,7 +39,7 @@ from litestar.stores.registry import StoreRegistry from litestar.types import Empty from litestar.types.internal_types import PathParameterDefinition -from litestar.utils import AsyncCallable, join_paths, unique, warn_deprecation +from litestar.utils import AsyncCallable, join_paths, unique from litestar.utils.dataclass import extract_dataclass_items from litestar.utils.predicates import is_async_callable from litestar.utils.warnings import warn_pdb_on_exception @@ -129,7 +129,6 @@ class Litestar(Router): "_lifespan_managers", "_debug", "_openapi_schema", - "_preferred_validation_backend", "after_exception", "allowed_hosts", "asgi_handler", @@ -209,7 +208,6 @@ def __init__( websocket_class: type[WebSocket] | None = None, lifespan: list[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] | None = None, pdb_on_exception: bool | None = None, - _preferred_validation_backend: Literal["attrs", "pydantic"] | None = None, ) -> None: """Initialize a ``Litestar`` application. @@ -305,13 +303,6 @@ def __init__( if pdb_on_exception is None: pdb_on_exception = os.getenv("LITESTAR_PDB", "0") == "1" - if _preferred_validation_backend is not None: - warn_deprecation( - version="2.0.0beta1", - kind="parameter", - deprecated_name="_preferred_validation_backend", - ) - config = AppConfig( after_exception=list(after_exception or []), after_request=after_request, @@ -388,7 +379,6 @@ def __init__( self.on_startup = config.on_startup self.openapi_config = config.openapi_config self.openapi_schema_plugins = [p for p in config.plugins if isinstance(p, OpenAPISchemaPluginProtocol)] - self._preferred_validation_backend: Literal["attrs", "pydantic"] = _preferred_validation_backend or "pydantic" self.request_class = config.request_class or Request self.response_cache_config = config.response_cache_config self.serialization_plugins = [p for p in config.plugins if isinstance(p, SerializationPluginProtocol)] diff --git a/litestar/handlers/base.py b/litestar/handlers/base.py index 85a71444c1..29cdc2c45e 100644 --- a/litestar/handlers/base.py +++ b/litestar/handlers/base.py @@ -3,7 +3,7 @@ from copy import copy from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast -from litestar._signature import create_signature_model +from litestar._signature import SignatureModel from litestar.di import Provide from litestar.dto.interface import HandlerContext from litestar.exceptions import ImproperlyConfiguredException @@ -18,7 +18,6 @@ from typing_extensions import Self from litestar import Litestar - from litestar._signature.models import SignatureModel from litestar.connection import ASGIConnection from litestar.controller import Controller from litestar.dto.interface import DTOInterface @@ -421,10 +420,9 @@ def _set_runtime_callables(self) -> None: def _create_signature_model(self, app: Litestar) -> None: """Create signature model for handler function.""" if not self.signature_model: - self.signature_model = create_signature_model( + self.signature_model = SignatureModel.create( dependency_name_set=self.dependency_name_set, fn=cast("AnyCallable", self.fn.value), - preferred_validation_backend=app._preferred_validation_backend, parsed_signature=self.parsed_fn_signature, has_data_dto=bool(self.resolve_dto()), ) @@ -433,10 +431,9 @@ def _create_provider_signature_models(self, app: Litestar) -> None: """Create signature models for dependency providers.""" for provider in self.resolve_dependencies().values(): if not getattr(provider, "signature_model", None): - provider.signature_model = create_signature_model( + provider.signature_model = SignatureModel.create( dependency_name_set=self.dependency_name_set, fn=provider.dependency.value, - preferred_validation_backend=app._preferred_validation_backend, parsed_signature=ParsedSignature.from_fn( unwrap_partial(provider.dependency.value), self.resolve_signature_namespace() ), diff --git a/litestar/handlers/websocket_handlers/listener.py b/litestar/handlers/websocket_handlers/listener.py index 9f05ad3c0f..937f51fe72 100644 --- a/litestar/handlers/websocket_handlers/listener.py +++ b/litestar/handlers/websocket_handlers/listener.py @@ -17,7 +17,7 @@ from msgspec.json import Encoder as JsonEncoder -from litestar._signature import create_signature_model +from litestar._signature import SignatureModel from litestar.connection import WebSocket from litestar.dto.interface import HandlerContext from litestar.exceptions import ImproperlyConfiguredException, WebSocketDisconnect @@ -290,10 +290,9 @@ def _create_signature_model(self, app: Litestar) -> None: new_signature = create_handler_signature( self._listener_context.listener_callback_signature.original_signature ) - self.signature_model = create_signature_model( + self.signature_model = SignatureModel.create( dependency_name_set=self.dependency_name_set, fn=cast("AnyCallable", self.fn.value), - preferred_validation_backend=app._preferred_validation_backend, parsed_signature=ParsedSignature.from_signature(new_signature, self.resolve_signature_namespace()), ) diff --git a/litestar/serialization.py b/litestar/serialization.py index 1bdfd2624e..3f57f56980 100644 --- a/litestar/serialization.py +++ b/litestar/serialization.py @@ -14,9 +14,12 @@ from pathlib import Path, PurePath from re import Pattern from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload +from uuid import UUID import msgspec +from msgspec import ValidationError from pydantic import ( + UUID1, BaseModel, ByteSize, ConstrainedBytes, @@ -65,7 +68,7 @@ def _enc_constrained_date(date: ConstrainedDate) -> str: return date.isoformat() -def _enc_pattern(pattern: Pattern) -> Any: +def _enc_pattern(pattern: Pattern[str]) -> Any: return pattern.pattern @@ -125,6 +128,35 @@ def default_serializer(value: Any, type_encoders: Mapping[Any, Callable[[Any], A raise TypeError(f"Unsupported type: {type(value)!r}") +PydanticUUIDType = TypeVar("PydanticUUIDType", bound="UUID1") + + +def _dec_pydantic_uuid(type_: type[PydanticUUIDType], val: Any) -> PydanticUUIDType: + if isinstance(val, str): + val = type_(val) + elif isinstance(val, (bytes, bytearray)): + try: + val = type_(val.decode()) + except ValueError: + # 16 bytes in big-endian order as the bytes argument fail + # the above check + val = type_(bytes=val) + elif isinstance(val, UUID): + val = type_(str(val)) + + if not isinstance(val, type_): + raise ValidationError(f"Invalid UUID: {val!r}") + + if type_._required_version != val.version: # type:ignore[attr-defined] + raise ValidationError(f"Invalid UUID version: {val!r}") + + return val + + +def _dec_pydantic(type_: type[BaseModel], value: Any) -> BaseModel: + return type_.parse_obj(value) + + def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover """Transform values non-natively supported by ``msgspec`` @@ -135,9 +167,16 @@ def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover Returns: A ``msgspec``-supported type """ + + from litestar.datastructures.state import ImmutableState + + if issubclass(type_, UUID1): + return _dec_pydantic_uuid(type_, value) + if isinstance(value, type_): + return value if issubclass(type_, BaseModel): - return type_.parse_obj(value) - if issubclass(type_, (Path, PurePath)): + return _dec_pydantic(type_, value) + if issubclass(type_, (Path, PurePath, ImmutableState, UUID)): return type_(value) raise TypeError(f"Unsupported type: {type(value)!r}") diff --git a/litestar/testing/helpers.py b/litestar/testing/helpers.py index d087cd90f5..10a5438545 100644 --- a/litestar/testing/helpers.py +++ b/litestar/testing/helpers.py @@ -107,7 +107,6 @@ def create_test_client( timeout: float | None = None, type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, - _preferred_validation_backend: Literal["pydantic", "attrs"] | None = None, ) -> TestClient[Litestar]: """Create a Litestar app instance and initializes it. @@ -281,7 +280,6 @@ def test_my_handler() -> None: template_config=template_config, type_encoders=type_encoders, websocket_class=websocket_class, - _preferred_validation_backend=_preferred_validation_backend, ) return TestClient[Litestar]( @@ -351,7 +349,6 @@ def create_async_test_client( timeout: float | None = None, type_encoders: TypeEncodersMap | None = None, websocket_class: type[WebSocket] | None = None, - _preferred_validation_backend: Literal["pydantic", "attrs"] | None = None, ) -> AsyncTestClient[Litestar]: """Create a Litestar app instance and initializes it. @@ -525,7 +522,6 @@ def test_my_handler() -> None: template_config=template_config, type_encoders=type_encoders, websocket_class=websocket_class, - _preferred_validation_backend=_preferred_validation_backend, ) return AsyncTestClient[Litestar]( diff --git a/litestar/utils/typing.py b/litestar/utils/typing.py index 93261da003..b3802c2b4e 100644 --- a/litestar/utils/typing.py +++ b/litestar/utils/typing.py @@ -37,6 +37,7 @@ "instantiable_type_mapping", "make_non_optional_union", "unwrap_annotation", + "unwrap_union", ) diff --git a/poetry.lock b/poetry.lock index f7c1a0c178..5b0d6363ea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -540,31 +540,6 @@ files = [ {file = "cachetools-5.3.1.tar.gz", hash = "sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b"}, ] -[[package]] -name = "cattrs" -version = "23.1.2" -description = "Composable complex class support for attrs and dataclasses." -optional = false -python-versions = ">=3.7" -files = [ - {file = "cattrs-23.1.2-py3-none-any.whl", hash = "sha256:b2bb14311ac17bed0d58785e5a60f022e5431aca3932e3fc5cc8ed8639de50a4"}, - {file = "cattrs-23.1.2.tar.gz", hash = "sha256:db1c821b8c537382b2c7c66678c3790091ca0275ac486c76f3c8f3920e83c657"}, -] - -[package.dependencies] -attrs = ">=20" -exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} -typing_extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} - -[package.extras] -bson = ["pymongo (>=4.2.0,<5.0.0)"] -cbor2 = ["cbor2 (>=5.4.6,<6.0.0)"] -msgpack = ["msgpack (>=1.0.2,<2.0.0)"] -orjson = ["orjson (>=3.5.2,<4.0.0)"] -pyyaml = ["PyYAML (>=6.0,<7.0)"] -tomlkit = ["tomlkit (>=0.11.4,<0.12.0)"] -ujson = ["ujson (>=5.4.0,<6.0.0)"] - [[package]] name = "certifi" version = "2023.5.7" @@ -4309,11 +4284,11 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] annotated-types = ["annotated-types"] -attrs = ["attrs", "cattrs", "python-dateutil", "pytimeparse"] +attrs = ["attrs"] brotli = ["brotli"] cli = ["click", "jsbeautifier", "rich", "rich-click", "uvicorn"] cryptography = ["cryptography"] -full = ["alembic", "attrs", "brotli", "cattrs", "click", "cryptography", "jinja2", "jsbeautifier", "opentelemetry-instrumentation-asgi", "prometheus-client", "python-dateutil", "python-jose", "pytimeparse", "redis", "rich", "sqlalchemy", "structlog", "uvicorn"] +full = ["alembic", "attrs", "brotli", "click", "cryptography", "jinja2", "jsbeautifier", "opentelemetry-instrumentation-asgi", "prometheus-client", "python-dateutil", "python-jose", "pytimeparse", "redis", "rich", "sqlalchemy", "structlog", "uvicorn"] jinja = ["jinja2"] jwt = ["cryptography", "python-jose"] opentelemetry = ["opentelemetry-instrumentation-asgi"] @@ -4328,4 +4303,4 @@ tortoise-orm = ["tortoise-orm"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "70f5c93f8dda8cc11bf8c1939968fe289983304abe84a221d88d3316e57f7104" +content-hash = "850899772834e4726e6f195e57c6d43a84bb1a324f8b0f801df8fbe86ef73893" diff --git a/pyproject.toml b/pyproject.toml index 1b88b68baa..13b665cd7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,6 @@ annotated-types = { version = "*", optional = true } anyio = ">=3" attrs = { version = "*", optional = true } brotli = { version = "*", optional = true } -cattrs = { version = "*", optional = true } click = { version = "*", optional = true } cryptography = { version = "*", optional = true } fast-query-parsers = "*" @@ -89,7 +88,7 @@ importlib-resources = { version = ">=5.12.0", python = "<3.9" } jinja2 = { version = ">=3.1.2", optional = true } jsbeautifier = { version = "*", optional = true } mako = { version = ">=1.2.4", optional = true } -msgspec = ">=0.16.0" +msgspec = ">=0.17.0" multidict = ">=6.0.2" opentelemetry-instrumentation-asgi = { version = "*", optional = true } picologging = { version = "*", optional = true } @@ -119,7 +118,6 @@ attrs = "*" beanie = "*" beautifulsoup4 = "*" brotli = "*" -cattrs = "*" click = "*" cryptography = "*" duckdb-engine = "*" @@ -197,7 +195,7 @@ types-redis = "*" [tool.poetry.extras] annotated-types = ["annotated-types"] -attrs = ["attrs", "cattrs", "python-dateutil", "pytimeparse"] +attrs = ["attrs"] brotli = ["brotli"] cli = ["click", "rich", "rich-click", "jsbeautifier", "uvicorn"] cryptography = ["cryptography"] @@ -216,7 +214,6 @@ full = [ "alembic", "attrs", "brotli", - "cattrs", "click", "cryptography", "jinja2", diff --git a/tests/e2e/test_routing/test_path_resolution.py b/tests/e2e/test_routing/test_path_resolution.py index d01df7b7b8..90bf7556e3 100644 --- a/tests/e2e/test_routing/test_path_resolution.py +++ b/tests/e2e/test_routing/test_path_resolution.py @@ -245,9 +245,9 @@ def test_support_for_path_type_parameters() -> None: def lower_handler(string_param: str) -> str: return string_param - @get(path="/{string_param:str}/{parth_param:path}") - def upper_handler(string_param: str, parth_param: Path) -> str: - return string_param + str(parth_param) + @get(path="/{string_param:str}/{path_param:path}") + def upper_handler(string_param: str, path_param: Path) -> str: + return string_param + str(path_param) with create_test_client([lower_handler, upper_handler]) as client: response = client.get("/abc") diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 60c7ed4530..e3f4125338 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -139,7 +139,7 @@ def test_app_params_defined_on_app_config_object() -> None: litestar_signature = inspect.signature(Litestar) app_config_fields = {f.name for f in fields(AppConfig)} for name in litestar_signature.parameters: - if name in {"on_app_init", "initial_state", "_preferred_validation_backend"}: + if name in {"on_app_init", "initial_state"}: continue assert name in app_config_fields # ensure there are not fields defined on AppConfig that aren't in the Litestar signature diff --git a/tests/unit/test_dto/test_attrs_fail.py b/tests/unit/test_dto/test_attrs_fail.py deleted file mode 100644 index 817c96016e..0000000000 --- a/tests/unit/test_dto/test_attrs_fail.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -import pytest - -from litestar import post -from litestar.testing import create_test_client - -from . import MockDTO, Model - - -@pytest.mark.xfail -def test_dto_defined_on_handler() -> None: - @post(dto=MockDTO) - def handler(data: Model) -> Model: - assert data == Model(a=1, b="2") - return data - - with create_test_client(route_handlers=handler, _preferred_validation_backend="attrs") as client: - response = client.post("/", json={"what": "ever"}) - assert response.status_code == 201 - assert response.json() == {"a": 1, "b": "2"} diff --git a/tests/unit/test_dto/test_factory/test_integration.py b/tests/unit/test_dto/test_factory/test_integration.py index e556c6d77b..7a736bafee 100644 --- a/tests/unit/test_dto/test_factory/test_integration.py +++ b/tests/unit/test_dto/test_factory/test_integration.py @@ -78,7 +78,7 @@ def handler(data: Foo) -> Foo: assert data.bar == "hello" return data - with create_test_client(route_handlers=[handler]) as client: + with create_test_client(route_handlers=[handler], debug=True) as client: response = client.post("/", json={"baz": "hello"}) assert response.json() == {"baz": "hello"} diff --git a/tests/unit/test_dto/test_integration.py b/tests/unit/test_dto/test_integration.py index 0c98fbac11..a398e95a5d 100644 --- a/tests/unit/test_dto/test_integration.py +++ b/tests/unit/test_dto/test_integration.py @@ -15,7 +15,7 @@ def handler(data: Model) -> Model: assert data == Model(a=1, b="2") return data - with create_test_client(route_handlers=handler, _preferred_validation_backend="pydantic") as client: + with create_test_client(route_handlers=handler) as client: response = client.post("/", json={"what": "ever"}) assert response.status_code == 201 assert response.json() == {"a": 1, "b": "2"} diff --git a/tests/unit/test_handlers/test_http_handlers/test_to_response.py b/tests/unit/test_handlers/test_http_handlers/test_to_response.py index 694d4aad7d..c4ab4f3bf3 100644 --- a/tests/unit/test_handlers/test_http_handlers/test_to_response.py +++ b/tests/unit/test_handlers/test_http_handlers/test_to_response.py @@ -10,7 +10,7 @@ from starlette.responses import Response as StarletteResponse from litestar import HttpMethod, Litestar, MediaType, Request, Response, get, route -from litestar._signature import create_signature_model +from litestar._signature import SignatureModel from litestar.background_tasks import BackgroundTask from litestar.contrib.jinja import JinjaTemplateEngine from litestar.datastructures import Cookie, ResponseHeader @@ -90,9 +90,8 @@ async def test_function(data: Person) -> Person: return data person_instance = PersonFactory.build() - test_function.signature_model = create_signature_model( + test_function.signature_model = SignatureModel.create( fn=test_function.fn.value, - preferred_validation_backend="pydantic", dependency_name_set=set(), parsed_signature=ParsedSignature.from_fn(test_function.fn.value, {}), ) diff --git a/tests/unit/test_openapi/test_parameters.py b/tests/unit/test_openapi/test_parameters.py index d7a953c421..3d8a6664af 100644 --- a/tests/unit/test_openapi/test_parameters.py +++ b/tests/unit/test_openapi/test_parameters.py @@ -7,7 +7,7 @@ from litestar._openapi.parameters import create_parameter_for_handler from litestar._openapi.schema_generation import SchemaCreator from litestar._openapi.typescript_converter.schema_parsing import is_schema_value -from litestar._signature import create_signature_model +from litestar._signature import SignatureModel from litestar.di import Provide from litestar.enums import ParamType from litestar.exceptions import ImproperlyConfiguredException @@ -28,10 +28,9 @@ def _create_parameters(app: Litestar, path: str) -> List["OpenAPIParameter"]: handler = route_handler.fn.value assert callable(handler) - handler_fields = create_signature_model( + handler_fields = SignatureModel.create( fn=handler, dependency_name_set=set(), - preferred_validation_backend=app._preferred_validation_backend, parsed_signature=route_handler.parsed_fn_signature, ).fields diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index 99a182ad4f..b693bd74e5 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -15,7 +15,6 @@ SchemaCreator, create_schema_for_annotation, ) -from litestar._signature.models.pydantic_signature_model import PydanticSignatureModel from litestar.app import DEFAULT_OPENAPI_CONFIG from litestar.di import Provide from litestar.enums import ParamType @@ -110,9 +109,7 @@ class Opts(str, Enum): class M(BaseModel): opt: Opts - schema = create_schema_for_annotation( - annotation=PydanticSignatureModel.field_definition_from_model_field(M.__fields__["opt"]).annotation - ) + schema = create_schema_for_annotation(annotation=M.__annotations__["opt"]) assert schema assert schema.enum == ["opt1", "opt2"] diff --git a/tests/unit/test_params.py b/tests/unit/test_params.py index 2b7cb20049..a8fb2722ec 100644 --- a/tests/unit/test_params.py +++ b/tests/unit/test_params.py @@ -11,13 +11,12 @@ from litestar.testing import create_test_client -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_parsing_of_parameter_as_annotated(backend: Any) -> None: +def test_parsing_of_parameter_as_annotated() -> None: @get(path="/") def handler(param: Annotated[str, Parameter(min_length=1)]) -> str: return param - with create_test_client(handler, _preferred_validation_backend=backend) as client: + with create_test_client(handler) as client: response = client.get("/") assert response.status_code == HTTP_400_BAD_REQUEST @@ -25,13 +24,12 @@ def handler(param: Annotated[str, Parameter(min_length=1)]) -> str: assert response.status_code == HTTP_200_OK -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_parsing_of_parameter_as_default(backend: Any) -> None: +def test_parsing_of_parameter_as_default() -> None: @get(path="/") def handler(param: str = Parameter(min_length=1)) -> str: return param - with create_test_client(handler, _preferred_validation_backend=backend) as client: + with create_test_client(handler) as client: response = client.get("/?param=") assert response.status_code == HTTP_400_BAD_REQUEST @@ -39,13 +37,12 @@ def handler(param: str = Parameter(min_length=1)) -> str: assert response.status_code == HTTP_200_OK -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_parsing_of_body_as_annotated(backend: Any) -> None: +def test_parsing_of_body_as_annotated() -> None: @post(path="/") def handler(data: Annotated[List[str], Body(min_items=1)]) -> List[str]: return data - with create_test_client(handler, _preferred_validation_backend=backend) as client: + with create_test_client(handler) as client: response = client.post("/", json=[]) assert response.status_code == HTTP_400_BAD_REQUEST @@ -53,13 +50,12 @@ def handler(data: Annotated[List[str], Body(min_items=1)]) -> List[str]: assert response.status_code == HTTP_201_CREATED -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_parsing_of_body_as_default(backend: Any) -> None: +def test_parsing_of_body_as_default() -> None: @post(path="/") def handler(data: List[str] = Body(min_items=1)) -> List[str]: return data - with create_test_client(handler, _preferred_validation_backend=backend) as client: + with create_test_client(handler) as client: response = client.post("/", json=[]) assert response.status_code == HTTP_400_BAD_REQUEST @@ -67,24 +63,22 @@ def handler(data: List[str] = Body(min_items=1)) -> List[str]: assert response.status_code == HTTP_201_CREATED -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_parsing_of_dependency_as_annotated(backend: Any) -> None: +def test_parsing_of_dependency_as_annotated() -> None: @get(path="/", dependencies={"dep": Provide(lambda: None, sync_to_thread=False)}) def handler(dep: Annotated[int, Dependency(skip_validation=True)]) -> int: return dep - with create_test_client(handler, _preferred_validation_backend=backend) as client: + with create_test_client(handler) as client: response = client.get("/") assert response.text == "null" -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_parsing_of_dependency_as_default(backend: Any) -> None: +def test_parsing_of_dependency_as_default() -> None: @get(path="/", dependencies={"dep": Provide(lambda: None, sync_to_thread=False)}) def handler(dep: int = Dependency(skip_validation=True)) -> int: return dep - with create_test_client(handler, _preferred_validation_backend=backend) as client: + with create_test_client(handler) as client: response = client.get("/") assert response.text == "null" @@ -211,14 +205,13 @@ def get_seq(seq: List[str]) -> List[str]: assert resp.json() == ["a", "b", "c"] -@pytest.mark.parametrize("backend", ("pydantic", "attrs")) -def test_regex_validation(backend: Any) -> None: +def test_regex_validation() -> None: # https://github.com/litestar-org/litestar/issues/1860 @get(path="/val_regex", media_type=MediaType.TEXT) async def regex_val(text: Annotated[str, Parameter(title="a or b", pattern="[a|b]")]) -> str: return f"str: {text}" - with create_test_client(route_handlers=[regex_val], _preferred_validation_backend=backend) as client: + with create_test_client(route_handlers=[regex_val]) as client: for letter in ("a", "b"): response = client.get(f"/val_regex?text={letter}") assert response.status_code == HTTP_200_OK diff --git a/tests/unit/test_signature/test_attrs_signature_modelling.py b/tests/unit/test_signature/test_attrs_signature_modelling.py deleted file mode 100644 index 34604e40b2..0000000000 --- a/tests/unit/test_signature/test_attrs_signature_modelling.py +++ /dev/null @@ -1,129 +0,0 @@ -from datetime import date, datetime, time, timedelta, timezone -from typing import Any - -import pytest - -from litestar._signature.models.attrs_signature_model import _converter -from tests import Person, PersonFactory - -now = datetime.now(tz=timezone.utc) -today = now.date() -time_now = time(hour=now.hour, minute=now.minute, second=now.second, microsecond=now.microsecond) -one_minute = timedelta(minutes=1) -person = PersonFactory.build() - - -@pytest.mark.parametrize( - "value,expected", - ( - ("1", True), - (b"1", True), - ("True", True), - (b"True", True), - ("on", True), - (b"on", True), - ("t", True), - (b"t", True), - ("true", True), - (b"true", True), - ("y", True), - (b"y", True), - ("yes", True), - (b"yes", True), - (1, True), - (True, True), - ("0", False), - (b"0", False), - ("False", False), - (b"False", False), - ("f", False), - (b"f", False), - ("false", False), - (b"false", False), - ("n", False), - (b"n", False), - ("no", False), - (b"no", False), - ("off", False), - (b"off", False), - (0, False), - (False, False), - ), -) -def test_cattrs_converter_structure_bool(value: Any, expected: Any) -> None: - result = _converter.structure(value, bool) - assert result == expected - - -def test_cattrs_converter_structure_bool_value_error() -> None: - with pytest.raises(ValueError): - _converter.structure(None, bool) - _converter.structure("foofoofoo", bool) - _converter.structure(object(), bool) - _converter.structure(type, bool) - _converter.structure({}, bool) - _converter.structure([], bool) - - -@pytest.mark.parametrize( - "value,cls,expected", - ( - (now, datetime, now.isoformat()), - (now.isoformat(), datetime, now.isoformat()), - ), -) -def test_cattrs_converter_structure_datetime(value: Any, cls: Any, expected: Any) -> None: - result = _converter.structure(value, cls).isoformat() - assert result == expected - - -@pytest.mark.parametrize( - "value,cls,expected", - ( - (now, date, today.isoformat()), - (now.isoformat(), date, today.isoformat()), - (now.timestamp(), date, today.isoformat()), - (today, date, today.isoformat()), - (today.isoformat(), date, today.isoformat()), - ), -) -def test_cattrs_converter_structure_date(value: Any, cls: Any, expected: Any) -> None: - result = _converter.structure(value, cls).isoformat() - assert result == expected - - -@pytest.mark.parametrize( - "value,cls,expected", - ( - (time_now, time, time_now.isoformat()), - (time_now.isoformat(), time, time_now.isoformat()), - ), -) -def test_cattrs_converter_structure_time(value: Any, cls: Any, expected: Any) -> None: - result = _converter.structure(value, cls).isoformat() - assert result == expected - - -@pytest.mark.parametrize( - "value,cls,expected", - ( - (one_minute, timedelta, one_minute.total_seconds()), - (one_minute.total_seconds(), timedelta, one_minute.total_seconds()), - ("1 minute", timedelta, one_minute.total_seconds()), - ), -) -def test_cattrs_converter_structure_timedelta(value: Any, cls: Any, expected: Any) -> None: - result = _converter.structure(value, cls).total_seconds() - assert result == expected - - -@pytest.mark.parametrize( - "value,cls,expected", - ( - (person, Person, person.dict()), - (person.dict(), Person, person.dict()), - ), -) -def test_cattrs_converter_structure_pydantic(value: Any, cls: Any, expected: Any) -> None: - result = _converter.structure(value, cls).dict() - assert result == expected diff --git a/tests/unit/test_signature/test_parsing.py b/tests/unit/test_signature/test_parsing.py index babcbe152a..bc29064be2 100644 --- a/tests/unit/test_signature/test_parsing.py +++ b/tests/unit/test_signature/test_parsing.py @@ -1,73 +1,45 @@ -from dataclasses import dataclass from types import ModuleType -from typing import Any, Callable, Iterable, List, Literal, Optional, Sequence +from typing import Any, Callable, Iterable, List, Optional, Sequence, Union from unittest.mock import MagicMock import pytest -from attr import define from pydantic import BaseModel -from typing_extensions import TypedDict +from typing_extensions import Annotated -from litestar import get, post -from litestar._signature import create_signature_model -from litestar.di import Provide -from litestar.exceptions import ImproperlyConfiguredException, ValidationException -from litestar.params import Dependency, Parameter -from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR +from litestar import get +from litestar._signature import SignatureModel +from litestar.params import Body, Parameter +from litestar.status_codes import HTTP_200_OK, HTTP_204_NO_CONTENT from litestar.testing import RequestFactory, TestClient, create_test_client +from litestar.types import Empty from litestar.types.helper_types import OptionalSequence from litestar.utils.signature import ParsedSignature -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_parses_values_from_connection_kwargs_without_plugin( - preferred_validation_backend: Literal["attrs", "pydantic"] -) -> None: +def test_parses_values_from_connection_kwargs_without_plugin() -> None: class MyModel(BaseModel): name: str def fn(a: MyModel) -> None: pass - model = create_signature_model( + model = SignatureModel.create( fn=fn, dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, parsed_signature=ParsedSignature.from_fn(fn, {}), ) result = model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a={"name": "my name"}) assert result == {"a": MyModel(name="my name")} -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_parses_values_from_connection_kwargs_raises( - preferred_validation_backend: Literal["attrs", "pydantic"] -) -> None: - def fn(a: int) -> None: - pass - - model = create_signature_model( - fn=fn, - dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, - parsed_signature=ParsedSignature.from_fn(fn, {}), - ) - with pytest.raises(ValidationException): - model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a="not an int") - - -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_create_function_signature_model_parameter_parsing( - preferred_validation_backend: Literal["attrs", "pydantic"] -) -> None: +def test_create_function_signature_model_parameter_parsing() -> None: @get() def my_fn(a: int, b: str, c: Optional[bytes], d: bytes = b"123", e: Optional[dict] = None) -> None: pass - model = create_signature_model( + model = SignatureModel.create( fn=my_fn.fn.value, dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, parsed_signature=ParsedSignature.from_fn(my_fn.fn.value, {}), ) fields = model.fields @@ -77,133 +49,27 @@ def my_fn(a: int, b: str, c: Optional[bytes], d: bytes = b"123", e: Optional[dic assert not fields["b"].is_optional assert fields["c"].annotation is Optional[bytes] assert fields["c"].is_optional - assert fields["c"].default is None + assert fields["c"].default is Empty assert fields["d"].annotation is bytes assert fields["d"].default == b"123" assert fields["e"].annotation == Optional[dict] assert fields["e"].is_optional - assert fields["e"].default is None - - -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_create_signature_validation(preferred_validation_backend: Literal["attrs", "pydantic"]) -> None: - @get() - def my_fn(typed: int, untyped) -> None: # type: ignore - pass + assert fields["e"].default is Empty - with pytest.raises(ImproperlyConfiguredException): - create_signature_model( - fn=my_fn.fn.value, - dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, - parsed_signature=ParsedSignature.from_fn(my_fn.fn.value, {}), - ) - -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_create_function_signature_model_ignore_return_annotation( - preferred_validation_backend: Literal["attrs", "pydantic"] -) -> None: +def test_create_function_signature_model_ignore_return_annotation() -> None: @get(path="/health", status_code=HTTP_204_NO_CONTENT) async def health_check() -> None: return None - signature_model_type = create_signature_model( + signature_model_type = SignatureModel.create( fn=health_check.fn.value, dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, parsed_signature=ParsedSignature.from_fn(health_check.fn.value, {}), ) assert signature_model_type().to_dict() == {} -@pytest.mark.parametrize( - "preferred_validation_backend, error_extra", - ( - ( - "attrs", - [{"key": "dep", "message": "invalid literal for int() with base 10: 'thirteen'"}], - ), - ( - "pydantic", - [{"key": "dep", "message": "value is not a valid integer"}], - ), - ), -) -def test_dependency_validation_failure_raises_500( - preferred_validation_backend: Literal["attrs", "pydantic"], - error_extra: Any, -) -> None: - dependencies = {"dep": Provide(lambda: "thirteen", sync_to_thread=False)} - - @get("/") - def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None: - ... - - with create_test_client( - route_handlers=[test], dependencies=dependencies, _preferred_validation_backend=preferred_validation_backend - ) as client: - response = client.get("/?param=13") - - assert response.json() == { - "detail": "Internal Server Error", - "extra": error_extra, - "status_code": HTTP_500_INTERNAL_SERVER_ERROR, - } - - -@pytest.mark.parametrize( - "preferred_validation_backend, error_extra", - ( - ( - "attrs", - [{"key": "param", "message": "invalid literal for int() with base 10: 'thirteen'", "source": "query"}], - ), - ), -) -def test_validation_failure_raises_400( - preferred_validation_backend: Literal["attrs", "pydantic"], error_extra: Any -) -> None: - dependencies = {"dep": Provide(lambda: 13, sync_to_thread=False)} - - @get("/") - def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None: - ... - - with create_test_client( - route_handlers=[test], dependencies=dependencies, _preferred_validation_backend=preferred_validation_backend - ) as client: - response = client.get("/?param=thirteen") - - assert response.json() == { - "detail": "Validation failed for GET http://testserver.local/?param=thirteen", - "extra": error_extra, - "status_code": 400, - } - - -def test_client_pydantic_backend_error_precedence_over_server_error() -> None: - dependencies = { - "dep": Provide(lambda: "thirteen", sync_to_thread=False), - "optional_dep": Provide(lambda: "thirty-one", sync_to_thread=False), - } - - @get("/") - def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None: - ... - - with create_test_client( - route_handlers=[test], dependencies=dependencies, _preferred_validation_backend="pydantic" - ) as client: - response = client.get("/?param=thirteen") - - assert response.json() == { - "detail": "Validation failed for GET http://testserver.local/?param=thirteen", - "extra": [{"key": "param", "message": "value is not a valid integer", "source": "query"}], - "status_code": 400, - } - - def test_signature_model_resolves_forward_ref_annotations(create_module: Callable[[str], ModuleType]) -> None: module = create_module( """ @@ -244,15 +110,13 @@ def test(a: Optional[List[int]] = Parameter(query="a", default=None, required=Fa assert response.json() == exp -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_field_definition_is_non_string_iterable(preferred_validation_backend: Literal["attrs", "pydantic"]) -> None: +def test_field_definition_is_non_string_iterable() -> None: def fn(a: Iterable[int], b: Optional[Iterable[int]]) -> None: pass - model = create_signature_model( + model = SignatureModel.create( fn=fn, dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, parsed_signature=ParsedSignature.from_fn(fn, {}), ) @@ -260,15 +124,13 @@ def fn(a: Iterable[int], b: Optional[Iterable[int]]) -> None: assert model.fields["b"].is_non_string_iterable -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_field_definition_is_non_string_sequence(preferred_validation_backend: Literal["attrs", "pydantic"]) -> None: +def test_field_definition_is_non_string_sequence() -> None: def fn(a: Sequence[int], b: OptionalSequence[int]) -> None: pass - model = create_signature_model( + model = SignatureModel.create( fn=fn, dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, parsed_signature=ParsedSignature.from_fn(fn, signature_namespace={}), ) @@ -276,242 +138,29 @@ def fn(a: Sequence[int], b: OptionalSequence[int]) -> None: assert model.fields["b"].is_non_string_sequence -@pytest.mark.parametrize("signature_backend", ["pydantic", "attrs"]) @pytest.mark.parametrize("query,expected", [("1", True), ("true", True), ("0", False), ("false", False)]) -def test_query_param_bool(query: str, expected: bool, signature_backend: Literal["pydantic", "attrs"]) -> None: +def test_query_param_bool(query: str, expected: bool) -> None: mock = MagicMock() @get("/") def handler(param: bool) -> None: mock(param) - with create_test_client(route_handlers=[handler], _preferred_validation_backend=signature_backend) as client: + with create_test_client(route_handlers=[handler]) as client: response = client.get(f"/?param={query}") assert response.status_code == HTTP_200_OK, response.json() mock.assert_called_once_with(expected) -@pytest.mark.parametrize("preferred_validation_backend", ("attrs", "pydantic")) -def test_validation_error_exception_key(preferred_validation_backend: Literal["attrs", "pydantic"]) -> None: - class OtherChild(BaseModel): - val: List[int] - - class Child(BaseModel): - val: int - other_val: int - - class Parent(BaseModel): - child: Child - other_child: OtherChild - - def fn(model: Parent) -> None: - pass - - model = create_signature_model( - fn=fn, - dependency_name_set=set(), - preferred_validation_backend=preferred_validation_backend, - parsed_signature=ParsedSignature.from_fn(fn, {}), - ) - - with pytest.raises(ValidationException) as exc_info: - model.parse_values_from_connection_kwargs( - connection=RequestFactory().get(), model={"child": {}, "other_child": {}} - ) - - assert isinstance(exc_info.value.extra, list) - assert exc_info.value.extra[0]["key"] == "model.child.val" - assert exc_info.value.extra[1]["key"] == "model.child.other_val" - assert exc_info.value.extra[2]["key"] == "model.other_child.val" - - -def test_invalid_input_pydantic() -> None: - class OtherChild(BaseModel): - val: List[int] - - class Child(BaseModel): - val: int - other_val: int - - class Parent(BaseModel): - child: Child - other_child: OtherChild - - @post("/") - def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), - ) -> None: - ... - - with create_test_client(route_handlers=[test]) as client: - response = client.post( - "/", - json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, - params={"int_param": "param", "length_param": "d"}, - headers={"X-SOME-INT": "header"}, - cookies={"int-cookie": "cookie"}, - ) - - assert response.status_code == HTTP_400_BAD_REQUEST - - data = response.json() - - assert data - assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, - ] - - -def test_invalid_input_attrs() -> None: - @define - class OtherChild: - val: List[int] - - @define - class Child: - val: int - other_val: int - - @define - class Parent: - child: Child - other_child: OtherChild - - @post("/") - def test( - data: Parent, - int_param: int, - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), - ) -> None: - ... +def test_union_constraint_handling() -> None: + mock = MagicMock() - with create_test_client(route_handlers=[test]) as client: - response = client.post( - "/", - json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, - params={"int_param": "param"}, - headers={"X-SOME-INT": "header"}, - cookies={"int-cookie": "cookie"}, - ) - - assert response.status_code == HTTP_400_BAD_REQUEST - - data = response.json() - - assert data - assert data["extra"] == [ - {"key": "child.val", "message": "invalid literal for int() with base 10: 'a'", "source": "body"}, - {"key": "child.other_val", "message": "invalid literal for int() with base 10: 'b'", "source": "body"}, - {"key": "other_child.val.1", "message": "invalid literal for int() with base 10: 'c'", "source": "body"}, - {"key": "int_param", "message": "invalid literal for int() with base 10: 'param'", "source": "query"}, - {"key": "int_header", "message": "invalid literal for int() with base 10: 'header'", "source": "header"}, - {"key": "int_cookie", "message": "invalid literal for int() with base 10: 'cookie'", "source": "cookie"}, - ] - - -def test_invalid_input_dataclass() -> None: - @dataclass - class OtherChild: - val: List[int] - - @dataclass - class Child: - val: int - other_val: int - - @dataclass - class Parent: - child: Child - other_child: OtherChild - - @post("/") - def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), - ) -> None: - ... + @get("/") + def handler(param: Annotated[Union[str, List[str]], Body(max_length=3, max_items=3)]) -> None: + mock(param) - with create_test_client(route_handlers=[test]) as client: - response = client.post( - "/", - json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, - params={"int_param": "param", "length_param": "d"}, - headers={"X-SOME-INT": "header"}, - cookies={"int-cookie": "cookie"}, - ) - - assert response.status_code == HTTP_400_BAD_REQUEST - - data = response.json() - - assert data - assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, - ] - - -def test_invalid_input_typed_dict() -> None: - class OtherChild(TypedDict): - val: List[int] - - class Child(TypedDict): - val: int - other_val: int - - class Parent(TypedDict): - child: Child - other_child: OtherChild - - @post("/") - def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), - ) -> None: - ... + with create_test_client([handler]) as client: + response = client.get("/?param=foo") - with create_test_client(route_handlers=[test]) as client: - response = client.post( - "/", - json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, - params={"int_param": "param", "length_param": "d"}, - headers={"X-SOME-INT": "header"}, - cookies={"int-cookie": "cookie"}, - ) - - assert response.status_code == HTTP_400_BAD_REQUEST - - data = response.json() - - assert data - assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, - ] + assert response.status_code == 200 + mock.assert_called_once_with("foo") diff --git a/tests/unit/test_signature/test_utils.py b/tests/unit/test_signature/test_utils.py deleted file mode 100644 index cbba6f941c..0000000000 --- a/tests/unit/test_signature/test_utils.py +++ /dev/null @@ -1,27 +0,0 @@ -from __future__ import annotations - -from typing import Any, List, Optional - -import attrs -import pytest -from pydantic import BaseModel - -from litestar._signature.utils import _any_attrs_annotation -from litestar.utils.signature import ParsedSignature - - -@attrs.define -class Foo: - bar: str - - -class Bar(BaseModel): - foo: str - - -@pytest.mark.parametrize("annotation", [Foo, List[Foo], Optional[Foo]]) -def test_any_attrs_annotation(annotation: Any) -> None: - def fn(foo: annotation) -> None: - ... - - assert _any_attrs_annotation(ParsedSignature.from_fn(fn, {"annotation": annotation})) is True diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py new file mode 100644 index 0000000000..c93f4dac0f --- /dev/null +++ b/tests/unit/test_signature/test_validation.py @@ -0,0 +1,319 @@ +from dataclasses import dataclass +from typing import List, Optional + +import pytest +from attr import define +from pydantic import BaseModel +from typing_extensions import TypedDict + +from litestar import get, post +from litestar._signature import SignatureModel +from litestar.di import Provide +from litestar.exceptions import ImproperlyConfiguredException, ValidationException +from litestar.params import Dependency, Parameter +from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR +from litestar.testing import RequestFactory, create_test_client +from litestar.utils.signature import ParsedSignature + + +def test_parses_values_from_connection_kwargs_raises() -> None: + def fn(a: int) -> None: + pass + + model = SignatureModel.create( + fn=fn, + dependency_name_set=set(), + parsed_signature=ParsedSignature.from_fn(fn, {}), + ) + with pytest.raises(ValidationException): + model.parse_values_from_connection_kwargs(connection=RequestFactory().get(), a="not an int") + + +def test_create_signature_validation() -> None: + @get() + def my_fn(typed: int, untyped) -> None: # type: ignore + pass + + with pytest.raises(ImproperlyConfiguredException): + SignatureModel.create( + fn=my_fn.fn.value, + dependency_name_set=set(), + parsed_signature=ParsedSignature.from_fn(my_fn.fn.value, {}), + ) + + +def test_dependency_validation_failure_raises_500() -> None: + dependencies = {"dep": Provide(lambda: "thirteen", sync_to_thread=False)} + + @get("/") + def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None: + ... + + with create_test_client( + route_handlers=[test], + dependencies=dependencies, + ) as client: + response = client.get("/?param=13") + + assert response.json() == {"detail": "Internal Server Error", "status_code": HTTP_500_INTERNAL_SERVER_ERROR} + + +def test_validation_failure_raises_400() -> None: + dependencies = {"dep": Provide(lambda: 13, sync_to_thread=False)} + + @get("/") + def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None: + ... + + with create_test_client(route_handlers=[test], dependencies=dependencies) as client: + response = client.get("/?param=thirteen") + + assert response.json() == { + "detail": "Validation failed for GET http://testserver.local/?param=thirteen", + "extra": [{"key": "param", "message": "Expected `int`, got `str`", "source": "query"}], + "status_code": 400, + } + + +def test_client_backend_error_precedence_over_server_error() -> None: + dependencies = { + "dep": Provide(lambda: "thirteen", sync_to_thread=False), + "optional_dep": Provide(lambda: "thirty-one", sync_to_thread=False), + } + + @get("/") + def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> None: + ... + + with create_test_client(route_handlers=[test], dependencies=dependencies) as client: + response = client.get("/?param=thirteen") + + assert response.json() == { + "detail": "Validation failed for GET http://testserver.local/?param=thirteen", + "extra": [{"key": "param", "message": "Expected `int`, got `str`", "source": "query"}], + "status_code": 400, + } + + +def test_validation_error_exception_key() -> None: + class OtherChild(BaseModel): + val: List[int] + + class Child(BaseModel): + val: int + other_val: int + + class Parent(BaseModel): + child: Child + other_child: OtherChild + + def fn(data: Parent) -> None: + pass + + model = SignatureModel.create( + fn=fn, + dependency_name_set=set(), + parsed_signature=ParsedSignature.from_fn(fn, {}), + ) + + with pytest.raises(ValidationException) as exc_info: + model.parse_values_from_connection_kwargs( + connection=RequestFactory().get(), data={"child": {}, "other_child": {}} + ) + + assert isinstance(exc_info.value.extra, list) + assert exc_info.value.extra[0]["key"] == "child.val" + assert exc_info.value.extra[1]["key"] == "child.other_val" + assert exc_info.value.extra[2]["key"] == "other_child.val" + + +def test_invalid_input_pydantic() -> None: + class OtherChild(BaseModel): + val: List[int] + + class Child(BaseModel): + val: int + other_val: int + + class Parent(BaseModel): + child: Child + other_child: OtherChild + + @post("/") + def test( + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), + ) -> None: + ... + + with create_test_client(route_handlers=[test]) as client: + response = client.post( + "/", + json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, + params={"int_param": "param", "length_param": "d"}, + headers={"X-SOME-INT": "header"}, + cookies={"int-cookie": "cookie"}, + ) + + assert response.status_code == HTTP_400_BAD_REQUEST + + data = response.json() + + assert data + assert data["extra"] == [ + {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, + {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, + {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, + {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, + {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, + {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, + {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + ] + + +def test_invalid_input_attrs() -> None: + @define + class OtherChild: + val: List[int] + + @define + class Child: + val: int + other_val: int + + @define + class Parent: + child: Child + other_child: OtherChild + + @post("/") + def test( + data: Parent, + int_param: int, + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), + ) -> None: + ... + + with create_test_client(route_handlers=[test]) as client: + response = client.post( + "/", + json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, + params={"int_param": "param"}, + headers={"X-SOME-INT": "header"}, + cookies={"int-cookie": "cookie"}, + ) + + assert response.status_code == HTTP_400_BAD_REQUEST + + data = response.json() + + assert data + assert data["extra"] == [ + {"key": "child.val", "message": "invalid literal for int() with base 10: 'a'", "source": "body"}, + {"key": "child.other_val", "message": "invalid literal for int() with base 10: 'b'", "source": "body"}, + {"key": "other_child.val.1", "message": "invalid literal for int() with base 10: 'c'", "source": "body"}, + {"key": "int_param", "message": "invalid literal for int() with base 10: 'param'", "source": "query"}, + {"key": "int_header", "message": "invalid literal for int() with base 10: 'header'", "source": "header"}, + {"key": "int_cookie", "message": "invalid literal for int() with base 10: 'cookie'", "source": "cookie"}, + ] + + +def test_invalid_input_dataclass() -> None: + @dataclass + class OtherChild: + val: List[int] + + @dataclass + class Child: + val: int + other_val: int + + @dataclass + class Parent: + child: Child + other_child: OtherChild + + @post("/") + def test( + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), + ) -> None: + ... + + with create_test_client(route_handlers=[test]) as client: + response = client.post( + "/", + json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, + params={"int_param": "param", "length_param": "d"}, + headers={"X-SOME-INT": "header"}, + cookies={"int-cookie": "cookie"}, + ) + + assert response.status_code == HTTP_400_BAD_REQUEST + + data = response.json() + + assert data + assert data["extra"] == [ + {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, + {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, + {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, + {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, + {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, + {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, + {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + ] + + +def test_invalid_input_typed_dict() -> None: + class OtherChild(TypedDict): + val: List[int] + + class Child(TypedDict): + val: int + other_val: int + + class Parent(TypedDict): + child: Child + other_child: OtherChild + + @post("/") + def test( + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), + ) -> None: + ... + + with create_test_client(route_handlers=[test]) as client: + response = client.post( + "/", + json={"child": {"val": "a", "other_val": "b"}, "other_child": {"val": [1, "c"]}}, + params={"int_param": "param", "length_param": "d"}, + headers={"X-SOME-INT": "header"}, + cookies={"int-cookie": "cookie"}, + ) + + assert response.status_code == HTTP_400_BAD_REQUEST + + data = response.json() + + assert data + assert data["extra"] == [ + {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, + {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, + {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, + {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, + {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, + {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, + {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + ] From 6a7af343fa0916b9e8b995e6b2106b3dbce74d67 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Fri, 14 Jul 2023 21:45:39 +0200 Subject: [PATCH 2/5] Pydantic v2 integration (#1952) * feat(internal): add pydantic v2 support Linting fixes. Update tests/unit/test_kwargs/test_path_params.py fix signature namespace issue support min_length and max_length support min_length and max_length handle constraints on union types fix error message tests chore(signature-model): remove pydantic and attrs signature models chore(signature model): fix python 3.8 compat feat: msgspec signature model. Linting fixes. Update tests/unit/test_kwargs/test_path_params.py fix signature namespace issue support min_length and max_length support min_length and max_length handle constraints on union types fix error message tests chore(signature-model): remove pydantic and attrs signature models feat: msgspec signature model. Linting fixes. Update tests/unit/test_kwargs/test_path_params.py fix signature namespace issue support min_length and max_length support min_length and max_length handle constraints on union types fix error message tests feat(signature-model): add pydantic v2 support * feat(internal): handle pydantic errors --------- Co-authored-by: Peter Schutt --- .github/workflows/ci.yaml | 7 +- .github/workflows/test.yaml | 10 +- .pre-commit-config.yaml | 6 +- docs/examples/startup_and_shutdown.py | 11 +- litestar/_kwargs/extractors.py | 1 + litestar/_kwargs/parameter_definition.py | 5 +- litestar/_openapi/schema_generation/schema.py | 207 ++++++++--------- litestar/_signature/model.py | 9 +- litestar/constants.py | 18 +- litestar/contrib/msgspec.py | 1 + litestar/contrib/piccolo.py | 3 +- litestar/contrib/pydantic.py | 61 +++-- litestar/datastructures/upload_file.py | 20 -- .../dto/factory/_backends/pydantic/backend.py | 24 +- litestar/dto/factory/stdlib/dataclass.py | 1 + litestar/openapi/spec/schema.py | 2 +- litestar/partial.py | 5 +- litestar/plugins.py | 1 + litestar/routes/base.py | 46 +++- litestar/serialization.py | 182 ++++++++------- litestar/types/serialization.py | 19 +- litestar/typing.py | 21 +- litestar/utils/predicates.py | 3 +- poetry.lock | 216 +++++++++++++----- pyproject.toml | 33 +-- .../test_injection_of_generic_models.py | 12 +- tests/examples/test_request_data.py | 62 ++--- tests/unit/test_app.py | 5 +- tests/unit/test_contrib/test_pydantic.py | 3 + tests/unit/test_kwargs/test_header_params.py | 2 +- tests/unit/test_kwargs/test_multipart_data.py | 12 +- tests/unit/test_kwargs/test_path_params.py | 5 +- .../test_reserved_kwargs_injection.py | 5 +- .../test_openapi/test_constrained_fields.py | 102 ++++++++- tests/unit/test_openapi/test_schema.py | 3 +- tests/unit/test_openapi/utils.py | 39 ++-- tests/unit/test_partial.py | 50 +++- tests/unit/test_serialization.py | 198 +++++++++------- tests/unit/test_signature/test_validation.py | 68 +++--- 39 files changed, 922 insertions(+), 556 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e0ca5d787a..c5fa1b4bea 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,11 +29,13 @@ jobs: fail-fast: true matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] + pydantic-version: ["1", "2"] uses: ./.github/workflows/test.yaml with: + coverage: ${{ matrix.python-version == '3.11' && matrix.pydantic-version == '2' }} + integration: ${{ matrix.python-version == '3.11' && matrix.pydantic-version == '2' }} + pydantic-version: ${{ matrix.pydantic-version }} python-version: ${{ matrix.python-version }} - coverage: ${{ matrix.python-version == '3.11' }} - integration: ${{ matrix.python-version == '3.11' }} test-platform-compat: if: github.event_name == 'push' @@ -44,6 +46,7 @@ jobs: uses: ./.github/workflows/test.yaml with: python-version: "3.11" + pydantic-version: "2" os: ${{ matrix.os }} sonar: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 205195c0ff..8f59979718 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -6,6 +6,9 @@ on: python-version: required: true type: string + pydantic-version: + required: true + type: string coverage: required: false type: boolean @@ -45,16 +48,19 @@ jobs: uses: actions/cache@v3 with: path: .venv - key: v1-venv-${{ runner.os }}-${{ inputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + key: v1-venv-${{ runner.os }}-${{ inputs.python-version }}-${{ inputs.pydantic-version }}-${{ hashFiles('**/poetry.lock') }} - name: Load cached pip wheels if: runner.os == 'Windows' id: cached-pip-wheels uses: actions/cache@v3 with: path: ~/.cache - key: cache-${{ runner.os }}-${{ inputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + key: cache-${{ runner.os }}-${{ inputs.python-version }}-${{ inputs.pydantic-version }}-${{ hashFiles('**/poetry.lock') }} - name: Install dependencies run: poetry install --no-interaction + - if: ${{ inputs.pydantic-version == '1' }} + name: Install pydantic v1 + run: source .venv/bin/activate && pip install "pydantic>=1.10.10" - name: Set pythonpath run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV - name: Test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea2863bd0f..ae35620461 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -100,7 +100,8 @@ repos: polyfactory, prometheus_client, psycopg, - pydantic, + pydantic>=2, + pydantic_extra_types, pytest, pytest-lazy-fixture, pytest-mock, @@ -159,7 +160,8 @@ repos: polyfactory, prometheus_client, psycopg, - pydantic, + pydantic>=2, + pydantic_extra_types, pytest, pytest-lazy-fixture, pytest-mock, diff --git a/docs/examples/startup_and_shutdown.py b/docs/examples/startup_and_shutdown.py index a1d912e4d8..847336b571 100644 --- a/docs/examples/startup_and_shutdown.py +++ b/docs/examples/startup_and_shutdown.py @@ -1,16 +1,11 @@ +import os from typing import cast -from pydantic import BaseSettings from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from litestar import Litestar - -class AppSettings(BaseSettings): - DATABASE_URI: str = "postgresql+asyncpg://postgres:mysecretpassword@pg.db:5432/db" - - -settings = AppSettings() +DB_URI = os.environ.get("DATABASE_URI", "postgresql+asyncpg://postgres:mysecretpassword@pg.db:5432/db") def get_db_connection(app: Litestar) -> AsyncEngine: @@ -19,7 +14,7 @@ def get_db_connection(app: Litestar) -> AsyncEngine: If it doesn't exist, creates it and saves it in on the application state object """ if not getattr(app.state, "engine", None): - app.state.engine = create_async_engine(settings.DATABASE_URI) + app.state.engine = create_async_engine(DB_URI) return cast("AsyncEngine", app.state.engine) diff --git a/litestar/_kwargs/extractors.py b/litestar/_kwargs/extractors.py index 68e7e42292..37f2a288bf 100644 --- a/litestar/_kwargs/extractors.py +++ b/litestar/_kwargs/extractors.py @@ -24,6 +24,7 @@ from litestar.dto.interface import DTOInterface from litestar.typing import FieldDefinition + __all__ = ( "body_extractor", "cookies_extractor", diff --git a/litestar/_kwargs/parameter_definition.py b/litestar/_kwargs/parameter_definition.py index afc9b8628e..02b09fcebd 100644 --- a/litestar/_kwargs/parameter_definition.py +++ b/litestar/_kwargs/parameter_definition.py @@ -5,12 +5,11 @@ from litestar.enums import ParamType from litestar.params import ParameterKwarg -__all__ = ("ParameterDefinition", "create_parameter_definition", "merge_parameter_sets") - - if TYPE_CHECKING: from litestar.typing import FieldDefinition +__all__ = ("ParameterDefinition", "create_parameter_definition", "merge_parameter_sets") + class ParameterDefinition(NamedTuple): """Tuple defining a kwarg representing a request parameter.""" diff --git a/litestar/_openapi/schema_generation/schema.py b/litestar/_openapi/schema_generation/schema.py index a70841b890..b205cf3361 100644 --- a/litestar/_openapi/schema_generation/schema.py +++ b/litestar/_openapi/schema_generation/schema.py @@ -34,7 +34,7 @@ from _decimal import Decimal from msgspec.structs import fields as msgspec_struct_fields -from typing_extensions import NotRequired, Required, get_args, get_type_hints +from typing_extensions import Annotated, NotRequired, Required, get_args, get_type_hints from litestar._openapi.schema_generation.constrained_fields import ( create_date_constrained_field_schema, @@ -73,31 +73,9 @@ from litestar.plugins import OpenAPISchemaPluginProtocol try: - from pydantic import ( - BaseModel, - ConstrainedBytes, - ConstrainedDate, - ConstrainedDecimal, - ConstrainedFloat, - ConstrainedFrozenSet, - ConstrainedInt, - ConstrainedList, - ConstrainedSet, - ConstrainedStr, - ) - from pydantic.fields import ModelField + from pydantic import BaseModel except ImportError: BaseModel = Any # type: ignore - ConstrainedBytes = Any # type: ignore - ConstrainedDate = Any # type: ignore - ConstrainedDecimal = Any # type: ignore - ConstrainedFloat = Any # type: ignore - ConstrainedFrozenSet = Any # type: ignore - ConstrainedInt = Any # type: ignore - ConstrainedList = Any # type: ignore - ConstrainedSet = Any # type: ignore - ConstrainedStr = Any # type: ignore - ModelField = Any # type: ignore try: from attrs import AttrsInstance @@ -107,40 +85,8 @@ import pydantic PYDANTIC_TYPE_MAP: dict[type[Any] | None | Any, Schema] = { - pydantic.UUID1: Schema( - type=OpenAPIType.STRING, - format=OpenAPIFormat.UUID, - description="UUID1 string", - ), - pydantic.UUID3: Schema( - type=OpenAPIType.STRING, - format=OpenAPIFormat.UUID, - description="UUID3 string", - ), - pydantic.UUID4: Schema( - type=OpenAPIType.STRING, - format=OpenAPIFormat.UUID, - description="UUID4 string", - ), - pydantic.UUID5: Schema( - type=OpenAPIType.STRING, - format=OpenAPIFormat.UUID, - description="UUID5 string", - ), - pydantic.AnyHttpUrl: Schema( - type=OpenAPIType.STRING, format=OpenAPIFormat.URL, description="must be a valid HTTP based URL" - ), - pydantic.AnyUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL), pydantic.ByteSize: Schema(type=OpenAPIType.INTEGER), - pydantic.DirectoryPath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE), pydantic.EmailStr: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL), - pydantic.FilePath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE), - pydantic.HttpUrl: Schema( - type=OpenAPIType.STRING, - format=OpenAPIFormat.URL, - description="must be a valid HTTP based URL", - max_length=2083, - ), pydantic.IPvAnyAddress: Schema( one_of=[ Schema( @@ -185,48 +131,95 @@ ), pydantic.Json: Schema(type=OpenAPIType.OBJECT, format=OpenAPIFormat.JSON_POINTER), pydantic.NameEmail: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.EMAIL, description="Name and email"), - pydantic.NegativeFloat: Schema(type=OpenAPIType.NUMBER, exclusive_maximum=0.0), - pydantic.NegativeInt: Schema(type=OpenAPIType.INTEGER, exclusive_maximum=0), - pydantic.NonNegativeInt: Schema(type=OpenAPIType.INTEGER, minimum=0), - pydantic.NonPositiveFloat: Schema(type=OpenAPIType.NUMBER, maximum=0.0), - pydantic.PaymentCardNumber: Schema(type=OpenAPIType.STRING, min_length=12, max_length=19), - pydantic.PositiveFloat: Schema(type=OpenAPIType.NUMBER, exclusive_minimum=0.0), - pydantic.PositiveInt: Schema(type=OpenAPIType.INTEGER, exclusive_minimum=0), - pydantic.PostgresDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="postgres DSN"), - pydantic.PyObject: Schema( - type=OpenAPIType.STRING, - description="dot separated path identifying a python object, e.g. 'decimal.Decimal'", - ), - pydantic.RedisDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="redis DSN"), - pydantic.SecretBytes: Schema(type=OpenAPIType.STRING), - pydantic.SecretStr: Schema(type=OpenAPIType.STRING), - pydantic.StrictBool: Schema(type=OpenAPIType.BOOLEAN), - pydantic.StrictBytes: Schema(type=OpenAPIType.STRING), - pydantic.StrictFloat: Schema(type=OpenAPIType.NUMBER), - pydantic.StrictInt: Schema(type=OpenAPIType.INTEGER), - pydantic.StrictStr: Schema(type=OpenAPIType.STRING), } + + if pydantic.VERSION.startswith("1"): + # pydantic v1 values only - some are removed in v2, others are Annotated[] based and require a different + # logic + PYDANTIC_TYPE_MAP.update( + { + # removed in v2 + pydantic.PyObject: Schema( + type=OpenAPIType.STRING, + description="dot separated path identifying a python object, e.g. 'decimal.Decimal'", + ), + # annotated in v2 + pydantic.UUID1: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID1 string", + ), + pydantic.UUID3: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID3 string", + ), + pydantic.UUID4: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID4 string", + ), + pydantic.UUID5: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.UUID, + description="UUID5 string", + ), + pydantic.DirectoryPath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE), + pydantic.AnyUrl: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URL), + pydantic.AnyHttpUrl: Schema( + type=OpenAPIType.STRING, format=OpenAPIFormat.URL, description="must be a valid HTTP based URL" + ), + pydantic.FilePath: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI_REFERENCE), + pydantic.HttpUrl: Schema( + type=OpenAPIType.STRING, + format=OpenAPIFormat.URL, + description="must be a valid HTTP based URL", + max_length=2083, + ), + pydantic.RedisDsn: Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="redis DSN"), + pydantic.PostgresDsn: Schema( + type=OpenAPIType.STRING, format=OpenAPIFormat.URI, description="postgres DSN" + ), + pydantic.SecretBytes: Schema(type=OpenAPIType.STRING), + pydantic.SecretStr: Schema(type=OpenAPIType.STRING), + pydantic.StrictBool: Schema(type=OpenAPIType.BOOLEAN), + pydantic.StrictBytes: Schema(type=OpenAPIType.STRING), + pydantic.StrictFloat: Schema(type=OpenAPIType.NUMBER), + pydantic.StrictInt: Schema(type=OpenAPIType.INTEGER), + pydantic.StrictStr: Schema(type=OpenAPIType.STRING), + pydantic.NegativeFloat: Schema(type=OpenAPIType.NUMBER, exclusive_maximum=0.0), + pydantic.NegativeInt: Schema(type=OpenAPIType.INTEGER, exclusive_maximum=0), + pydantic.NonNegativeInt: Schema(type=OpenAPIType.INTEGER, minimum=0), + pydantic.NonPositiveFloat: Schema(type=OpenAPIType.NUMBER, maximum=0.0), + pydantic.PaymentCardNumber: Schema(type=OpenAPIType.STRING, min_length=12, max_length=19), + pydantic.PositiveFloat: Schema(type=OpenAPIType.NUMBER, exclusive_minimum=0.0), + pydantic.PositiveInt: Schema(type=OpenAPIType.INTEGER, exclusive_minimum=0), + } + ) + except ImportError: PYDANTIC_TYPE_MAP = {} KWARG_DEFINITION_ATTRIBUTE_TO_OPENAPI_PROPERTY_MAP: dict[str, str] = { + "content_encoding": "contentEncoding", "default": "default", - "multiple_of": "multipleOf", + "description": "description", + "enum": "enum", + "examples": "examples", + "external_docs": "externalDocs", + "format": "format", "ge": "minimum", + "gt": "exclusiveMinimum", "le": "maximum", "lt": "exclusiveMaximum", - "gt": "exclusiveMinimum", - "max_length": "maxLength", - "min_length": "minLength", "max_items": "maxItems", + "max_length": "maxLength", "min_items": "minItems", + "min_length": "minLength", + "multiple_of": "multipleOf", "pattern": "pattern", "title": "title", - "description": "description", - "examples": "examples", - "external_docs": "externalDocs", - "content_encoding": "contentEncoding", } TYPE_MAP: dict[type[Any] | None | Any, Schema] = { @@ -590,7 +583,7 @@ def for_plugin(self, field_definition: FieldDefinition, plugin: OpenAPISchemaPlu ) return schema # pragma: no cover - def for_pydantic_model(self, annotation: type[BaseModel], dto_for: ForType | None) -> Schema: + def for_pydantic_model(self, annotation: type[BaseModel], dto_for: ForType | None) -> Schema: # pyright: ignore """Create a schema object for a given pydantic model class. Args: @@ -600,8 +593,15 @@ def for_pydantic_model(self, annotation: type[BaseModel], dto_for: ForType | Non Returns: A schema instance. """ + annotation_hints = get_type_hints(annotation, include_extras=True) model_config = getattr(annotation, "__config__", getattr(annotation, "model_config", Empty)) + model_fields: dict[str, pydantic.fields.FieldInfo] = { + k: getattr(f, "field_info", f) + for k, f in getattr(annotation, "__fields__", getattr(annotation, "model_fields", {})).items() + } + + # pydantic v2 logic if isinstance(model_config, dict): title = model_config.get("title") example = model_config.get("example") @@ -609,22 +609,28 @@ def for_pydantic_model(self, annotation: type[BaseModel], dto_for: ForType | Non title = getattr(model_config, "title", None) example = getattr(model_config, "example", None) + field_definitions = { + f.alias + if f.alias and self.prefer_alias + else k: FieldDefinition.from_kwarg( + annotation=Annotated[annotation_hints[k], f, f.metadata] # pyright: ignore + if pydantic.VERSION.startswith("2") + else Annotated[annotation_hints[k], f], # pyright: ignore + name=f.alias if f.alias and self.prefer_alias else k, + default=f.default if f.default not in UNDEFINED_SENTINELS else Empty, + ) + for k, f in model_fields.items() + } + return Schema( - required=sorted(self.get_field_name(field) for field in annotation.__fields__.values() if field.required), - properties={ - self.get_field_name(f): self.for_field_definition( - FieldDefinition.from_kwarg( - annotation=annotation_hints[f.name], name=self.get_field_name(f), default=f.field_info - ) - ) - for f in annotation.__fields__.values() - }, + required=sorted(f.name for f in field_definitions.values() if f.is_required), + properties={k: self.for_field_definition(f) for k, f in field_definitions.items()}, type=OpenAPIType.OBJECT, title=title or _get_type_schema_name(annotation, dto_for), examples=[Example(example)] if example else None, ) - def for_attrs_class(self, annotation: type[AttrsInstance], dto_for: ForType | None) -> Schema: + def for_attrs_class(self, annotation: type[AttrsInstance], dto_for: ForType | None) -> Schema: # pyright: ignore """Create a schema object for a given attrs class. Args: @@ -782,17 +788,6 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) -> ) return schema - def get_field_name(self, field_definition: ModelField) -> str: - """Get the preferred name for a model field. - - Args: - field_definition: A model field instance. - - Returns: - The preferred name for the field. - """ - return (field_definition.alias or field_definition.name) if self.prefer_alias else field_definition.name - def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference: if field.kwarg_definition and field.is_const and field.has_default and schema.const is None: schema.const = field.default diff --git a/litestar/_signature/model.py b/litestar/_signature/model.py index e52aa8806c..406c95eb76 100644 --- a/litestar/_signature/model.py +++ b/litestar/_signature/model.py @@ -6,15 +6,14 @@ from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct from msgspec.structs import asdict -from pydantic import ValidationError as PydanticValidationError from typing_extensions import Annotated from litestar._signature.utils import create_type_overrides, validate_signature_dependencies from litestar.enums import ScopeType from litestar.exceptions import InternalServerException, ValidationException from litestar.params import DependencyKwarg, KwargDefinition, ParameterKwarg -from litestar.serialization import dec_hook -from litestar.typing import FieldDefinition # noqa: TCH +from litestar.serialization import ExtendedMsgSpecValidationError, dec_hook +from litestar.typing import FieldDefinition # noqa from litestar.utils import make_non_optional_union from litestar.utils.dataclass import simple_asdict from litestar.utils.typing import unwrap_union @@ -137,8 +136,8 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg messages: list[ErrorMessage] = [] try: return convert(kwargs, cls, strict=False, dec_hook=dec_hook).to_dict() - except PydanticValidationError as e: - for exc in e.errors(): + except ExtendedMsgSpecValidationError as e: + for exc in e.errors: keys = [str(loc) for loc in exc["loc"]] message = cls._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection) messages.append(message) diff --git a/litestar/constants.py b/litestar/constants.py index 79fb8fa994..a25fe0abd4 100644 --- a/litestar/constants.py +++ b/litestar/constants.py @@ -1,8 +1,7 @@ +from dataclasses import MISSING from inspect import Signature from typing import Literal -from pydantic.fields import Undefined - from litestar.enums import MediaType from litestar.types import Empty @@ -20,6 +19,19 @@ SCOPE_STATE_NAMESPACE: Literal["__litestar__"] = "__litestar__" SCOPE_STATE_RESPONSE_COMPRESSED: Literal["response_compressed"] = "response_compressed" SKIP_VALIDATION_NAMES = {"request", "socket", "scope", "receive", "send"} -UNDEFINED_SENTINELS = {Undefined, Signature.empty, Empty, Ellipsis} +UNDEFINED_SENTINELS = {Signature.empty, Empty, Ellipsis, MISSING} WEBSOCKET_CLOSE: Literal["websocket.close"] = "websocket.close" WEBSOCKET_DISCONNECT: Literal["websocket.disconnect"] = "websocket.disconnect" + +try: + import pydantic + + if pydantic.VERSION.startswith("2"): + from pydantic_core import PydanticUndefined + else: # pragma: no cover + from pydantic.fields import Undefined as PydanticUndefined # type: ignore + + UNDEFINED_SENTINELS.add(PydanticUndefined) + +except ImportError: # pragma: no cover + pass diff --git a/litestar/contrib/msgspec.py b/litestar/contrib/msgspec.py index 0269d377c3..cc021d75a9 100644 --- a/litestar/contrib/msgspec.py +++ b/litestar/contrib/msgspec.py @@ -17,6 +17,7 @@ from litestar.typing import FieldDefinition + __all__ = ("MsgspecDTO",) T = TypeVar("T", bound="Struct | Collection[Struct]") diff --git a/litestar/contrib/piccolo.py b/litestar/contrib/piccolo.py index ee04b47277..b64c6e6a39 100644 --- a/litestar/contrib/piccolo.py +++ b/litestar/contrib/piccolo.py @@ -12,7 +12,6 @@ from litestar.dto.factory.field import DTOField, Mark from litestar.exceptions import MissingDependencyException from litestar.types import Empty -from litestar.typing import FieldDefinition from litestar.utils.helpers import get_fully_qualified_class_name try: @@ -23,6 +22,8 @@ from piccolo.columns import Column, column_types from piccolo.table import Table +from litestar.typing import FieldDefinition + T = TypeVar("T", bound=Table) __all__ = ("PiccoloDTO",) diff --git a/litestar/contrib/pydantic.py b/litestar/contrib/pydantic.py index 2e49e40712..34f5c97dfc 100644 --- a/litestar/contrib/pydantic.py +++ b/litestar/contrib/pydantic.py @@ -3,36 +3,33 @@ from dataclasses import replace from typing import TYPE_CHECKING, Collection, Generic, TypeVar -from pydantic import BaseModel - from litestar.dto.factory.base import AbstractDTOFactory from litestar.dto.factory.data_structures import DTOFieldDefinition from litestar.dto.factory.field import DTO_FIELD_META_KEY, DTOField from litestar.dto.factory.utils import get_model_type_hints +from litestar.exceptions import MissingDependencyException from litestar.types.empty import Empty from litestar.utils.helpers import get_fully_qualified_class_name if TYPE_CHECKING: - from typing import Any, ClassVar, Generator - - from pydantic.fields import ModelField + from typing import ClassVar, Generator from litestar.typing import FieldDefinition -__all__ = ("PydanticDTO",) -T = TypeVar("T", bound="BaseModel | Collection[BaseModel]") +try: + import pydantic + if pydantic.VERSION.startswith("2"): + from pydantic_core import PydanticUndefined + else: # pragma: no cover + from pydantic.fields import Undefined as PydanticUndefined # type: ignore +except ImportError as e: + raise MissingDependencyException("pydantic") from e -def _determine_default(field_definition: FieldDefinition, model_field: ModelField) -> Any: - if ( - model_field.default is Ellipsis - or model_field.default_factory is not None - or (model_field.default is None and not field_definition.is_optional) - ): - return Empty +__all__ = ("PydanticDTO",) - return model_field.default +T = TypeVar("T", bound="pydantic.BaseModel | Collection[pydantic.BaseModel]") class PydanticDTO(AbstractDTOFactory[T], Generic[T]): @@ -40,28 +37,44 @@ class PydanticDTO(AbstractDTOFactory[T], Generic[T]): __slots__ = () - model_type: ClassVar[type[BaseModel]] + model_type: ClassVar[type[pydantic.BaseModel]] @classmethod - def generate_field_definitions(cls, model_type: type[BaseModel]) -> Generator[DTOFieldDefinition, None, None]: + def generate_field_definitions( + cls, model_type: type[pydantic.BaseModel] + ) -> Generator[DTOFieldDefinition, None, None]: model_field_definitions = get_model_type_hints(model_type) - for key, model_field in model_type.__fields__.items(): - field_definition = model_field_definitions[key] - model_field = model_type.__fields__[key] + + if pydantic.VERSION.startswith("1"): + model_fields: dict[str, pydantic.fields.FieldInfo] = {k: model_field.field_info for k, model_field in model_type.__fields__.items()} # type: ignore + else: + model_fields = dict(model_type.model_fields) + + for field_name, field_info in model_fields.items(): + field_definition = model_field_definitions[field_name] dto_field = (field_definition.extra or {}).pop(DTO_FIELD_META_KEY, DTOField()) + if field_info.default is not PydanticUndefined: + default = field_info.default + elif field_definition.is_optional: + default = None + else: + default = Empty + yield replace( DTOFieldDefinition.from_field_definition( field_definition=field_definition, dto_field=dto_field, unique_model_name=get_fully_qualified_class_name(model_type), - default_factory=model_field.default_factory or Empty, + default_factory=field_info.default_factory + if field_info.default_factory and field_info.default_factory is not PydanticUndefined # type: ignore[comparison-overlap] + else Empty, dto_for=None, ), - default=_determine_default(field_definition, model_field), - name=key, + default=default, + name=field_name, ) @classmethod def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: - return field_definition.is_subclass_of(BaseModel) + return field_definition.is_subclass_of(pydantic.BaseModel) diff --git a/litestar/datastructures/upload_file.py b/litestar/datastructures/upload_file.py index 7ddaf4ff47..07b425514f 100644 --- a/litestar/datastructures/upload_file.py +++ b/litestar/datastructures/upload_file.py @@ -1,20 +1,14 @@ from __future__ import annotations from tempfile import SpooledTemporaryFile -from typing import TYPE_CHECKING, Any from anyio.to_thread import run_sync from litestar.constants import ONE_MEGABYTE -from litestar.openapi.spec.enums import OpenAPIType __all__ = ("UploadFile",) -if TYPE_CHECKING: - from pydantic.fields import ModelField - - class UploadFile: """Representation of a file upload""" @@ -106,17 +100,3 @@ async def close(self) -> None: def __repr__(self) -> str: return f"{self.filename} - {self.content_type}" - - @classmethod - def __modify_schema__(cls, field_schema: dict[str, Any], field: ModelField | None) -> None: - """Create a pydantic JSON schema. - - Args: - field_schema: The schema being generated for the field. - field: the model class field. - - Returns: - None - """ - if field: - field_schema.update({"type": OpenAPIType.STRING.value, "contentMediaType": "application/octet-stream"}) diff --git a/litestar/dto/factory/_backends/pydantic/backend.py b/litestar/dto/factory/_backends/pydantic/backend.py index 4f192a4bbd..61f2ad63f3 100644 --- a/litestar/dto/factory/_backends/pydantic/backend.py +++ b/litestar/dto/factory/_backends/pydantic/backend.py @@ -2,9 +2,8 @@ from typing import TYPE_CHECKING, TypeVar -from pydantic import BaseModel, parse_obj_as - from litestar.dto.factory._backends.abc import AbstractDTOBackend +from litestar.exceptions import MissingDependencyException from litestar.serialization import decode_media_type from .utils import _create_model_for_field_definitions @@ -15,25 +14,38 @@ from litestar.dto.factory._backends.types import FieldDefinitionsType from litestar.dto.interface import ConnectionContext +try: + import pydantic +except ImportError as e: + raise MissingDependencyException("pydantic") from e + __all__ = ("PydanticDTOBackend",) T = TypeVar("T") -class PydanticDTOBackend(AbstractDTOBackend[BaseModel]): +class PydanticDTOBackend(AbstractDTOBackend[pydantic.BaseModel]): __slots__ = () - def create_transfer_model_type(self, unique_name: str, field_definitions: FieldDefinitionsType) -> type[BaseModel]: + def create_transfer_model_type( + self, unique_name: str, field_definitions: FieldDefinitionsType + ) -> type[pydantic.BaseModel]: fqn_uid: str = self._gen_unique_name_id(unique_name) model = _create_model_for_field_definitions(fqn_uid, field_definitions) setattr(model, "__schema_name__", unique_name) return model - def parse_raw(self, raw: bytes, connection_context: ConnectionContext) -> BaseModel | Collection[BaseModel]: + def parse_raw( + self, raw: bytes, connection_context: ConnectionContext + ) -> pydantic.BaseModel | Collection[pydantic.BaseModel]: return decode_media_type( # type:ignore[no-any-return] raw, connection_context.request_encoding_type, type_=self.annotation ) def parse_builtins(self, builtins: Any, connection_context: ConnectionContext) -> Any: - return parse_obj_as(self.annotation, builtins) + return ( + pydantic.TypeAdapter(self.annotation).validate_python(builtins, strict=False) + if pydantic.VERSION.startswith("2") + else pydantic.parse_obj_as(self.annotation, builtins) + ) diff --git a/litestar/dto/factory/stdlib/dataclass.py b/litestar/dto/factory/stdlib/dataclass.py index f92144fef7..4cbe3b91d1 100644 --- a/litestar/dto/factory/stdlib/dataclass.py +++ b/litestar/dto/factory/stdlib/dataclass.py @@ -17,6 +17,7 @@ from litestar.types.protocols import DataclassProtocol from litestar.typing import FieldDefinition + __all__ = ("DataclassDTO", "T") T = TypeVar("T", bound="DataclassProtocol | Collection[DataclassProtocol]") diff --git a/litestar/openapi/spec/schema.py b/litestar/openapi/spec/schema.py index 6583b4b886..661dfd8e04 100644 --- a/litestar/openapi/spec/schema.py +++ b/litestar/openapi/spec/schema.py @@ -14,7 +14,7 @@ from litestar.openapi.spec.xml import XML from litestar.types import DataclassProtocol -__all__ = ("Schema",) +__all__ = ("Schema", "SchemaDataContainer") def _recursive_hash(value: Hashable | Sequence | Mapping | DataclassProtocol | type[DataclassProtocol]) -> int: diff --git a/litestar/partial.py b/litestar/partial.py index 7b1d527fbb..ae5e1db7f1 100644 --- a/litestar/partial.py +++ b/litestar/partial.py @@ -9,11 +9,10 @@ Type, TypeVar, Union, - get_type_hints, ) import msgspec -from typing_extensions import TypeAlias, TypedDict +from typing_extensions import TypeAlias, TypedDict, get_type_hints from litestar.exceptions import ImproperlyConfiguredException from litestar.types.builtin_types import NoneType @@ -55,7 +54,7 @@ def _create_partial_type_name(item: SupportedTypes) -> str: def _extract_type_hints(item: Any) -> tuple[tuple[str, Any], ...]: return tuple( (field_name, field_type) - for field_name, field_type in get_type_hints(item).items() + for field_name, field_type in get_type_hints(item, include_extras=True).items() if not is_class_var(field_type) ) diff --git a/litestar/plugins.py b/litestar/plugins.py index 9c68ca9cea..49c25de305 100644 --- a/litestar/plugins.py +++ b/litestar/plugins.py @@ -14,6 +14,7 @@ from litestar.openapi.spec import Schema from litestar.typing import FieldDefinition + __all__ = ("SerializationPluginProtocol", "InitPluginProtocol", "OpenAPISchemaPluginProtocol", "PluginProtocol") ModelT = TypeVar("ModelT") diff --git a/litestar/routes/base.py b/litestar/routes/base.py index 2769eeb9b7..8cc99b6707 100644 --- a/litestar/routes/base.py +++ b/litestar/routes/base.py @@ -2,18 +2,14 @@ import re from abc import ABC, abstractmethod -from datetime import date, datetime, time, timedelta +from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from pathlib import Path from typing import TYPE_CHECKING, Any, Callable from uuid import UUID -from pydantic.datetime_parse import ( - parse_date, - parse_datetime, - parse_duration, - parse_time, -) +from dateutil.parser import parse as parse_datetime +from pytimeparse.timeparse import timeparse as parse_time from litestar._kwargs import KwargsModel from litestar._signature import get_signature_model @@ -26,6 +22,34 @@ from litestar.handlers.base import BaseRouteHandler from litestar.types import Method, Receive, Scope, Send + +def _parse_datetime(value: str) -> datetime: + try: + return datetime.fromtimestamp(float(value), tz=timezone.utc) + except (ValueError, TypeError): + return parse_datetime(value) + + +def _parse_date(value: str) -> date: + dt = _parse_datetime(value=value) + return date(year=dt.year, month=dt.month, day=dt.day) + + +def _parse_time(value: str) -> time: + try: + return time.fromisoformat(value) + except ValueError: + dt = _parse_datetime(value) + return time(hour=dt.hour, minute=dt.minute, second=dt.second, microsecond=dt.microsecond, tzinfo=dt.tzinfo) + + +def _parse_timedelta(value: str) -> timedelta: + try: + return timedelta(seconds=int(float(value))) + except ValueError: + return timedelta(seconds=parse_time(value)) # pyright: ignore + + param_match_regex = re.compile(r"{(.*?)}") param_type_map = { @@ -47,10 +71,10 @@ int: int, Decimal: Decimal, UUID: UUID, - date: parse_date, - datetime: parse_datetime, - time: parse_time, - timedelta: parse_duration, + date: _parse_date, + datetime: _parse_datetime, + time: _parse_time, + timedelta: _parse_timedelta, } diff --git a/litestar/serialization.py b/litestar/serialization.py index 3f57f56980..a9efae8d8a 100644 --- a/litestar/serialization.py +++ b/litestar/serialization.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import deque +from datetime import date, datetime, time from decimal import Decimal from functools import partial from ipaddress import ( @@ -13,29 +14,20 @@ ) from pathlib import Path, PurePath from re import Pattern -from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, cast, overload from uuid import UUID import msgspec from msgspec import ValidationError -from pydantic import ( - UUID1, - BaseModel, - ByteSize, - ConstrainedBytes, - ConstrainedDate, - NameEmail, - SecretField, - StrictBool, -) -from pydantic.color import Color -from pydantic.json import decimal_encoder from litestar.enums import MediaType from litestar.exceptions import SerializationException from litestar.types import Empty, Serializer +from litestar.utils import is_class_and_subclass, is_pydantic_model_class if TYPE_CHECKING: + from typing_extensions import TypeAlias + from litestar.types import TypeEncodersMap __all__ = ( @@ -47,53 +39,114 @@ "encode_json", "encode_msgpack", "get_serializer", + "ExtendedMsgSpecValidationError", ) T = TypeVar("T") +PYDANTIC_DECODERS: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [] -def _enc_base_model(model: BaseModel) -> Any: - return model.dict() - - -def _enc_byte_size(bytes_: ByteSize) -> int: - return bytes_.real - - -def _enc_constrained_bytes(bytes_: ConstrainedBytes) -> str: - return bytes_.decode("utf-8") +class ExtendedMsgSpecValidationError(ValidationError): + def __init__(self, errors: list[dict[str, Any]]) -> None: + self.errors = errors + super().__init__(errors) -def _enc_constrained_date(date: ConstrainedDate) -> str: - return date.isoformat() +try: + import pydantic -def _enc_pattern(pattern: Pattern[str]) -> Any: - return pattern.pattern + PYDANTIC_ENCODERS: dict[Any, Callable[[Any], Any]] = { + pydantic.EmailStr: str, + pydantic.NameEmail: str, + pydantic.ByteSize: lambda val: val.real, + } + def _dec_pydantic(type_: type[pydantic.BaseModel], value: Any) -> pydantic.BaseModel: + try: + return ( + type_.model_validate(value, strict=False) + if hasattr(type_, "model_validate") + else type_.parse_obj(value) + ) + except pydantic.ValidationError as e: + raise ExtendedMsgSpecValidationError(errors=cast(list[dict[str, Any]], e.errors())) from e + + PYDANTIC_DECODERS.append((is_pydantic_model_class, _dec_pydantic)) + + if pydantic.VERSION.startswith("1"): # pragma: no cover + PYDANTIC_ENCODERS.update( + { + pydantic.BaseModel: lambda model: model.dict(), + pydantic.SecretField: str, + pydantic.StrictBool: int, + pydantic.color.Color: str, # pyright: ignore + pydantic.ConstrainedBytes: lambda val: val.decode("utf-8"), + pydantic.ConstrainedDate: lambda val: val.isoformat(), + } + ) + + PydanticUUIDType: TypeAlias = ( + "type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5]" + ) + + def _dec_pydantic_uuid(type_: PydanticUUIDType, val: Any) -> PydanticUUIDType: + if isinstance(val, str): + val = type_(val) + + elif isinstance(val, (bytes, bytearray)): + try: + val = type_(val.decode()) + except ValueError: + # 16 bytes in big-endian order as the bytes argument fail + # the above check + val = type_(bytes=val) + elif isinstance(val, UUID): + val = type_(str(val)) + + if not isinstance(val, type_): + raise ValidationError(f"Invalid UUID: {val!r}") + + if type_._required_version != val.version: # type: ignore + raise ValidationError(f"Invalid UUID version: {val!r}") + + return cast("PydanticUUIDType", val) + + def _is_pydantic_uuid(value: Any) -> bool: + return is_class_and_subclass(value, (pydantic.UUID1, pydantic.UUID3, pydantic.UUID4, pydantic.UUID5)) + + PYDANTIC_DECODERS.append((_is_pydantic_uuid, _dec_pydantic_uuid)) + else: + from pydantic_extra_types import color + + PYDANTIC_ENCODERS.update( + { + pydantic.BaseModel: lambda model: model.model_dump(mode="json"), + color.Color: str, + pydantic.types.SecretStr: lambda val: "**********" if val else "", + pydantic.types.SecretBytes: lambda val: "**********" if val else "", + } + ) + + +except ImportError: + PYDANTIC_ENCODERS = {} DEFAULT_TYPE_ENCODERS: TypeEncodersMap = { Path: str, PurePath: str, - # pydantic specific types - BaseModel: _enc_base_model, - ByteSize: _enc_byte_size, - NameEmail: str, - Color: str, - SecretField: str, - ConstrainedBytes: _enc_constrained_bytes, - ConstrainedDate: _enc_constrained_date, IPv4Address: str, IPv4Interface: str, IPv4Network: str, IPv6Address: str, IPv6Interface: str, IPv6Network: str, - # pydantic compatibility + datetime: lambda val: val.isoformat(), + date: lambda val: val.isoformat(), + time: lambda val: val.isoformat(), deque: list, - Decimal: decimal_encoder, - StrictBool: int, - Pattern: _enc_pattern, + Decimal: lambda val: int(val) if val.as_tuple().exponent >= 0 else float(val), + Pattern: lambda val: val.pattern, # support subclasses of stdlib types, If no previous type matched, these will be # the last type in the mro, so we use this to (attempt to) convert a subclass into # its base class. # see https://github.com/jcrist/msgspec/issues/248 @@ -103,6 +156,8 @@ def _enc_pattern(pattern: Pattern[str]) -> Any: float: float, set: set, frozenset: frozenset, + bytes: bytes, + **PYDANTIC_ENCODERS, } @@ -119,42 +174,15 @@ def default_serializer(value: Any, type_encoders: Mapping[Any, Callable[[Any], A """ if type_encoders is None: type_encoders = DEFAULT_TYPE_ENCODERS + for base in value.__class__.__mro__[:-1]: try: encoder = type_encoders[base] + return encoder(value) except KeyError: continue - return encoder(value) - raise TypeError(f"Unsupported type: {type(value)!r}") - - -PydanticUUIDType = TypeVar("PydanticUUIDType", bound="UUID1") - -def _dec_pydantic_uuid(type_: type[PydanticUUIDType], val: Any) -> PydanticUUIDType: - if isinstance(val, str): - val = type_(val) - elif isinstance(val, (bytes, bytearray)): - try: - val = type_(val.decode()) - except ValueError: - # 16 bytes in big-endian order as the bytes argument fail - # the above check - val = type_(bytes=val) - elif isinstance(val, UUID): - val = type_(str(val)) - - if not isinstance(val, type_): - raise ValidationError(f"Invalid UUID: {val!r}") - - if type_._required_version != val.version: # type:ignore[attr-defined] - raise ValidationError(f"Invalid UUID version: {val!r}") - - return val - - -def _dec_pydantic(type_: type[BaseModel], value: Any) -> BaseModel: - return type_.parse_obj(value) + raise TypeError(f"Unsupported type: {type(value)!r}") def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover @@ -170,14 +198,16 @@ def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover from litestar.datastructures.state import ImmutableState - if issubclass(type_, UUID1): - return _dec_pydantic_uuid(type_, value) if isinstance(value, type_): return value - if issubclass(type_, BaseModel): - return _dec_pydantic(type_, value) + + for predicate, decoder in PYDANTIC_DECODERS: + if predicate(type_): + return decoder(type_, value) + if issubclass(type_, (Path, PurePath, ImmutableState, UUID)): return type_(value) + raise TypeError(f"Unsupported type: {type(value)!r}") @@ -187,7 +217,7 @@ def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover _msgspec_msgpack_decoder = msgspec.msgpack.Decoder(dec_hook=dec_hook) -def encode_json(obj: Any, default: Callable[[Any], Any] | None = default_serializer) -> bytes: +def encode_json(obj: Any, default: Callable[[Any], Any] | None = None) -> bytes: """Encode a value into JSON. Args: @@ -201,9 +231,7 @@ def encode_json(obj: Any, default: Callable[[Any], Any] | None = default_seriali SerializationException: If error encoding ``obj``. """ try: - if default is None or default is default_serializer: - return _msgspec_json_encoder.encode(obj) - return msgspec.json.encode(obj, enc_hook=default) + return msgspec.json.encode(obj, enc_hook=default) if default else _msgspec_json_encoder.encode(obj) except (TypeError, msgspec.EncodeError) as msgspec_error: raise SerializationException(str(msgspec_error)) from msgspec_error diff --git a/litestar/types/serialization.py b/litestar/types/serialization.py index b9af13a3be..0ffe888da5 100644 --- a/litestar/types/serialization.py +++ b/litestar/types/serialization.py @@ -25,10 +25,7 @@ from pydantic import ( BaseModel, ByteSize, - ConstrainedBytes, - ConstrainedDate, NameEmail, - SecretField, StrictBool, ) from pydantic.color import Color @@ -36,6 +33,18 @@ from litestar.types import DataclassProtocol + try: # pragma: no cover + # pydantic v1 only values + from pydantic import ( + ConstrainedBytes, + ConstrainedDate, + SecretField, + ) + except ImportError: + ConstrainedBytes = BaseModel + ConstrainedDate = BaseModel + SecretField = BaseModel + __all__ = ( "LitestarEncodableType", "EncodableBuiltinType", @@ -55,8 +64,6 @@ "IPv4Address | IPv4Interface | IPv4Network | IPv6Address | IPv6Interface | IPv6Network" ) EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct" -EncodablePydanticType: TypeAlias = ( - "BaseModel | ByteSize | ConstrainedBytes | ConstrainedDate | NameEmail | SecretField | StrictBool | Color" -) +EncodablePydanticType: TypeAlias = "BaseModel | ByteSize | ConstrainedBytes | ConstrainedDate | NameEmail | SecretField | StrictBool | Color" # type: ignore LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | EncodablePydanticType" diff --git a/litestar/typing.py b/litestar/typing.py index 5a5ac6cff5..34ca5ed901 100644 --- a/litestar/typing.py +++ b/litestar/typing.py @@ -1,11 +1,10 @@ from __future__ import annotations -from collections import abc +from collections import abc, deque from dataclasses import dataclass, replace from inspect import Parameter, Signature from typing import Any, AnyStr, Collection, ForwardRef, Literal, Mapping, Sequence, TypeVar, cast -from pydantic.fields import FieldInfo from typing_extensions import Annotated, NotRequired, Required, get_args, get_origin from litestar.exceptions import ImproperlyConfiguredException @@ -29,6 +28,11 @@ unwrap_annotation, ) +try: + from pydantic.fields import FieldInfo +except ImportError: + FieldInfo = Empty # type: ignore + __all__ = ("FieldDefinition",) T = TypeVar("T", bound=KwargDefinition) @@ -64,7 +68,10 @@ def _parse_metadata(value: Any, is_sequence_container: bool, extra: dict[str, An Returns: A dictionary of constraints, which fulfill the kwargs of a KwargDefinition class. """ - extra = cast("dict[str, Any]", extra or getattr(value, "extra", None) or {}) + extra = { + **cast("dict[str, Any]", extra or getattr(value, "extra", None) or {}), + **(getattr(value, "json_schema_extra", None) or {}), + } if example := extra.pop("example", None): example_list = [Example(value=example)] elif examples := getattr(value, "examples", None): @@ -116,7 +123,13 @@ def _traverse_metadata( """ constraints: dict[str, Any] = {} for value in metadata: - if is_annotated_type(value) and (type_args := [v for v in get_args(value) if v is not None]): + if isinstance(value, (list, set, frozenset, deque)): + constraints.update( + _traverse_metadata( + metadata=cast("Sequence[Any]", value), is_sequence_container=is_sequence_container, extra=extra + ) + ) + elif is_annotated_type(value) and (type_args := [v for v in get_args(value) if v is not None]): # annotated values can be nested inside other annotated values # this behaviour is buggy in python 3.8, hence we need to guard here. if len(type_args) > 1: diff --git a/litestar/utils/predicates.py b/litestar/utils/predicates.py index 0537a2c1ee..7a71f4119a 100644 --- a/litestar/utils/predicates.py +++ b/litestar/utils/predicates.py @@ -73,6 +73,7 @@ "is_pydantic_constrained_field", "is_pydantic_model_class", "is_pydantic_model_instance", + "is_struct_class", "is_sync_or_async_generator", "is_typed_dict", "is_union", @@ -350,7 +351,7 @@ def is_pydantic_constrained_field( ) return any( - is_class_and_subclass(annotation, constrained_type) + is_class_and_subclass(annotation, constrained_type) # type: ignore[arg-type] for constrained_type in ( ConstrainedBytes, ConstrainedDate, diff --git a/poetry.lock b/poetry.lock index 5b0d6363ea..06e8d113b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -339,27 +339,22 @@ tzdata = ["tzdata"] [[package]] name = "beanie" -version = "1.20.0" +version = "1.19.2" description = "Asynchronous Python ODM for MongoDB" optional = false python-versions = ">=3.7,<4.0" files = [ - {file = "beanie-1.20.0-py3-none-any.whl", hash = "sha256:3212cfc20c1b30d5b4ae9d9dfc02eda0cbc09c26f23c7aeb079baa1d90889a57"}, - {file = "beanie-1.20.0.tar.gz", hash = "sha256:d83004a8330dab9055ea57a6247a324ed40011f9cdcd527aaf4c0dc9253e9d21"}, + {file = "beanie-1.19.2-py3-none-any.whl", hash = "sha256:fba803e954eff3f036db2236c1e02fe7afffe4330db2108d405ac820076e3ae1"}, + {file = "beanie-1.19.2.tar.gz", hash = "sha256:1894c984a9f129bce03e19a9cb52ad47002bb3b4ea1a6b7af3a8a65978a35b78"}, ] [package.dependencies] click = ">=7" lazy-model = ">=0.0.3" -motor = ">=2.5.0,<4.0.0" -pydantic = ">=1.10.0,<2.0.0" +motor = ">=2.5,<4.0" +pydantic = ">=1.10.0" toml = "*" -[package.extras] -doc = ["Markdown (>=3.3)", "Pygments (>=2.8.0)", "jinja2 (>=3.0.3)", "mkdocs (>=1.4)", "mkdocs-material (>=9.0)", "pydoc-markdown (==4.6)"] -queue = ["beanie-batteries-queue (>=0.2)"] -test = ["asgi-lifespan (>=1.0.1)", "dnspython (>=2.1.0)", "fastapi (>=0.78.0)", "flake8 (>=3)", "httpx (>=0.23.0)", "pre-commit (>=2.3.0)", "pyright (>=0)", "pytest (>=6.0.0)", "pytest-asyncio (>=0.21.0)", "pytest-cov (>=2.8.1)"] - [[package]] name = "beautifulsoup4" version = "4.12.2" @@ -2505,13 +2500,13 @@ files = [ [[package]] name = "piccolo" -version = "0.118.0" +version = "0.116.0" description = "A fast, user friendly ORM and query builder which supports asyncio." optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.7.0" files = [ - {file = "piccolo-0.118.0-py3-none-any.whl", hash = "sha256:2ca856f76cf591705cb1834b3d073068736775c280ffdd700cf1fa765f61e55a"}, - {file = "piccolo-0.118.0.tar.gz", hash = "sha256:67ebc76b439dd25472fa024b0373bc42e5676612caf918906bf0ee5819b074d3"}, + {file = "piccolo-0.116.0-py3-none-any.whl", hash = "sha256:c32d50425c283adaa3df3c57fa08df83a076babecf71b46d23ab85c5f29bfadd"}, + {file = "piccolo-0.116.0.tar.gz", hash = "sha256:e41440a89dc3b7a5706f71497375f620279de0845789fdb9a962b690451d760b"}, ] [package.dependencies] @@ -2519,7 +2514,7 @@ black = "*" colorama = ">=0.4.0" inflection = ">=0.5.1" Jinja2 = ">=2.11.0" -pydantic = {version = ">=1.6,<2.0", extras = ["email"]} +pydantic = {version = ">=1.6", extras = ["email"]} targ = ">=0.3.7" typing-extensions = ">=4.3.0" @@ -2762,56 +2757,153 @@ files = [ [[package]] name = "pydantic" -version = "1.10.11" -description = "Data validation and settings management using python type hints" +version = "2.0.3" +description = "Data validation using Python type hints" optional = false python-versions = ">=3.7" files = [ - {file = "pydantic-1.10.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ff44c5e89315b15ff1f7fdaf9853770b810936d6b01a7bcecaa227d2f8fe444f"}, - {file = "pydantic-1.10.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a6c098d4ab5e2d5b3984d3cb2527e2d6099d3de85630c8934efcfdc348a9760e"}, - {file = "pydantic-1.10.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16928fdc9cb273c6af00d9d5045434c39afba5f42325fb990add2c241402d151"}, - {file = "pydantic-1.10.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0588788a9a85f3e5e9ebca14211a496409cb3deca5b6971ff37c556d581854e7"}, - {file = "pydantic-1.10.11-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e9baf78b31da2dc3d3f346ef18e58ec5f12f5aaa17ac517e2ffd026a92a87588"}, - {file = "pydantic-1.10.11-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:373c0840f5c2b5b1ccadd9286782852b901055998136287828731868027a724f"}, - {file = "pydantic-1.10.11-cp310-cp310-win_amd64.whl", hash = "sha256:c3339a46bbe6013ef7bdd2844679bfe500347ac5742cd4019a88312aa58a9847"}, - {file = "pydantic-1.10.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:08a6c32e1c3809fbc49debb96bf833164f3438b3696abf0fbeceb417d123e6eb"}, - {file = "pydantic-1.10.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a451ccab49971af043ec4e0d207cbc8cbe53dbf148ef9f19599024076fe9c25b"}, - {file = "pydantic-1.10.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b02d24f7b2b365fed586ed73582c20f353a4c50e4be9ba2c57ab96f8091ddae"}, - {file = "pydantic-1.10.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f34739a89260dfa420aa3cbd069fbcc794b25bbe5c0a214f8fb29e363484b66"}, - {file = "pydantic-1.10.11-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e297897eb4bebde985f72a46a7552a7556a3dd11e7f76acda0c1093e3dbcf216"}, - {file = "pydantic-1.10.11-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d185819a7a059550ecb85d5134e7d40f2565f3dd94cfd870132c5f91a89cf58c"}, - {file = "pydantic-1.10.11-cp311-cp311-win_amd64.whl", hash = "sha256:4400015f15c9b464c9db2d5d951b6a780102cfa5870f2c036d37c23b56f7fc1b"}, - {file = "pydantic-1.10.11-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2417de68290434461a266271fc57274a138510dca19982336639484c73a07af6"}, - {file = "pydantic-1.10.11-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:331c031ba1554b974c98679bd0780d89670d6fd6f53f5d70b10bdc9addee1713"}, - {file = "pydantic-1.10.11-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8268a735a14c308923e8958363e3a3404f6834bb98c11f5ab43251a4e410170c"}, - {file = "pydantic-1.10.11-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:44e51ba599c3ef227e168424e220cd3e544288c57829520dc90ea9cb190c3248"}, - {file = "pydantic-1.10.11-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d7781f1d13b19700b7949c5a639c764a077cbbdd4322ed505b449d3ca8edcb36"}, - {file = "pydantic-1.10.11-cp37-cp37m-win_amd64.whl", hash = "sha256:7522a7666157aa22b812ce14c827574ddccc94f361237ca6ea8bb0d5c38f1629"}, - {file = "pydantic-1.10.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc64eab9b19cd794a380179ac0e6752335e9555d214cfcb755820333c0784cb3"}, - {file = "pydantic-1.10.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8dc77064471780262b6a68fe67e013298d130414d5aaf9b562c33987dbd2cf4f"}, - {file = "pydantic-1.10.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe429898f2c9dd209bd0632a606bddc06f8bce081bbd03d1c775a45886e2c1cb"}, - {file = "pydantic-1.10.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:192c608ad002a748e4a0bed2ddbcd98f9b56df50a7c24d9a931a8c5dd053bd3d"}, - {file = "pydantic-1.10.11-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ef55392ec4bb5721f4ded1096241e4b7151ba6d50a50a80a2526c854f42e6a2f"}, - {file = "pydantic-1.10.11-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e0bb6efe86281623abbeeb0be64eab740c865388ee934cd3e6a358784aca6e"}, - {file = "pydantic-1.10.11-cp38-cp38-win_amd64.whl", hash = "sha256:265a60da42f9f27e0b1014eab8acd3e53bd0bad5c5b4884e98a55f8f596b2c19"}, - {file = "pydantic-1.10.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:469adf96c8e2c2bbfa655fc7735a2a82f4c543d9fee97bd113a7fb509bf5e622"}, - {file = "pydantic-1.10.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e6cbfbd010b14c8a905a7b10f9fe090068d1744d46f9e0c021db28daeb8b6de1"}, - {file = "pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abade85268cc92dff86d6effcd917893130f0ff516f3d637f50dadc22ae93999"}, - {file = "pydantic-1.10.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9738b0f2e6c70f44ee0de53f2089d6002b10c33264abee07bdb5c7f03038303"}, - {file = "pydantic-1.10.11-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:787cf23e5a0cde753f2eabac1b2e73ae3844eb873fd1f5bdbff3048d8dbb7604"}, - {file = "pydantic-1.10.11-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:174899023337b9fc685ac8adaa7b047050616136ccd30e9070627c1aaab53a13"}, - {file = "pydantic-1.10.11-cp39-cp39-win_amd64.whl", hash = "sha256:1954f8778489a04b245a1e7b8b22a9d3ea8ef49337285693cf6959e4b757535e"}, - {file = "pydantic-1.10.11-py3-none-any.whl", hash = "sha256:008c5e266c8aada206d0627a011504e14268a62091450210eda7c07fabe6963e"}, - {file = "pydantic-1.10.11.tar.gz", hash = "sha256:f66d479cf7eb331372c470614be6511eae96f1f120344c25f3f9bb59fb1b5528"}, + {file = "pydantic-2.0.3-py3-none-any.whl", hash = "sha256:614eb3321eb600c81899a88fa9858b008e3c79e0d4f1b49ab1f516b4b0c27cfb"}, + {file = "pydantic-2.0.3.tar.gz", hash = "sha256:94f13e0dcf139a5125e88283fc999788d894e14ed90cf478bcc2ee50bd4fc630"}, ] [package.dependencies] -email-validator = {version = ">=1.0.3", optional = true, markers = "extra == \"email\""} -typing-extensions = ">=4.2.0" +annotated-types = ">=0.4.0" +email-validator = {version = ">=2.0.0", optional = true, markers = "extra == \"email\""} +pydantic-core = "2.3.0" +typing-extensions = ">=4.6.1" + +[package.extras] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.3.0" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_core-2.3.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:4542c98b8364b976593703a2dda97377433b102f380b61bc3a2cbc2fbdae1d1f"}, + {file = "pydantic_core-2.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9342de50824b40f55d2600f66c6f9a91a3a24851eca39145a749a3dc804ee599"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:539432f911686cb80284c30b33eaf9f4fd9a11e1111fe0dc98fdbdce69b49821"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38a0e7ee65c8999394d92d9c724434cb629279d19844f2b69d9bbc46dc8b8b61"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_armv7l.whl", hash = "sha256:e3ed6834cc005798187a56c248a2240207cb8ffdda1c89e9afda4c3d526c2ea0"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_ppc64le.whl", hash = "sha256:e72ac299a6bf732a60852d052acf3999d234686755a02ba111e85e7ebf8155b1"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_24_s390x.whl", hash = "sha256:616b3451b05ca63b8f433c627f68046b39543faeaa4e50d8c6699a2a1e4b85a5"}, + {file = "pydantic_core-2.3.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:adcb9c8848e15c613e483e0b99767ae325af27fe0dbd866df01fe5849d06e6e1"}, + {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:464bf799b422be662e5e562e62beeffc9eaa907d381a9d63a2556615bbda286d"}, + {file = "pydantic_core-2.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4638ebc17de08c2f3acba557efeb6f195c88b7299d8c55c0bb4e20638bbd4d03"}, + {file = "pydantic_core-2.3.0-cp310-none-win32.whl", hash = "sha256:9ff322c7e1030543d35d83bb521b69114d3d150750528d7757544f639def9ad6"}, + {file = "pydantic_core-2.3.0-cp310-none-win_amd64.whl", hash = "sha256:4824eb018f0a4680b1e434697a9bf3f41c7799b80076d06530cbbd212e040ccc"}, + {file = "pydantic_core-2.3.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:0aa429578e23885b3984c49d687cd05ab06f0b908ea1711a8bf7e503b7f97160"}, + {file = "pydantic_core-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20d710c1f79af930b8891bcebd84096798e4387ab64023ef41521d58f21277d3"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:309f45d4d7481d6f09cb9e35c72caa0e50add4a30bb08c04c5fe5956a0158633"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bcfb7be905aa849bd882262e1df3f75b564e2f708b4b4c7ad2d3deaf5410562"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_armv7l.whl", hash = "sha256:85cd9c0af34e371390e3cb2f3a470b0b40cc07568c1e966c638c49062be6352d"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_ppc64le.whl", hash = "sha256:37c5028cebdf731298724070838fb3a71ef1fbd201d193d311ac2cbdbca25a23"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_24_s390x.whl", hash = "sha256:e4208f23f12d0ad206a07a489ef4cb15722c10b62774c4460ee4123250be938e"}, + {file = "pydantic_core-2.3.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c24465dd11b65c8510f251b095fc788c7c91481c81840112fe3f76c30793a455"}, + {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3cd7ee8bbfab277ab56e272221886fd33a1b5943fbf45ae9195aa6a48715a8a0"}, + {file = "pydantic_core-2.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0fc7e0b056b66cc536e97ef60f48b3b289f6b3b62ac225afd4b22a42434617bf"}, + {file = "pydantic_core-2.3.0-cp311-none-win32.whl", hash = "sha256:4788135db4bd83a5edc3522b11544b013be7d25b74b155e08dd3b20cd6663bbb"}, + {file = "pydantic_core-2.3.0-cp311-none-win_amd64.whl", hash = "sha256:f93c867e5e85584a28c6a6feb6f2086d717266eb5d1210d096dd717b7f4dec04"}, + {file = "pydantic_core-2.3.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:73f62bb7fd862d9bcd886e10612bade6fe042eda8b47e8c129892bcfb7b45e84"}, + {file = "pydantic_core-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4d889d498fce64bfcd8adf1a78579a7f626f825cbeb2956a24a29b35f9a1df32"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d55e38a89ec2ae17b2fa7ffeda6b70f63afab1888bd0d57aaa7b7879760acb4"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1aefebb506bc1fe355d91d25f12bcdea7f4d7c2d9f0f6716dd025543777c99a5"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_armv7l.whl", hash = "sha256:6441a29f42585f085db0c04cd0557d4cbbb46fa68a0972409b1cfe9f430280c1"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_ppc64le.whl", hash = "sha256:47e8f034be31390a8f525431eb5e803a78ce7e2e11b32abf5361a972e14e6b61"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_24_s390x.whl", hash = "sha256:ad814864aba263be9c83ada44a95f72d10caabbf91589321f95c29c902bdcff0"}, + {file = "pydantic_core-2.3.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9eff3837d447fccf2ac38c259b14ab9cbde700df355a45a1f3ff244d5e78f8b6"}, + {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:534f3f63c000f08050c6f7f4378bf2b52d7ba9214e9d35e3f60f7ad24a4d6425"}, + {file = "pydantic_core-2.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ef6a222d54f742c24f6b143aab088702db3a827b224e75b9dd28b38597c595fe"}, + {file = "pydantic_core-2.3.0-cp312-none-win32.whl", hash = "sha256:4e26944e64ecc1d7b19db954c0f7b471f3b141ec8e1a9f57cfe27671525cd248"}, + {file = "pydantic_core-2.3.0-cp312-none-win_amd64.whl", hash = "sha256:019c5c41941438570dfc7d3f0ae389b2425add1775a357ce1e83ed1434f943d6"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:27c1bbfb9d84a75cf33b7f19b53c29eb7ead99b235fce52aced5507174ab8f98"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:7cb496e934b71f1ade844ab91d6ccac78a3520e5df02fdb2357f85a71e541e69"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5af2d43b1978958d91351afbcc9b4d0cfe144c46c61740e82aaac8bb39ab1a4d"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3097c39d7d4e8dba2ef86de171dcccad876c36d8379415ba18a5a4d0533510"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_armv7l.whl", hash = "sha256:dd3b023f3317dbbbc775e43651ce1a31a9cea46216ad0b5be37afc18a2007699"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_ppc64le.whl", hash = "sha256:27babb9879bf2c45ed655d02639f4c30e2b9ef1b71ce59c2305bbf7287910a18"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_24_s390x.whl", hash = "sha256:2183a9e18cdc0de53bdaa1675f237259162abeb62d6ac9e527c359c1074dc55d"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c089d8e7f1b4db08b2f8e4107304eec338df046275dad432635a9be9531e2fc8"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2f10aa5452b865818dd0137f568d443f5e93b60a27080a01aa4b7512c7ba13a3"}, + {file = "pydantic_core-2.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:f642313d559f9d9a00c4de6820124059cc3342a0d0127b18301de2c680d5ea40"}, + {file = "pydantic_core-2.3.0-cp37-none-win32.whl", hash = "sha256:45327fc57afbe3f2c3d7f54a335d5cecee8a9fdb3906a2fbed8af4092f4926df"}, + {file = "pydantic_core-2.3.0-cp37-none-win_amd64.whl", hash = "sha256:e427b66596a6441a5607dfc0085b47d36073f88da7ac48afd284263b9b99e6ce"}, + {file = "pydantic_core-2.3.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:0b3d781c71b8bfb621ef23b9c874933e2cd33237c1a65cc20eeb37437f8e7e18"}, + {file = "pydantic_core-2.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ad46027dbd5c1db87dc0b49becbe23093b143a20302028d387dae37ee5ef95f5"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39aa09ed7ce2a648c904f79032d16dda29e6913112af8465a7bf710eef23c7ca"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05b4bf8c58409586a7a04c858a86ab10f28c6c1a7c33da65e0326c59d5b0ab16"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_armv7l.whl", hash = "sha256:ba2b807d2b62c446120906b8580cddae1d76d3de4efbb95ccc87f5e35c75b4b2"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_ppc64le.whl", hash = "sha256:ea955e4ed21f4bbb9b83fea09fc6af0bed82e69ecf6b35ec89237a0a49633033"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_24_s390x.whl", hash = "sha256:06884c07956526ac9ebfef40fe21a11605569b8fc0e2054a375fb39c978bf48f"}, + {file = "pydantic_core-2.3.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f868e731a18b403b88aa434d960489ceeed0ddeb44ebc02389540731a67705e0"}, + {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cb08fab0fc1db15c277b72e33ac74ad9c0c789413da8984a3eacb22a94b42ef4"}, + {file = "pydantic_core-2.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6ca34c29fbd6592de5fd39e80c1993634d704c4e7e14ba54c87b2c7c53da68fe"}, + {file = "pydantic_core-2.3.0-cp38-none-win32.whl", hash = "sha256:cd782807d35c8a41aaa7d30b5107784420eefd9fdc1c760d86007d43ae00b15d"}, + {file = "pydantic_core-2.3.0-cp38-none-win_amd64.whl", hash = "sha256:01f56d5ee70b1d39c0fd08372cc5142274070ab7181d17c86035f130eebc05b8"}, + {file = "pydantic_core-2.3.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:78b1ac0151271ce62bc2b33755f1043eda6a310373143a2f27e2bcd3d5fc8633"}, + {file = "pydantic_core-2.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:64bfd2c35a2c350f73ac52dc134d8775f93359c4c969280a6fe5301b5b6e7431"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:937c0fe9538f1212b62df6a68f8d78df3572fe3682d9a0dd8851eac8a4e46063"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d965c7c4b40d1cedec9188782e98bd576f9a04868835604200c3a6e817b824f"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_armv7l.whl", hash = "sha256:ad442b8585ed4a3c2d22e4bf7b465d9b7d281e055b09719a8aeb5b576422dc9b"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_ppc64le.whl", hash = "sha256:4bf20c9722821fce766e685718e739deeccc60d6bc7be5029281db41f999ee0c"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_24_s390x.whl", hash = "sha256:f3dd5333049b5b3faa739e0f40b77cc8b7a1aded2f2da0e28794c81586d7b08a"}, + {file = "pydantic_core-2.3.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dc5f516b24d24bc9e8dd9305460899f38302b3c4f9752663b396ef9848557bf"}, + {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:055f7ea6b1fbb37880d66d70eefd22dd319b09c79d2cb99b1dbfeb34b653b0b2"}, + {file = "pydantic_core-2.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:af693a89db6d6ac97dd84dd7769b3f2bd9007b578127d0e7dda03053f4d3b34b"}, + {file = "pydantic_core-2.3.0-cp39-none-win32.whl", hash = "sha256:f60e31e3e15e8c294bf70c60f8ae4d0c3caf3af8f26466e9aa8ea4c01302749b"}, + {file = "pydantic_core-2.3.0-cp39-none-win_amd64.whl", hash = "sha256:2b79f3681481f4424d7845cc7a261d5a4baa810d656b631fa844dc9967b36a7b"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-macosx_10_7_x86_64.whl", hash = "sha256:a666134b41712e30a71afaa26deeb4da374179f769fa49784cdf0e7698880fab"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c119e9227487ad3d7c3c737d896afe548a6be554091f9745da1f4b489c40561"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73929a2fb600a2333fce2efd92596cff5e6bf8946e20e93c067b220760064862"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:41bbc2678a5b6a19371b2cb51f30ccea71f0c14b26477d2d884fed761cea42c7"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dcbff997f47d45bf028bda4c3036bb3101e89a3df271281d392b6175f71c71d1"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:afa8808159169368b66e4fbeafac6c6fd8f26246dc4d0dcc2caf94bd9cf1b828"}, + {file = "pydantic_core-2.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:12be3b5f54f8111ca38e6b7277f26c23ba5cb3344fae06f879a0a93dfc8b479e"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ed5babdcd3d052ba5cf8832561f18df20778c7ccf12587b2d82f7bf3bf259a0e"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d642e5c029e2acfacf6aa0a7a3e822086b3b777c70d364742561f9ca64c1ffc"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba3073eb38a1294e8c7902989fb80a7a147a69db2396818722bd078476586a0"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d5146a6749b1905e04e62e0ad4622f079e5582f8b3abef5fb64516c623127908"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:deeb64335f489c3c11949cbd1d1668b3f1fb2d1c6a5bf40e126ef7bf95f9fa40"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:31acc37288b8e69e4849f618c3d5cf13b58077c1a1ff9ade0b3065ba974cd385"}, + {file = "pydantic_core-2.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:e09d9f6d722de9d4c1c5f122ea9bc6b25a05f975457805af4dcab7b0128aacbf"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:ba6a8cf089222a171b8f84e6ec2d10f7a9d14f26be3a347b14775a8741810676"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ef1fd1b24e9bcddcb168437686677104e205c8e25b066e73ffdf331d3bb8792b"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eda1a89c4526826c0a87d33596a4cd15b8f58e9250f503e39af1699ba9c878e8"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a3e9a18401a28db4358da2e191508702dbf065f2664c710708cdf9552b9fa50c"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:a439fd0d45d51245bbde799726adda5bd18aed3fa2b01ab2e6a64d6d13776fa3"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:bf6a1d2c920cc9528e884850a4b2ee7629e3d362d5c44c66526d4097bbb07a1a"}, + {file = "pydantic_core-2.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e33fcbea3b63a339dd94de0fc442fefacfe681cc7027ce63f67af9f7ceec7422"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:bf3ed993bdf4754909f175ff348cf8f78d4451215b8aa338633f149ca3b1f37a"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7584171eb3115acd4aba699bc836634783f5bd5aab131e88d8eeb8a3328a4a72"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1624baa76d1740711b2048f302ae9a6d73d277c55a8c3e88b53b773ebf73a971"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:06f33f695527f5a86e090f208978f9fd252c9cfc7e869d3b679bd71f7cb2c1fa"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7ecf0a67b212900e92f328181fed02840d74ed39553cdb38d27314e2b9c89dfa"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:45fa1e8ad6f4367ad73674ca560da8e827cc890eaf371f3ee063d6d7366a207b"}, + {file = "pydantic_core-2.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8d0dbcc57839831ae79fd24b1b83d42bc9448d79feaf3ed3fb5cbf94ffbf3eb7"}, + {file = "pydantic_core-2.3.0.tar.gz", hash = "sha256:5cfb5ac4e82c47d5dc25b209dd4c3989e284b80109f9e08b33c895080c424b4f"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" + +[[package]] +name = "pydantic-extra-types" +version = "2.0.0" +description = "Extra Pydantic types." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pydantic_extra_types-2.0.0-py3-none-any.whl", hash = "sha256:63e5109f00815e71fff2b82090ff0523baef6b8a51889356fd984ef50c184e64"}, + {file = "pydantic_extra_types-2.0.0.tar.gz", hash = "sha256:137ddacb168d95ea77591dbb3739ec4da5eeac0fc4df7f797371d9904451a178"}, +] + +[package.dependencies] +pydantic = ">=2.0b3" [package.extras] -dotenv = ["python-dotenv (>=0.10.4)"] -email = ["email-validator (>=1.0.3)"] +all = ["phonenumbers (>=8,<9)", "pycountry (>=22,<23)"] [[package]] name = "pydata-sphinx-theme" @@ -4269,18 +4361,18 @@ files = [ [[package]] name = "zipp" -version = "3.16.1" +version = "3.16.2" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.16.1-py3-none-any.whl", hash = "sha256:0b37c326d826d5ca35f2b9685cd750292740774ef16190008b00a0227c256fe0"}, - {file = "zipp-3.16.1.tar.gz", hash = "sha256:857b158da2cbf427b376da1c24fd11faecbac5a4ac7523c3607f8a01f94c2ec0"}, + {file = "zipp-3.16.2-py3-none-any.whl", hash = "sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0"}, + {file = "zipp-3.16.2.tar.gz", hash = "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147"}, ] [package.extras] docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] annotated-types = ["annotated-types"] @@ -4303,4 +4395,4 @@ tortoise-orm = ["tortoise-orm"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "850899772834e4726e6f195e57c6d43a84bb1a324f8b0f801df8fbe86ef73893" +content-hash = "3a803a93020a303daa341b9f32f91d5f9be5c4353d3916c9b5b5987ce80db47f" diff --git a/pyproject.toml b/pyproject.toml index 13b665cd7a..4d0b6f6db7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ click = { version = "*", optional = true } cryptography = { version = "*", optional = true } fast-query-parsers = "*" httpx = ">=0.22" +pydantic-extra-types = { version = "*", optional = true } importlib-metadata = { version = "*", python = "<3.10" } importlib-resources = { version = ">=5.12.0", python = "<3.9" } jinja2 = { version = ">=3.1.2", optional = true } @@ -94,10 +95,10 @@ opentelemetry-instrumentation-asgi = { version = "*", optional = true } picologging = { version = "*", optional = true } polyfactory = ">=2.3.2" prometheus-client = { version = "*", optional = true } -pydantic = "<2" -python-dateutil = { version = "*", optional = true } +pydantic = "*" +python-dateutil = "*" python-jose = { version = "*", optional = true } -pytimeparse = { version = "*", optional = true } +pytimeparse = "*" pyyaml = "*" redis = { version = ">=4.4.4, <4.5.0", optional = true, extras = ["hiredis"] } rich = { version = ">=13.0.0", optional = true } @@ -138,6 +139,7 @@ picologging = "*" pre-commit = "*" prometheus-client = "*" psycopg = "*" +pydantic = ">=2" pytest = "*" pytest-asyncio = "*" pytest-cov = "*" @@ -158,6 +160,7 @@ structlog = "*" tortoise-orm = "*" trio = "*" uvicorn = "*" +pydantic-extra-types = "*" [tool.poetry.group.docs] optional = true @@ -402,30 +405,6 @@ known-first-party = ["litestar", "tests", "examples"] "tests/unit/test_contrib/test_sqlalchemy/**/*.*" = ["UP006"] "docs/examples/application_hooks/before_send_hook.py" = ["UP006"] "docs/examples/contrib/sqlalchemy/plugins/**/*.*" = ["UP006"] -"docs/examples/tests/**/*.*" = [ - "A", - "ARG", - "B", - "BLE", - "C901", - "D", - "DTZ", - "EM", - "FBT", - "G", - "N", - "PGH", - "PIE", - "PLR", - "PLW", - "PTH", - "RSE", - "S", - "S101", - "SIM", - "TCH", - "TRY", -] "docs/**/*.*" = ["S", "B", "DTZ", "A", "TCH", "ERA", "D", "RET"] "docs/examples/**" = ["T201"] "docs/examples/data_transfer_objects**/*.*" = ["UP006"] diff --git a/tests/e2e/test_dependency_injection/test_injection_of_generic_models.py b/tests/e2e/test_dependency_injection/test_injection_of_generic_models.py index a1b0b978cf..85c6e92bf4 100644 --- a/tests/e2e/test_dependency_injection/test_injection_of_generic_models.py +++ b/tests/e2e/test_dependency_injection/test_injection_of_generic_models.py @@ -1,17 +1,21 @@ from typing import Generic, Optional, Type, TypeVar -from pydantic import BaseModel -from pydantic.generics import GenericModel +from pydantic import VERSION, BaseModel from litestar import get from litestar.di import Provide from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client +if VERSION.startswith("1"): + from pydantic.generics import GenericModel +else: + GenericModel = BaseModel + T = TypeVar("T") -class Store(GenericModel, Generic[T]): +class Store(GenericModel, Generic[T]): # type: ignore[misc] """Abstract store.""" model: Type[T] @@ -32,7 +36,7 @@ def get(self, value_id: str) -> Optional[Item]: async def get_item_store() -> DictStore: - return DictStore(model=Item) # type: ignore + return DictStore(model=Item) def test_generic_model_injection() -> None: diff --git a/tests/examples/test_request_data.py b/tests/examples/test_request_data.py index 3fae11a7fa..f86d20a94b 100644 --- a/tests/examples/test_request_data.py +++ b/tests/examples/test_request_data.py @@ -16,35 +16,35 @@ def test_request_data_1() -> None: with TestClient(app=app) as client: - res = client.post("/", json={"hello": "world"}) - assert res.status_code == 201 - assert res.json() == {"hello": "world"} + response = client.post("/", json={"hello": "world"}) + assert response.status_code == 201 + assert response.json() == {"hello": "world"} def test_request_data_2() -> None: with TestClient(app=app_2) as client: - res = client.post("/", json={"id": 1, "name": "John"}) - assert res.status_code == 201 - assert res.json() == {"id": 1, "name": "John"} + response = client.post("/", json={"id": 1, "name": "John"}) + assert response.status_code == 201 + assert response.json() == {"id": 1, "name": "John"} def test_request_data_3() -> None: with TestClient(app=app_3) as client: - res = client.post("/", json={"id": 1, "name": "John"}) - assert res.status_code == 201 - assert res.json() == {"id": 1, "name": "John"} + response = client.post("/", json={"id": 1, "name": "John"}) + assert response.status_code == 201 + assert response.json() == {"id": 1, "name": "John"} def test_request_data_4() -> None: with TestClient(app=app_4) as client: - res = client.post("/", data={"id": 1, "name": "John"}) - assert res.status_code == 201 - assert res.json() == {"id": 1, "name": "John"} + response = client.post("/", data={"id": 1, "name": "John"}) + assert response.status_code == 201 + assert response.json() == {"id": 1, "name": "John"} def test_request_data_5() -> None: with TestClient(app=app_5) as client: - res = client.post( + response = client.post( "/", content=b'--d26a9a4ed2f441fba9ab42d04b42099e\r\nContent-Disposition: form-data; name="id"\r\n\r\n1\r\n--d26a9a4ed2f441fba9ab42d04b42099e\r\nContent-Disposition: form-data; name="name"\r\n\r\nJohn\r\n--d26a9a4ed2f441fba9ab42d04b42099e--\r\n', headers={ @@ -52,43 +52,45 @@ def test_request_data_5() -> None: "Content-Type": "multipart/form-data; boundary=d26a9a4ed2f441fba9ab42d04b42099e", }, ) - assert res.json() == {"id": 1, "name": "John"} - assert res.status_code == 201 + assert response.json() == {"id": 1, "name": "John"} + assert response.status_code == 201 def test_request_data_6() -> None: with TestClient(app=app_6) as client: - res = client.post("/", files={"upload": ("hello", b"world")}) - assert res.status_code == 201 - assert res.text == "hello, world" + response = client.post("/", files={"upload": ("hello", b"world")}) + assert response.status_code == 201 + assert response.text == "hello, world" def test_request_data_7() -> None: with TestClient(app=app_7) as client: - res = client.post("/", files={"upload": ("hello", b"world")}) - assert res.status_code == 201 - assert res.text == "hello, world" + response = client.post("/", files={"upload": ("hello", b"world")}) + assert response.status_code == 201 + assert response.text == "hello, world" def test_request_data_8() -> None: with TestClient(app=app_8) as client: - res = client.post("/", files={"cv": ("cv.odf", b"very impressive"), "diploma": ("diploma.pdf", b"the best")}) - assert res.status_code == 201 - assert res.json() == {"cv": "very impressive", "diploma": "the best"} + response = client.post( + "/", files={"cv": ("cv.odf", b"very impressive"), "diploma": ("diploma.pdf", b"the best")} + ) + assert response.status_code == 201 + assert response.json() == {"cv": "very impressive", "diploma": "the best"} def test_request_data_9() -> None: with TestClient(app=app_9) as client: - res = client.post("/", files={"hello": b"there", "i'm": "steve"}) - assert res.status_code == 201 - assert res.json() == {"hello": "there", "i'm": "steve"} + response = client.post("/", files={"hello": b"there", "i'm": "steve"}) + assert response.status_code == 201 + assert response.json() == {"hello": "there", "i'm": "steve"} def test_request_data_10() -> None: with TestClient(app=app_10) as client: - res = client.post("/", files={"foo": ("foo.txt", b"hello"), "bar": ("bar.txt", b"world")}) - assert res.status_code == 201 - assert res.json() == {"foo.txt": "hello", "bar.txt": "world"} + response = client.post("/", files={"foo": ("foo.txt", b"hello"), "bar": ("bar.txt", b"world")}) + assert response.status_code == 201 + assert response.json() == {"foo.txt": "hello", "bar.txt": "world"} def test_msgpack_app() -> None: diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index e3f4125338..a8dcfdb5b8 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, Mock, PropertyMock import pytest +from pydantic import VERSION from pytest import MonkeyPatch from litestar import Litestar, MediaType, Request, Response, get, post @@ -139,7 +140,7 @@ def test_app_params_defined_on_app_config_object() -> None: litestar_signature = inspect.signature(Litestar) app_config_fields = {f.name for f in fields(AppConfig)} for name in litestar_signature.parameters: - if name in {"on_app_init", "initial_state"}: + if name in {"on_app_init", "initial_state", "_preferred_validation_backend"}: continue assert name in app_config_fields # ensure there are not fields defined on AppConfig that aren't in the Litestar signature @@ -248,7 +249,7 @@ def my_route_handler(param: int, data: Person) -> None: response = client.post("/123", json={"first_name": "moishe"}) extra = response.json().get("extra") assert extra is not None - assert len(extra) == 3 + assert 3 if len(extra) == VERSION.startswith("1") else 4 def test_using_custom_http_exception_handler() -> None: diff --git a/tests/unit/test_contrib/test_pydantic.py b/tests/unit/test_contrib/test_pydantic.py index c1b0a91718..e2436bca48 100644 --- a/tests/unit/test_contrib/test_pydantic.py +++ b/tests/unit/test_contrib/test_pydantic.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +import pydantic +import pytest from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -45,6 +47,7 @@ class NotModel: assert PydanticDTO.detect_nested_field(FieldDefinition.from_annotation(NotModel)) is False +@pytest.mark.skipif(pydantic.VERSION.startswith("2"), reason="Beanie does not support pydantic 2 yet") def test_generate_field_definitions_from_beanie_models(create_module: Callable[[str], ModuleType]) -> None: module = create_module( """ diff --git a/tests/unit/test_kwargs/test_header_params.py b/tests/unit/test_kwargs/test_header_params.py index fb9578da04..d756748aef 100644 --- a/tests/unit/test_kwargs/test_header_params.py +++ b/tests/unit/test_kwargs/test_header_params.py @@ -52,6 +52,6 @@ async def my_method( assert user_id assert token == test_token - with create_test_client(my_method) as client: + with create_test_client(my_method, debug=True) as client: response = client.get(f"/users/{uuid4()}/", headers={"X-API-KEY": test_token}) assert response.status_code == HTTP_200_OK diff --git a/tests/unit/test_kwargs/test_multipart_data.py b/tests/unit/test_kwargs/test_multipart_data.py index f95b007acf..9d8d46de1c 100644 --- a/tests/unit/test_kwargs/test_multipart_data.py +++ b/tests/unit/test_kwargs/test_multipart_data.py @@ -6,6 +6,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Type import pytest +from msgspec import convert from pydantic import BaseConfig, BaseModel from litestar import Request, post @@ -14,6 +15,7 @@ from litestar.params import Body from litestar.status_codes import HTTP_201_CREATED, HTTP_400_BAD_REQUEST from litestar.testing import create_test_client +from dataclasses import dataclass from tests import Person, PersonFactory from . import Form @@ -101,19 +103,17 @@ def test_method(data: t_type = body) -> None: # type: ignore def test_request_body_multi_part_mixed_field_content_types() -> None: person = PersonFactory.build() - class MultiPartFormWithMixedFields(BaseModel): - class Config(BaseConfig): - arbitrary_types_allowed = True - + @dataclass + class MultiPartFormWithMixedFields: image: UploadFile - tags: List[str] + tags: List[int] profile: Person @post(path="/form") async def test_method(data: MultiPartFormWithMixedFields = Body(media_type=RequestEncodingType.MULTI_PART)) -> None: file_data = await data.image.read() assert file_data == b"data" - assert data.tags == ["1", "2", "3"] + assert data.tags == [1, 2, 3] assert data.profile == person with create_test_client(test_method) as client: diff --git a/tests/unit/test_kwargs/test_path_params.py b/tests/unit/test_kwargs/test_path_params.py index 5e3b15c5d3..f14fc0404e 100644 --- a/tests/unit/test_kwargs/test_path_params.py +++ b/tests/unit/test_kwargs/test_path_params.py @@ -5,7 +5,6 @@ from uuid import UUID, uuid1, uuid4 import pytest -from pydantic import UUID4 from litestar import Litestar, MediaType, get, post from litestar.exceptions import ImproperlyConfiguredException @@ -60,7 +59,7 @@ "user_id": "abc", "order_id": str(uuid1()), }, - True, + False, ), ], ) @@ -69,7 +68,7 @@ def test_path_params(params_dict: dict, should_raise: bool) -> None: @get(path=test_path) def test_method( - order_id: UUID4, + order_id: UUID, version: float = Parameter(gt=0.1, le=4.0), service_id: int = Parameter(gt=0, le=100), user_id: str = Parameter(min_length=1, max_length=10), diff --git a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py index 96e33edf4d..c139f4e713 100644 --- a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py +++ b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, Type, cast +import pydantic import pytest from polyfactory.factories.pydantic_factory import ModelFactory from pydantic import BaseModel, Field @@ -66,7 +67,9 @@ def route_handler(state: state_typing) -> str: # type: ignore class QueryParams(BaseModel): first: str - second: List[str] = Field(min_items=3) + second: List[str] = ( + Field(min_items=3) if pydantic.VERSION.startswith("1") else Field(min_length=1) # pyright: ignore + ) third: Optional[int] diff --git a/tests/unit/test_openapi/test_constrained_fields.py b/tests/unit/test_openapi/test_constrained_fields.py index 8b9e600a1a..7ba8332f16 100644 --- a/tests/unit/test_openapi/test_constrained_fields.py +++ b/tests/unit/test_openapi/test_constrained_fields.py @@ -1,6 +1,7 @@ from datetime import date from typing import Any, Pattern, Union +import pydantic import pytest from pydantic import conlist, conset @@ -13,8 +14,8 @@ from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType from litestar.params import KwargDefinition from litestar.typing import FieldDefinition - -from .utils import ( +from litestar.utils import is_class_and_subclass +from tests.unit.test_openapi.utils import ( constrained_collection, constrained_dates, constrained_numbers, @@ -22,8 +23,9 @@ ) +@pytest.mark.skipif(pydantic.VERSION.startswith("2"), reason="pydantic 1 specific logic") @pytest.mark.parametrize("annotation", constrained_collection) -def test_create_collection_constrained_field_schema(annotation: Any) -> None: +def test_create_collection_constrained_field_schema_pydantic_v1(annotation: Any) -> None: schema = SchemaCreator().for_collection_constrained_field(FieldDefinition.from_annotation(annotation)) assert schema.type == OpenAPIType.ARRAY assert schema.items.type == OpenAPIType.INTEGER # type: ignore[union-attr] @@ -31,9 +33,24 @@ def test_create_collection_constrained_field_schema(annotation: Any) -> None: assert schema.max_items == annotation.max_items +@pytest.mark.skipif(pydantic.VERSION.startswith("1"), reason="pydantic 2 specific logic") +@pytest.mark.parametrize("annotation", constrained_collection) +def test_create_collection_constrained_field_schema_pydantic_v2(annotation: Any) -> None: + field_definition = FieldDefinition.from_annotation(annotation) + schema = SchemaCreator().for_collection_constrained_field(field_definition) + assert schema.type == OpenAPIType.ARRAY + assert schema.items.type == OpenAPIType.INTEGER # type: ignore[union-attr] + assert any(getattr(m, "min_length", None) == schema.min_items for m in field_definition.metadata if m) + assert any(getattr(m, "max_length", None) == schema.max_items for m in field_definition.metadata if m) + + def test_create_collection_constrained_field_schema_sub_fields() -> None: for pydantic_fn in (conlist, conset): - field_definition = FieldDefinition.from_annotation(pydantic_fn(Union[str, int], min_items=1, max_items=10)) # type: ignore + if pydantic.VERSION.startswith("1"): + annotation = pydantic_fn(Union[str, int], min_items=1, max_items=10) # type: ignore + else: + annotation = pydantic_fn(Union[str, int], min_length=1, max_length=10) # type: ignore + field_definition = FieldDefinition.from_annotation(annotation) schema = SchemaCreator().for_collection_constrained_field(field_definition) assert schema.type == OpenAPIType.ARRAY expected = { @@ -49,16 +66,18 @@ def test_create_collection_constrained_field_schema_sub_fields() -> None: assert schema.to_schema() == expected +@pytest.mark.skipif(pydantic.version.VERSION.startswith("2"), reason="pydantic 1 specific logic") @pytest.mark.parametrize("annotation", constrained_string) -def test_create_string_constrained_field_schema(annotation: Any) -> None: +def test_create_string_constrained_field_schema_pydantic_v1(annotation: Any) -> None: field_definition = FieldDefinition.from_annotation(annotation) assert isinstance(field_definition.kwarg_definition, KwargDefinition) schema = create_string_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) assert schema.type == OpenAPIType.STRING + assert schema.min_length == annotation.min_length assert schema.max_length == annotation.max_length - if pattern := getattr(annotation, "regex", getattr(annotation, "pattern", None)): + if pattern := getattr(annotation, "regex", None): assert schema.pattern == pattern.pattern if isinstance(pattern, Pattern) else pattern if annotation.to_lower: assert schema.description @@ -66,13 +85,31 @@ def test_create_string_constrained_field_schema(annotation: Any) -> None: assert schema.description +@pytest.mark.skipif(pydantic.version.VERSION.startswith("1"), reason="pydantic 2 specific logic") +@pytest.mark.parametrize("annotation", constrained_string) +def test_create_string_constrained_field_schema_pydantic_v2(annotation: Any) -> None: + field_definition = FieldDefinition.from_annotation(annotation) + + assert isinstance(field_definition.kwarg_definition, KwargDefinition) + schema = create_string_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) + assert schema.type == OpenAPIType.STRING + + assert any(getattr(m, "min_length", None) == schema.min_length for m in field_definition.metadata if m) + assert any(getattr(m, "max_length", None) == schema.max_length for m in field_definition.metadata if m) + if pattern := getattr(annotation, "regex", getattr(annotation, "pattern", None)): + assert schema.pattern == pattern.pattern if isinstance(pattern, Pattern) else pattern + if any(getattr(m, "to_lower", getattr(m, "to_upper", None)) for m in field_definition.metadata if m): + assert schema.description + + +@pytest.mark.skipif(pydantic.version.VERSION.startswith("2"), reason="pydantic 1 specific logic") @pytest.mark.parametrize("annotation", constrained_numbers) -def test_create_numerical_constrained_field_schema(annotation: Any) -> None: +def test_create_numerical_constrained_field_schema_pydantic_v1(annotation: Any) -> None: field_definition = FieldDefinition.from_annotation(annotation) assert isinstance(field_definition.kwarg_definition, KwargDefinition) schema = create_numerical_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) - assert schema.type == OpenAPIType.INTEGER if issubclass(annotation, int) else OpenAPIType.NUMBER + assert schema.type == OpenAPIType.INTEGER if is_class_and_subclass(annotation, int) else OpenAPIType.NUMBER assert schema.exclusive_minimum == annotation.gt assert schema.minimum == annotation.ge assert schema.exclusive_maximum == annotation.lt @@ -80,8 +117,24 @@ def test_create_numerical_constrained_field_schema(annotation: Any) -> None: assert schema.multiple_of == annotation.multiple_of +@pytest.mark.skipif(pydantic.version.VERSION.startswith("1"), reason="pydantic 2 specific logic") +@pytest.mark.parametrize("annotation", constrained_numbers) +def test_create_numerical_constrained_field_schema_pydantic_v2(annotation: Any) -> None: + field_definition = FieldDefinition.from_annotation(annotation) + + assert isinstance(field_definition.kwarg_definition, KwargDefinition) + schema = create_numerical_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) + assert schema.type == OpenAPIType.INTEGER if is_class_and_subclass(annotation, int) else OpenAPIType.NUMBER + assert any(getattr(m, "gt", None) == schema.exclusive_minimum for m in field_definition.metadata if m) + assert any(getattr(m, "ge", None) == schema.minimum for m in field_definition.metadata if m) + assert any(getattr(m, "lt", None) == schema.exclusive_maximum for m in field_definition.metadata if m) + assert any(getattr(m, "le", None) == schema.maximum for m in field_definition.metadata if m) + assert any(getattr(m, "multiple_of", None) == schema.multiple_of for m in field_definition.metadata if m) + + +@pytest.mark.skipif(pydantic.version.VERSION.startswith("2"), reason="pydantic 1 specific logic") @pytest.mark.parametrize("annotation", constrained_dates) -def test_create_date_constrained_field_schema(annotation: Any) -> None: +def test_create_date_constrained_field_schema_pydantic_v1(annotation: Any) -> None: field_definition = FieldDefinition.from_annotation(annotation) assert isinstance(field_definition.kwarg_definition, KwargDefinition) @@ -94,6 +147,37 @@ def test_create_date_constrained_field_schema(annotation: Any) -> None: assert (date.fromtimestamp(schema.maximum) if schema.maximum else None) == annotation.le +@pytest.mark.skipif(pydantic.version.VERSION.startswith("1"), reason="pydantic 2 specific logic") +@pytest.mark.parametrize("annotation", constrained_dates) +def test_create_date_constrained_field_schema_pydantic_v2(annotation: Any) -> None: + field_definition = FieldDefinition.from_annotation(annotation) + + assert isinstance(field_definition.kwarg_definition, KwargDefinition) + schema = create_date_constrained_field_schema(field_definition.annotation, field_definition.kwarg_definition) + assert schema.type == OpenAPIType.STRING + assert schema.format == OpenAPIFormat.DATE + assert any( + getattr(m, "gt", None) == (date.fromtimestamp(schema.exclusive_minimum) if schema.exclusive_minimum else None) + for m in field_definition.metadata + if m + ) + assert any( + getattr(m, "ge", None) == (date.fromtimestamp(schema.minimum) if schema.minimum else None) + for m in field_definition.metadata + if m + ) + assert any( + getattr(m, "lt", None) == (date.fromtimestamp(schema.exclusive_maximum) if schema.exclusive_maximum else None) + for m in field_definition.metadata + if m + ) + assert any( + getattr(m, "le", None) == (date.fromtimestamp(schema.maximum) if schema.maximum else None) + for m in field_definition.metadata + if m + ) + + @pytest.mark.parametrize( "annotation", [*constrained_numbers, *constrained_collection, *constrained_string, *constrained_dates] ) diff --git a/tests/unit/test_openapi/test_schema.py b/tests/unit/test_openapi/test_schema.py index b693bd74e5..3ebd72ab7b 100644 --- a/tests/unit/test_openapi/test_schema.py +++ b/tests/unit/test_openapi/test_schema.py @@ -261,7 +261,8 @@ class Model(BaseModel): value: str = Field(title="title", description="description", example="example", max_length=16) schemas: Dict[str, Schema] = {} - SchemaCreator(schemas=schemas).for_field_definition(FieldDefinition.from_kwarg(name="Model", annotation=Model)) + field_definition = FieldDefinition.from_kwarg(name="Model", annotation=Model) + SchemaCreator(schemas=schemas).for_field_definition(field_definition) schema = schemas["Model"] assert schema.properties["value"].description == "description" # type: ignore diff --git a/tests/unit/test_openapi/utils.py b/tests/unit/test_openapi/utils.py index d427c298ea..36d8da2031 100644 --- a/tests/unit/test_openapi/utils.py +++ b/tests/unit/test_openapi/utils.py @@ -2,6 +2,7 @@ from decimal import Decimal from enum import Enum +import pydantic from pydantic import ( conbytes, condate, @@ -41,28 +42,38 @@ class Gender(str, Enum): condecimal(ge=Decimal("10"), le=Decimal("100"), multiple_of=Decimal("2")), ] constrained_string = [ - constr(regex="^[a-zA-Z]$"), - constr(to_upper=True, min_length=1, regex="^[a-zA-Z]$"), - constr(to_lower=True, min_length=1, regex="^[a-zA-Z]$"), - constr(to_lower=True, min_length=10, regex="^[a-zA-Z]$"), - constr(to_lower=True, min_length=10, max_length=100, regex="^[a-zA-Z]$"), - constr(min_length=1), + constr(regex="^[a-zA-Z]$") if pydantic.VERSION.startswith("1") else constr(pattern="^[a-zA-Z]$"), # type: ignore[call-arg] + constr(to_upper=True, min_length=1, regex="^[a-zA-Z]$") # type: ignore[call-arg] + if pydantic.VERSION.startswith("1") + else constr(to_upper=True, min_length=1, pattern="^[a-zA-Z]$"), + constr(to_lower=True, min_length=1, regex="^[a-zA-Z]$") # type: ignore[call-arg] + if pydantic.VERSION.startswith("1") + else constr(to_lower=True, min_length=1, pattern="^[a-zA-Z]$"), + constr(to_lower=True, min_length=10, regex="^[a-zA-Z]$") # type: ignore[call-arg] + if pydantic.VERSION.startswith("1") + else constr(to_lower=True, min_length=10, pattern="^[a-zA-Z]$"), + constr(to_lower=True, min_length=10, max_length=100, regex="^[a-zA-Z]$") # type: ignore[call-arg] + if pydantic.VERSION.startswith("1") + else constr(to_lower=True, min_length=10, max_length=100, pattern="^[a-zA-Z]$"), constr(min_length=1), constr(min_length=10), constr(min_length=10, max_length=100), - conbytes(to_lower=True, min_length=1), - conbytes(to_lower=True, min_length=10), - conbytes(to_upper=True, min_length=10), - conbytes(to_lower=True, min_length=10, max_length=100), + conbytes(min_length=1), + conbytes(min_length=10), + conbytes(min_length=10, max_length=100), conbytes(min_length=1), conbytes(min_length=10), conbytes(min_length=10, max_length=100), ] constrained_collection = [ - conlist(int, min_items=1), - conlist(int, min_items=1, max_items=10), - conset(int, min_items=1), - conset(int, min_items=1, max_items=10), + conlist(int, min_items=1) if pydantic.VERSION.startswith("1") else conlist(int, min_length=1), # type: ignore[call-arg] + conlist(int, min_items=1, max_items=10) # type: ignore[call-arg] + if pydantic.VERSION.startswith("1") + else conlist(int, min_length=1, max_length=10), + conset(int, min_items=1) if pydantic.VERSION.startswith("1") else conset(int, min_length=1), # type: ignore[call-arg] + conset(int, min_items=1, max_items=10) # type: ignore[call-arg] + if pydantic.VERSION.startswith("1") + else conset(int, min_length=1, max_length=10), ] constrained_dates = [ condate(gt=date.today() - timedelta(days=10), lt=date.today() + timedelta(days=100)), diff --git a/tests/unit/test_partial.py b/tests/unit/test_partial.py index 8b16315df2..fae43000de 100644 --- a/tests/unit/test_partial.py +++ b/tests/unit/test_partial.py @@ -1,10 +1,11 @@ import dataclasses -from typing import Any, ClassVar, Optional, TypedDict, get_type_hints +from typing import Any, ClassVar, Optional, TypedDict, cast, get_type_hints import pydantic import pytest from msgspec.inspect import type_info from pydantic import BaseModel +from pydantic.fields import FieldInfo from typing_extensions import get_args from litestar.exceptions import ImproperlyConfiguredException @@ -26,7 +27,8 @@ from typing import _GenericAlias as GenericAlias # type: ignore -def test_partial_pydantic_model() -> None: +@pytest.mark.skipif(pydantic.VERSION.startswith("2"), reason="pydantic v1 only logic") +def test_partial_pydantic_v1_model() -> None: class PersonWithClassVar(Person): cls_var: ClassVar[int] @@ -46,6 +48,26 @@ class PersonWithClassVar(Person): assert NoneType not in get_args(annotation) +@pytest.mark.skipif(pydantic.VERSION.startswith("1"), reason="pydantic v2 only logic") +def test_partial_pydantic_v2_model() -> None: + class PersonWithClassVar(Person): + cls_var: ClassVar[int] + + partial = Partial[PersonWithClassVar] + partial_model_fields = cast("dict[str,FieldInfo]", partial.model_fields) # type: ignore + assert len(partial_model_fields) == len(Person.model_fields) # + + for field in partial_model_fields.values(): + assert not field.is_required() + + for annotation in get_type_hints(partial).values(): + if not is_class_var(annotation): + assert isinstance(annotation, GenericAlias) + assert NoneType in get_args(annotation) + else: + assert NoneType not in get_args(annotation) + + def test_partial_vanilla_dataclass() -> None: @dataclasses.dataclass class VanillaDataClassPersonWithClassVar(VanillaDataClassPerson): @@ -126,7 +148,8 @@ class PersonWithClassVar(AttrsPerson): assert NoneType not in get_args(annotation) -def test_partial_pydantic_model_with_superclass() -> None: +@pytest.mark.skipif(pydantic.VERSION.startswith("2"), reason="pydantic v1 only logic") +def test_partial_pydantic_v1_model_with_superclass() -> None: """Test that Partial returns the correct annotations for nested models.""" class Parent(BaseModel): @@ -147,6 +170,27 @@ class Child(Parent): } +@pytest.mark.skipif(pydantic.VERSION.startswith("1"), reason="pydantic v2 only logic") +def test_partial_pydantic_v2_model_with_superclass() -> None: + """Test that Partial returns the correct annotations for nested models.""" + + class Parent(BaseModel): + parent_attribute: int + + class Child(Parent): + child_attribute: int + + partial_child = Partial[Child] + partial_model_fields = cast("dict[str,FieldInfo]", partial_child.model_fields) # type: ignore + for field in partial_model_fields.values(): + assert not field.is_required() + + assert get_type_hints(partial_child) == { + "parent_attribute": Optional[int], + "child_attribute": Optional[int], + } + + def test_partial_dataclass_with_superclass() -> None: """Test that Partial returns the correct annotations for nested models.""" diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index 462f158e1e..7f136e78bd 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -4,25 +4,35 @@ import pydantic import pytest +from _decimal import Decimal +from dateutil.utils import today from pydantic import ( + VERSION, BaseModel, ByteSize, - ConstrainedBytes, - ConstrainedDate, - ConstrainedDecimal, - ConstrainedFloat, - ConstrainedFrozenSet, - ConstrainedInt, - ConstrainedList, - ConstrainedSet, - ConstrainedStr, EmailStr, NameEmail, - PaymentCardNumber, SecretBytes, SecretStr, + conbytes, + condate, + condecimal, + confloat, + confrozenset, + conint, + conlist, + conset, + constr, ) -from pydantic.color import Color + +if VERSION.startswith("1"): + from pydantic import ( + PaymentCardNumber, + ) + from pydantic.color import Color +else: + from pydantic_extra_types.color import Color # type: ignore + from pydantic_extra_types.payment import PaymentCardNumber # type: ignore from litestar.enums import MediaType from litestar.exceptions import SerializationException @@ -34,9 +44,6 @@ encode_json, encode_msgpack, ) -from tests import PersonFactory - -person = PersonFactory.build() class CustomStr(str): @@ -68,99 +75,136 @@ class CustomTuple(tuple): class Model(BaseModel): - path: Path = Path("example") - - email_str: pydantic.EmailStr = EmailStr("info@example.org") - name_email: NameEmail = NameEmail("info", "info@example.org") - color: Color = Color("rgb(255, 255, 255)") - bytesize: ByteSize = ByteSize.validate("100b") - secret_str: SecretStr = SecretStr("hello") - secret_bytes: SecretBytes = SecretBytes(b"hello") - payment_card_number: PaymentCardNumber = PaymentCardNumber("4000000000000002") - - constr: pydantic.constr() = ConstrainedStr("hello") # type: ignore[valid-type] - conbytes: pydantic.conbytes() = ConstrainedBytes(b"hello") # type: ignore[valid-type] - condate: pydantic.condate() = ConstrainedDate.today() # type: ignore[valid-type] - condecimal: pydantic.condecimal() = ConstrainedDecimal(3.14) # type: ignore[valid-type] - confloat: pydantic.confloat() = ConstrainedFloat(1.0) # type: ignore[valid-type] - conset: pydantic.conset(int) = ConstrainedSet([1]) # type: ignore[valid-type] - confrozenset: pydantic.confrozenset(int) = ConstrainedFrozenSet([1]) # type: ignore[valid-type] - conint: pydantic.conint() = ConstrainedInt(1) # type: ignore[valid-type] - conlist: pydantic.conlist(int, min_items=1) = ConstrainedList([1]) # type: ignore[valid-type] - - custom_str: CustomStr = CustomStr() - custom_int: CustomInt = CustomInt() - custom_float: CustomFloat = CustomFloat() - custom_list: CustomList = CustomList() - custom_set: CustomSet = CustomSet() - custom_frozenset: CustomFrozenSet = CustomFrozenSet() - custom_tuple: CustomTuple = CustomTuple() - - -model = Model() + if VERSION.startswith("1"): + + class Config: + arbitrary_types_allowed = True + + custom_str: CustomStr = CustomStr() + custom_int: CustomInt = CustomInt() + custom_float: CustomFloat = CustomFloat() + custom_list: CustomList = CustomList() + custom_set: CustomSet = CustomSet() + custom_frozenset: CustomFrozenSet = CustomFrozenSet() + custom_tuple: CustomTuple = CustomTuple() + + conset: conset(int, min_items=1) # type: ignore + confrozenset: confrozenset(int, min_items=1) # type: ignore + conlist: conlist(int, min_items=1) if pydantic.VERSION.startswith("2") else conlist(int, min_items=1) # type: ignore + + else: + model_config = {"arbitrary_types_allowed": True} + conset: conset(int, min_length=1) # type: ignore + confrozenset: confrozenset(int, min_length=1) # type: ignore + conlist: conlist(int, min_length=1) if pydantic.VERSION.startswith("2") else conlist(int, min_items=1) # type: ignore + + path: Path + + email_str: EmailStr + name_email: NameEmail + color: Color + bytesize: ByteSize + secret_str: SecretStr + secret_bytes: SecretBytes + payment_card_number: PaymentCardNumber + + constr: constr(min_length=1) # type: ignore + conbytes: conbytes(min_length=1) # type: ignore + condate: condate(ge=today().date()) # type: ignore + condecimal: condecimal(ge=Decimal("1")) # type: ignore + confloat: confloat(ge=0) # type: ignore + + conint: conint(ge=0) # type: ignore + + +@pytest.fixture() +def model() -> Model: + return Model( + path=Path("example"), + email_str="info@example.org", + name_email=NameEmail("info", "info@example.org"), + color=Color("rgb(255, 255, 255)"), + bytesize=ByteSize(100), + secret_str=SecretStr("hello"), + secret_bytes=SecretBytes(b"hello"), + payment_card_number=PaymentCardNumber("4000000000000002"), + constr="hello", + conbytes=b"hello", + condate=today(), + condecimal=Decimal("3.14"), + confloat=1.0, + conset={1}, + confrozenset=frozenset([1]), + conint=1, + conlist=[1], + ) @pytest.mark.parametrize( - "value, expected", + "attribute_name, expected", [ - (model.email_str, "info@example.org"), - (model.name_email, "info "), - (model.color, "white"), - (model.bytesize, 100), - (model.secret_str, "**********"), - (model.secret_bytes, "**********"), - (model.payment_card_number, "4000000000000002"), - (model.constr, "hello"), - (model.conbytes, "hello"), - (model.condate, model.condate.isoformat()), - (model.condecimal, 3.14), - (model.conset, {1}), - (model.confrozenset, frozenset([1])), - (model.conint, 1), - (model, model.dict()), - (model.custom_str, ""), - (model.custom_int, 0), - (model.custom_float, 0.0), - (model.custom_set, set()), - (model.custom_frozenset, frozenset()), + ("path", "example"), + ("email_str", "info@example.org"), + ("name_email", "info "), + ("color", "white"), + ("bytesize", 100), + ("secret_str", "**********"), + ("secret_bytes", "**********"), + ("payment_card_number", "4000000000000002"), + ("constr", "hello"), + ("conbytes", b"hello"), + ("condate", today().date().isoformat()), + ("condecimal", 3.14), + ("conset", {1}), + ("confrozenset", frozenset([1])), + ("conint", 1), ], ) -def test_default_serializer(value: Any, expected: Any) -> None: - assert default_serializer(value) == expected +def test_default_serializer(model: BaseModel, attribute_name: str, expected: Any) -> None: + assert default_serializer(getattr(model, attribute_name)) == expected + +def test_serialization_of_model_instance(model: BaseModel) -> None: + assert default_serializer(model) == model.model_dump(mode="json") if hasattr(model, "model_dump") else model.dict() -def test_pydantic_json_compatibility() -> None: - assert json.loads(model.json()) == json.loads(encode_json(model)) + +def test_pydantic_json_compatibility(model: BaseModel) -> None: + raw = model.model_dump_json() if hasattr(model, "model_dump_json") else model.json() + encoded_json = encode_json(model) + assert json.loads(raw) == json.loads(encoded_json) @pytest.mark.parametrize("encoder", [encode_json, encode_msgpack]) -def test_encoder_raises_serialization_exception(encoder: Any) -> None: +def test_encoder_raises_serialization_exception(model: BaseModel, encoder: Any) -> None: with pytest.raises(SerializationException): encoder(object()) @pytest.mark.parametrize("decoder", [decode_json, decode_msgpack]) -def test_decode_json_raises_serialization_exception(decoder: Any) -> None: +def test_decode_json_raises_serialization_exception(model: BaseModel, decoder: Any) -> None: with pytest.raises(SerializationException): decoder(b"str") -def test_decode_json_typed() -> None: - model_json = model.json() - assert decode_json(model_json, Model).json() == model_json +def test_decode_json_typed(model: BaseModel) -> None: + dumped_model = model.model_dump_json() if hasattr(model, "model_dump_json") else model.json() + decoded_model = decode_json(dumped_model, Model) + assert ( + decoded_model.model_dump_json() if hasattr(decoded_model, "model_dump_json") else decoded_model.json() + ) == dumped_model -def test_decode_msgpack_typed() -> None: +def test_decode_msgpack_typed(model: BaseModel) -> None: model_json = model.json() assert decode_msgpack(encode_msgpack(model), Model).json() == model_json -def test_decode_media_type() -> None: +def test_decode_media_type(model: BaseModel) -> None: model_json = model.json() assert decode_media_type(model_json.encode("utf-8"), MediaType.JSON, Model).json() == model_json assert decode_media_type(encode_msgpack(model), MediaType.MESSAGEPACK, Model).json() == model_json -def test_decode_media_type_unsupported_media_type() -> None: +def test_decode_media_type_unsupported_media_type(model: BaseModel) -> None: with pytest.raises(SerializationException): decode_media_type(b"", MediaType.HTML, Model) diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index c93f4dac0f..8f079c7c14 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -3,7 +3,7 @@ import pytest from attr import define -from pydantic import BaseModel +from pydantic import BaseModel, VERSION from typing_extensions import TypedDict from litestar import get, post @@ -50,8 +50,8 @@ def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> No ... with create_test_client( - route_handlers=[test], - dependencies=dependencies, + route_handlers=[test], + dependencies=dependencies, ) as client: response = client.get("/?param=13") @@ -141,11 +141,11 @@ class Parent(BaseModel): @post("/") def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -163,15 +163,21 @@ def test( data = response.json() assert data - assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, - ] + if VERSION.startswith("1"): + assert data["extra"] == [ + {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, + {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, + {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, + {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, + {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, + {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, + {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + ] + else: + assert data["extra"] == [ + {'message': 'Input should be a valid integer, unable to parse string as an integer', 'key': 'child.val'}, + {'message': 'Input should be a valid integer, unable to parse string as an integer', 'key': 'child.other_val'}, + {'message': 'Input should be a valid integer, unable to parse string as an integer', 'key': 'other_child.val.1'}] def test_invalid_input_attrs() -> None: @@ -191,10 +197,10 @@ class Parent: @post("/") def test( - data: Parent, - int_param: int, - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -239,11 +245,11 @@ class Parent: @post("/") def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -286,11 +292,11 @@ class Parent(TypedDict): @post("/") def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... From 65e1177a6960a5ab2f7b946123774dc02ec88863 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 15 Jul 2023 09:56:05 +0200 Subject: [PATCH 3/5] feat(internal): handle multiple messages for msgspec --- .pre-commit-config.yaml | 12 +- litestar/_kwargs/dependencies.py | 2 +- litestar/_kwargs/kwargs_model.py | 4 +- litestar/_openapi/parameters.py | 2 +- litestar/_openapi/path_item.py | 2 +- litestar/_signature/model.py | 66 ++++++----- litestar/serialization/__init__.py | 21 ++++ litestar/serialization/_msgspec_utils.py | 11 ++ .../serialization/_pydantic_serialization.py | 105 ++++++++++++++++++ .../msgspec_hooks.py} | 101 ++--------------- poetry.lock | 14 +-- pyproject.toml | 6 +- tests/unit/test_openapi/test_parameters.py | 2 +- tests/unit/test_openapi/test_request_body.py | 2 +- tests/unit/test_signature/test_parsing.py | 10 +- tests/unit/test_signature/test_validation.py | 94 ++++++++-------- 16 files changed, 262 insertions(+), 192 deletions(-) create mode 100644 litestar/serialization/__init__.py create mode 100644 litestar/serialization/_msgspec_utils.py create mode 100644 litestar/serialization/_pydantic_serialization.py rename litestar/{serialization.py => serialization/msgspec_hooks.py} (67%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae35620461..57df4e9e93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -70,7 +70,7 @@ repos: exclude: "test_apps|tools|docs|tests/examples|tests/docker_service_fixtures" additional_dependencies: [ - msgspec>=0.17.0, + polyfactory, aiosqlite, annotated_types, async_timeout, @@ -91,14 +91,13 @@ repos: jsbeautifier, mako, mongomock_motor, + msgspec, multidict, opentelemetry-instrumentation-asgi, opentelemetry-sdk, oracledb, piccolo, picologging, - polyfactory, - prometheus_client, psycopg, pydantic>=2, pydantic_extra_types, @@ -122,6 +121,7 @@ repos: types-pyyaml, types-redis, uvicorn, + prometheus_client, ] - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.317 @@ -130,7 +130,7 @@ repos: exclude: "test_apps|tools|docs|_openapi|tests/examples|tests/docker_service_fixtures" additional_dependencies: [ - msgspec>=0.17.0, + polyfactory, aiosqlite, annotated_types, async_timeout, @@ -151,14 +151,13 @@ repos: jsbeautifier, mako, mongomock_motor, + msgspec, multidict, opentelemetry-instrumentation-asgi, opentelemetry-sdk, oracledb, piccolo, picologging, - polyfactory, - prometheus_client, psycopg, pydantic>=2, pydantic_extra_types, @@ -182,6 +181,7 @@ repos: types-pyyaml, types-redis, uvicorn, + prometheus_client, ] - repo: local hooks: diff --git a/litestar/_kwargs/dependencies.py b/litestar/_kwargs/dependencies.py index 17b9d8bb06..30843c0d87 100644 --- a/litestar/_kwargs/dependencies.py +++ b/litestar/_kwargs/dependencies.py @@ -61,7 +61,7 @@ async def resolve_dependency( signature_model = get_signature_model(dependency.provide) dependency_kwargs = ( signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs) - if signature_model.fields + if signature_model._fields else {} ) value = await dependency.provide(**dependency_kwargs) diff --git a/litestar/_kwargs/kwargs_model.py b/litestar/_kwargs/kwargs_model.py index cb39eb1488..6b59555ec8 100644 --- a/litestar/_kwargs/kwargs_model.py +++ b/litestar/_kwargs/kwargs_model.py @@ -277,7 +277,7 @@ def create_for_signature_model( An instance of KwargsModel """ - field_definitions = signature_model.fields + field_definitions = signature_model._fields cls._validate_raw_kwargs( path_parameters=path_parameters, @@ -405,7 +405,7 @@ def _create_dependency_graph(cls, key: str, dependencies: dict[str, Provide]) -> list. """ provide = dependencies[key] - sub_dependency_keys = [k for k in get_signature_model(provide).fields if k in dependencies] + sub_dependency_keys = [k for k in get_signature_model(provide)._fields if k in dependencies] return Dependency( key=key, provide=provide, diff --git a/litestar/_openapi/parameters.py b/litestar/_openapi/parameters.py index d1d789e750..c500d415b5 100644 --- a/litestar/_openapi/parameters.py +++ b/litestar/_openapi/parameters.py @@ -136,7 +136,7 @@ def get_recursive_handler_parameters( ) ] - dependency_fields = dependency_providers[field_name].signature_model.fields + dependency_fields = dependency_providers[field_name].signature_model._fields return create_parameter_for_handler( route_handler=route_handler, handler_fields=dependency_fields, diff --git a/litestar/_openapi/path_item.py b/litestar/_openapi/path_item.py index 7d6e6ab2df..5f96c2670f 100644 --- a/litestar/_openapi/path_item.py +++ b/litestar/_openapi/path_item.py @@ -93,7 +93,7 @@ def create_path_item( route_handler, _ = handler_tuple if route_handler.include_in_schema: - handler_fields = route_handler.signature_model.fields if route_handler.signature_model else {} + handler_fields = route_handler.signature_model._fields if route_handler.signature_model else {} parameters = ( create_parameter_for_handler( route_handler=route_handler, diff --git a/litestar/_signature/model.py b/litestar/_signature/model.py index 406c95eb76..c96bfbe5ef 100644 --- a/litestar/_signature/model.py +++ b/litestar/_signature/model.py @@ -9,10 +9,11 @@ from typing_extensions import Annotated from litestar._signature.utils import create_type_overrides, validate_signature_dependencies -from litestar.enums import ScopeType +from litestar.enums import ParamType, ScopeType from litestar.exceptions import InternalServerException, ValidationException from litestar.params import DependencyKwarg, KwargDefinition, ParameterKwarg -from litestar.serialization import ExtendedMsgSpecValidationError, dec_hook +from litestar.serialization import dec_hook +from litestar.serialization._msgspec_utils import ExtendedMsgSpecValidationError from litestar.typing import FieldDefinition # noqa from litestar.utils import make_non_optional_union from litestar.utils.dataclass import simple_asdict @@ -38,7 +39,7 @@ class ErrorMessage(TypedDict): # in this case, we don't show a key at all as it will be empty key: NotRequired[str] message: str - source: NotRequired[Literal["cookie", "body", "header", "query"]] + source: NotRequired[Literal["body"] | ParamType] MSGSPEC_CONSTRAINT_FIELDS = ( @@ -59,9 +60,9 @@ class SignatureModel(Struct): """Model that represents a function signature that uses a msgspec specific type or types.""" # NOTE: we have to use Set and Dict here because python 3.8 goes haywire if we use 'set' and 'dict' - dependency_name_set: ClassVar[Set[str]] - fields: ClassVar[Dict[str, FieldDefinition]] - return_annotation: ClassVar[Any] + _dependency_name_set: ClassVar[Set[str]] + _fields: ClassVar[Dict[str, FieldDefinition]] + _return_annotation: ClassVar[Any] @classmethod def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception: @@ -79,7 +80,7 @@ def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessa if client_errors := [ err_message for err_message in messages - if ("key" in err_message and err_message["key"] not in cls.dependency_name_set) or "key" not in err_message + if ("key" in err_message and err_message["key"] not in cls._dependency_name_set) or "key" not in err_message ]: return ValidationException(detail=f"Validation failed for {method} {connection.url}", extra=client_errors) return InternalServerException() @@ -103,21 +104,35 @@ def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASG return message message["key"] = key = ".".join(keys) - - if key in connection.query_params: - message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", "query") - - elif key in cls.fields and isinstance(cls.fields[key].kwarg_definition, ParameterKwarg): - if cast(ParameterKwarg, cls.fields[key].kwarg_definition).cookie: - source = "cookie" - elif cast(ParameterKwarg, cls.fields[key].kwarg_definition).header: - source = "header" + if keys[0].startswith("data"): + message["key"] = message["key"].replace("data.", "") + message["source"] = "body" + elif key in connection.query_params: + message["source"] = ParamType.QUERY + + elif key in cls._fields and isinstance(cls._fields[key].kwarg_definition, ParameterKwarg): + if cast(ParameterKwarg, cls._fields[key].kwarg_definition).cookie: + message["source"] = ParamType.COOKIE + elif cast(ParameterKwarg, cls._fields[key].kwarg_definition).header: + message["source"] = ParamType.HEADER else: - source = "query" - message["source"] = cast("Literal['cookie', 'body', 'header', 'query']", source) + message["source"] = ParamType.QUERY return message + @classmethod + def _collect_errors(cls, **kwargs: Any) -> list[tuple[str, Exception]]: + exceptions: list[tuple[str, Exception]] = [] + for field_name in cls._fields: + try: + raw_value = kwargs[field_name] + annotation = cls.__annotations__[field_name] + convert(raw_value, type=annotation, strict=False, dec_hook=dec_hook, str_keys=True) + except Exception as e: # noqa: BLE001 + exceptions.append((field_name, e)) + + return exceptions + @classmethod def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]: """Extract values from the connection instance and return a dict of parsed values. @@ -143,10 +158,11 @@ def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwarg messages.append(message) raise cls._create_exception(messages=messages, connection=connection) from e except ValidationError as e: - match = ERR_RE.search(str(e)) - keys = [str(match.group(1)) if match else "n/a"] - message = cls._build_error_message(keys=keys, exc_msg=str(e), connection=connection) - messages.append(message) + for field_name, exc in cls._collect_errors(**kwargs): # type: ignore[assignment] + match = ERR_RE.search(str(exc)) + keys = [field_name, str(match.group(1))] if match else [field_name] + message = cls._build_error_message(keys=keys, exc_msg=str(e), connection=connection) + messages.append(message) raise cls._create_exception(messages=messages, connection=connection) from e def to_dict(self) -> dict[str, Any]: @@ -221,9 +237,9 @@ def create( bases=(cls,), module=getattr(fn, "__module__", None), namespace={ - "return_annotation": parsed_signature.return_type.annotation, - "dependency_name_set": dependency_names, - "fields": parsed_signature.parameters, + "_return_annotation": parsed_signature.return_type.annotation, + "_dependency_name_set": dependency_names, + "_fields": parsed_signature.parameters, }, kw_only=True, ) diff --git a/litestar/serialization/__init__.py b/litestar/serialization/__init__.py new file mode 100644 index 0000000000..ff67490d9b --- /dev/null +++ b/litestar/serialization/__init__.py @@ -0,0 +1,21 @@ +from .msgspec_hooks import ( + dec_hook, + decode_json, + decode_media_type, + decode_msgpack, + default_serializer, + encode_json, + encode_msgpack, + get_serializer, +) + +__all__ = ( + "dec_hook", + "decode_json", + "decode_media_type", + "decode_msgpack", + "default_serializer", + "encode_json", + "encode_msgpack", + "get_serializer", +) diff --git a/litestar/serialization/_msgspec_utils.py b/litestar/serialization/_msgspec_utils.py new file mode 100644 index 0000000000..ac174cc043 --- /dev/null +++ b/litestar/serialization/_msgspec_utils.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import Any + +from msgspec import ValidationError + + +class ExtendedMsgSpecValidationError(ValidationError): + def __init__(self, errors: list[dict[str, Any]]) -> None: + self.errors = errors + super().__init__(errors) diff --git a/litestar/serialization/_pydantic_serialization.py b/litestar/serialization/_pydantic_serialization.py new file mode 100644 index 0000000000..593714837a --- /dev/null +++ b/litestar/serialization/_pydantic_serialization.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import Any, Callable, TypeVar, cast +from uuid import UUID + +from msgspec import ValidationError + +from litestar.serialization._msgspec_utils import ExtendedMsgSpecValidationError +from litestar.utils import is_class_and_subclass, is_pydantic_model_class + +__all__ = ( + "create_pydantic_decoders", + "create_pydantic_encoders", +) + +T = TypeVar("T") + + +def create_pydantic_decoders() -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]: + decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [] + try: + import pydantic + + def _dec_pydantic(type_: type[pydantic.BaseModel], value: Any) -> pydantic.BaseModel: + try: + return ( + type_.model_validate(value, strict=False) + if hasattr(type_, "model_validate") + else type_.parse_obj(value) + ) + except pydantic.ValidationError as e: + raise ExtendedMsgSpecValidationError(errors=cast(list[dict[str, Any]], e.errors())) from e + + decoders.append((is_pydantic_model_class, _dec_pydantic)) + + def _dec_pydantic_uuid( + type_: type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5], val: Any + ) -> type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5]: + if isinstance(val, str): + val = type_(val) + + elif isinstance(val, (bytes, bytearray)): + try: + val = type_(val.decode()) + except ValueError: + # 16 bytes in big-endian order as the bytes argument fail + # the above check + val = type_(bytes=val) + elif isinstance(val, UUID): + val = type_(str(val)) + + if not isinstance(val, type_): + raise ValidationError(f"Invalid UUID: {val!r}") + + if type_._required_version != val.version: # type: ignore + raise ValidationError(f"Invalid UUID version: {val!r}") + + return cast( + "type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5]", val + ) + + def _is_pydantic_uuid(value: Any) -> bool: + return is_class_and_subclass(value, (pydantic.UUID1, pydantic.UUID3, pydantic.UUID4, pydantic.UUID5)) + + decoders.append((_is_pydantic_uuid, _dec_pydantic_uuid)) + return decoders + except ImportError: + return decoders + + +def create_pydantic_encoders() -> dict[Any, Callable[[Any], Any]]: + try: + import pydantic + + encoders: dict[Any, Callable[[Any], Any]] = { + pydantic.EmailStr: str, + pydantic.NameEmail: str, + pydantic.ByteSize: lambda val: val.real, + } + + if pydantic.VERSION.startswith("1"): # pragma: no cover + encoders.update( + { + pydantic.BaseModel: lambda model: model.dict(), + pydantic.SecretField: str, + pydantic.StrictBool: int, + pydantic.color.Color: str, # pyright: ignore + pydantic.ConstrainedBytes: lambda val: val.decode("utf-8"), + pydantic.ConstrainedDate: lambda val: val.isoformat(), + } + ) + else: + from pydantic_extra_types import color + + encoders.update( + { + pydantic.BaseModel: lambda model: model.model_dump(mode="json"), + color.Color: str, + pydantic.types.SecretStr: lambda val: "**********" if val else "", + pydantic.types.SecretBytes: lambda val: "**********" if val else "", + } + ) + return encoders + except ImportError: + return {} diff --git a/litestar/serialization.py b/litestar/serialization/msgspec_hooks.py similarity index 67% rename from litestar/serialization.py rename to litestar/serialization/msgspec_hooks.py index a9efae8d8a..c9c7f4bbe8 100644 --- a/litestar/serialization.py +++ b/litestar/serialization/msgspec_hooks.py @@ -14,20 +14,17 @@ ) from pathlib import Path, PurePath from re import Pattern -from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload from uuid import UUID import msgspec -from msgspec import ValidationError from litestar.enums import MediaType from litestar.exceptions import SerializationException +from litestar.serialization._pydantic_serialization import create_pydantic_decoders, create_pydantic_encoders from litestar.types import Empty, Serializer -from litestar.utils import is_class_and_subclass, is_pydantic_model_class if TYPE_CHECKING: - from typing_extensions import TypeAlias - from litestar.types import TypeEncodersMap __all__ = ( @@ -39,98 +36,14 @@ "encode_json", "encode_msgpack", "get_serializer", - "ExtendedMsgSpecValidationError", ) T = TypeVar("T") -PYDANTIC_DECODERS: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [] - - -class ExtendedMsgSpecValidationError(ValidationError): - def __init__(self, errors: list[dict[str, Any]]) -> None: - self.errors = errors - super().__init__(errors) - +EXTRA_DECODERS: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [] -try: - import pydantic - - PYDANTIC_ENCODERS: dict[Any, Callable[[Any], Any]] = { - pydantic.EmailStr: str, - pydantic.NameEmail: str, - pydantic.ByteSize: lambda val: val.real, - } - - def _dec_pydantic(type_: type[pydantic.BaseModel], value: Any) -> pydantic.BaseModel: - try: - return ( - type_.model_validate(value, strict=False) - if hasattr(type_, "model_validate") - else type_.parse_obj(value) - ) - except pydantic.ValidationError as e: - raise ExtendedMsgSpecValidationError(errors=cast(list[dict[str, Any]], e.errors())) from e - - PYDANTIC_DECODERS.append((is_pydantic_model_class, _dec_pydantic)) - - if pydantic.VERSION.startswith("1"): # pragma: no cover - PYDANTIC_ENCODERS.update( - { - pydantic.BaseModel: lambda model: model.dict(), - pydantic.SecretField: str, - pydantic.StrictBool: int, - pydantic.color.Color: str, # pyright: ignore - pydantic.ConstrainedBytes: lambda val: val.decode("utf-8"), - pydantic.ConstrainedDate: lambda val: val.isoformat(), - } - ) - - PydanticUUIDType: TypeAlias = ( - "type[pydantic.UUID1] | type[pydantic.UUID3] | type[pydantic.UUID4] | type[pydantic.UUID5]" - ) - - def _dec_pydantic_uuid(type_: PydanticUUIDType, val: Any) -> PydanticUUIDType: - if isinstance(val, str): - val = type_(val) - - elif isinstance(val, (bytes, bytearray)): - try: - val = type_(val.decode()) - except ValueError: - # 16 bytes in big-endian order as the bytes argument fail - # the above check - val = type_(bytes=val) - elif isinstance(val, UUID): - val = type_(str(val)) - - if not isinstance(val, type_): - raise ValidationError(f"Invalid UUID: {val!r}") - - if type_._required_version != val.version: # type: ignore - raise ValidationError(f"Invalid UUID version: {val!r}") - - return cast("PydanticUUIDType", val) - - def _is_pydantic_uuid(value: Any) -> bool: - return is_class_and_subclass(value, (pydantic.UUID1, pydantic.UUID3, pydantic.UUID4, pydantic.UUID5)) - - PYDANTIC_DECODERS.append((_is_pydantic_uuid, _dec_pydantic_uuid)) - else: - from pydantic_extra_types import color - - PYDANTIC_ENCODERS.update( - { - pydantic.BaseModel: lambda model: model.model_dump(mode="json"), - color.Color: str, - pydantic.types.SecretStr: lambda val: "**********" if val else "", - pydantic.types.SecretBytes: lambda val: "**********" if val else "", - } - ) - - -except ImportError: - PYDANTIC_ENCODERS = {} +if pydantic_decoders := create_pydantic_decoders(): + EXTRA_DECODERS.extend(pydantic_decoders) DEFAULT_TYPE_ENCODERS: TypeEncodersMap = { Path: str, @@ -157,7 +70,7 @@ def _is_pydantic_uuid(value: Any) -> bool: set: set, frozenset: frozenset, bytes: bytes, - **PYDANTIC_ENCODERS, + **create_pydantic_encoders(), } @@ -201,7 +114,7 @@ def dec_hook(type_: Any, value: Any) -> Any: # pragma: no cover if isinstance(value, type_): return value - for predicate, decoder in PYDANTIC_DECODERS: + for predicate, decoder in EXTRA_DECODERS: if predicate(type_): return decoder(type_, value) diff --git a/poetry.lock b/poetry.lock index 06e8d113b6..a1417e33b1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4075,13 +4075,13 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] name = "uvicorn" -version = "0.22.0" +version = "0.23.0" description = "The lightning-fast ASGI server." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "uvicorn-0.22.0-py3-none-any.whl", hash = "sha256:e9434d3bbf05f310e762147f769c9f21235ee118ba2d2bf1155a7196448bd996"}, - {file = "uvicorn-0.22.0.tar.gz", hash = "sha256:79277ae03db57ce7d9aa0567830bbb51d7a612f54d6e1e3e92da3ef24c2c8ed8"}, + {file = "uvicorn-0.23.0-py3-none-any.whl", hash = "sha256:479599b2c0bb1b9b394c6d43901a1eb0c1ec72c7d237b5bafea23c5b2d4cdf10"}, + {file = "uvicorn-0.23.0.tar.gz", hash = "sha256:d38ab90c0e2c6fe3a054cddeb962cfd5d0e0e6608eaaff4a01d5c36a67f3168c"}, ] [package.dependencies] @@ -4144,13 +4144,13 @@ test = ["Cython (>=0.29.32,<0.30.0)", "aiohttp", "flake8 (>=3.9.2,<3.10.0)", "my [[package]] name = "virtualenv" -version = "20.23.1" +version = "20.24.0" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.23.1-py3-none-any.whl", hash = "sha256:34da10f14fea9be20e0fd7f04aba9732f84e593dac291b757ce42e3368a39419"}, - {file = "virtualenv-20.23.1.tar.gz", hash = "sha256:8ff19a38c1021c742148edc4f81cb43d7f8c6816d2ede2ab72af5b84c749ade1"}, + {file = "virtualenv-20.24.0-py3-none-any.whl", hash = "sha256:18d1b37fc75cc2670625702d76849a91ebd383768b4e91382a8d51be3246049e"}, + {file = "virtualenv-20.24.0.tar.gz", hash = "sha256:e2a7cef9da880d693b933db7654367754f14e20650dc60e8ee7385571f8593a3"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 4d0b6f6db7..96b264c879 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,6 @@ click = { version = "*", optional = true } cryptography = { version = "*", optional = true } fast-query-parsers = "*" httpx = ">=0.22" -pydantic-extra-types = { version = "*", optional = true } importlib-metadata = { version = "*", python = "<3.10" } importlib-resources = { version = ">=5.12.0", python = "<3.9" } jinja2 = { version = ">=3.1.2", optional = true } @@ -96,6 +95,7 @@ picologging = { version = "*", optional = true } polyfactory = ">=2.3.2" prometheus-client = { version = "*", optional = true } pydantic = "*" +pydantic-extra-types = { version = "*", optional = true } python-dateutil = "*" python-jose = { version = "*", optional = true } pytimeparse = "*" @@ -140,6 +140,7 @@ pre-commit = "*" prometheus-client = "*" psycopg = "*" pydantic = ">=2" +pydantic-extra-types = "*" pytest = "*" pytest-asyncio = "*" pytest-cov = "*" @@ -160,7 +161,6 @@ structlog = "*" tortoise-orm = "*" trio = "*" uvicorn = "*" -pydantic-extra-types = "*" [tool.poetry.group.docs] optional = true @@ -206,12 +206,12 @@ jinja = ["jinja2"] jwt = ["python-jose", "cryptography"] opentelemetry = ["opentelemetry-instrumentation-asgi"] picologging = ["picologging"] +prometheus = ["prometheus-client"] redis = ["redis"] sqlalchemy = ["sqlalchemy", "alembic"] standard = ["click", "jinja2", "jsbeautifier", "rich", "uvicorn", "rich-click"] structlog = ["structlog"] tortoise-orm = ["tortoise-orm"] -prometheus = ["prometheus-client"] full = [ "alembic", diff --git a/tests/unit/test_openapi/test_parameters.py b/tests/unit/test_openapi/test_parameters.py index 3d8a6664af..7b7d134d6c 100644 --- a/tests/unit/test_openapi/test_parameters.py +++ b/tests/unit/test_openapi/test_parameters.py @@ -32,7 +32,7 @@ def _create_parameters(app: Litestar, path: str) -> List["OpenAPIParameter"]: fn=handler, dependency_name_set=set(), parsed_signature=route_handler.parsed_fn_signature, - ).fields + )._fields return create_parameter_for_handler( route_handler, handler_fields, route.path_parameters, SchemaCreator(generate_examples=True) diff --git a/tests/unit/test_openapi/test_request_body.py b/tests/unit/test_openapi/test_request_body.py index 39cb27c0b5..2c1d7e5605 100644 --- a/tests/unit/test_openapi/test_request_body.py +++ b/tests/unit/test_openapi/test_request_body.py @@ -24,7 +24,7 @@ class Config(BaseConfig): def test_create_request_body(person_controller: Type[Controller]) -> None: for route in Litestar(route_handlers=[person_controller]).routes: for route_handler, _ in route.route_handler_map.values(): # type: ignore - handler_fields = route_handler.signature_model.fields # type: ignore + handler_fields = route_handler.signature_model._fields # type: ignore if "data" in handler_fields: request_body = create_request_body( route_handler=route_handler, diff --git a/tests/unit/test_signature/test_parsing.py b/tests/unit/test_signature/test_parsing.py index bc29064be2..8d8790ac68 100644 --- a/tests/unit/test_signature/test_parsing.py +++ b/tests/unit/test_signature/test_parsing.py @@ -42,7 +42,7 @@ def my_fn(a: int, b: str, c: Optional[bytes], d: bytes = b"123", e: Optional[dic dependency_name_set=set(), parsed_signature=ParsedSignature.from_fn(my_fn.fn.value, {}), ) - fields = model.fields + fields = model._fields assert fields["a"].annotation is int assert not fields["a"].is_optional assert fields["b"].annotation is str @@ -120,8 +120,8 @@ def fn(a: Iterable[int], b: Optional[Iterable[int]]) -> None: parsed_signature=ParsedSignature.from_fn(fn, {}), ) - assert model.fields["a"].is_non_string_iterable - assert model.fields["b"].is_non_string_iterable + assert model._fields["a"].is_non_string_iterable + assert model._fields["b"].is_non_string_iterable def test_field_definition_is_non_string_sequence() -> None: @@ -134,8 +134,8 @@ def fn(a: Sequence[int], b: OptionalSequence[int]) -> None: parsed_signature=ParsedSignature.from_fn(fn, signature_namespace={}), ) - assert model.fields["a"].is_non_string_sequence - assert model.fields["b"].is_non_string_sequence + assert model._fields["a"].is_non_string_sequence + assert model._fields["b"].is_non_string_sequence @pytest.mark.parametrize("query,expected", [("1", True), ("true", True), ("0", False), ("false", False)]) diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index 8f079c7c14..5c61da390a 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -3,7 +3,7 @@ import pytest from attr import define -from pydantic import BaseModel, VERSION +from pydantic import VERSION, BaseModel from typing_extensions import TypedDict from litestar import get, post @@ -50,8 +50,8 @@ def test(dep: int, param: int, optional_dep: Optional[int] = Dependency()) -> No ... with create_test_client( - route_handlers=[test], - dependencies=dependencies, + route_handlers=[test], + dependencies=dependencies, ) as client: response = client.get("/?param=13") @@ -141,11 +141,11 @@ class Parent(BaseModel): @post("/") def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -175,9 +175,19 @@ def test( ] else: assert data["extra"] == [ - {'message': 'Input should be a valid integer, unable to parse string as an integer', 'key': 'child.val'}, - {'message': 'Input should be a valid integer, unable to parse string as an integer', 'key': 'child.other_val'}, - {'message': 'Input should be a valid integer, unable to parse string as an integer', 'key': 'other_child.val.1'}] + { + "message": "Input should be a valid integer, unable to parse string as an integer", + "key": "child.val", + }, + { + "message": "Input should be a valid integer, unable to parse string as an integer", + "key": "child.other_val", + }, + { + "message": "Input should be a valid integer, unable to parse string as an integer", + "key": "other_child.val.1", + }, + ] def test_invalid_input_attrs() -> None: @@ -197,10 +207,10 @@ class Parent: @post("/") def test( - data: Parent, - int_param: int, - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -219,12 +229,10 @@ def test( assert data assert data["extra"] == [ - {"key": "child.val", "message": "invalid literal for int() with base 10: 'a'", "source": "body"}, - {"key": "child.other_val", "message": "invalid literal for int() with base 10: 'b'", "source": "body"}, - {"key": "other_child.val.1", "message": "invalid literal for int() with base 10: 'c'", "source": "body"}, - {"key": "int_param", "message": "invalid literal for int() with base 10: 'param'", "source": "query"}, - {"key": "int_header", "message": "invalid literal for int() with base 10: 'header'", "source": "header"}, - {"key": "int_cookie", "message": "invalid literal for int() with base 10: 'cookie'", "source": "cookie"}, + {"message": "Expected `int`, got `str`", "key": "child.val", "source": "body"}, + {"message": "Expected `int`, got `str`", "key": "int_param", "source": "query"}, + {"message": "Expected `int`, got `str`", "key": "int_header", "source": "header"}, + {"message": "Expected `int`, got `str`", "key": "int_cookie", "source": "cookie"}, ] @@ -245,11 +253,11 @@ class Parent: @post("/") def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -268,13 +276,11 @@ def test( assert data assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + {"message": "Expected `int`, got `str`", "key": "child.val", "source": "body"}, + {"message": "Expected `int`, got `str`", "key": "int_param", "source": "query"}, + {"message": "Expected `int`, got `str`", "key": "length_param", "source": "query"}, + {"message": "Expected `int`, got `str`", "key": "int_header", "source": "header"}, + {"message": "Expected `int`, got `str`", "key": "int_cookie", "source": "cookie"}, ] @@ -292,11 +298,11 @@ class Parent(TypedDict): @post("/") def test( - data: Parent, - int_param: int, - length_param: str = Parameter(min_length=2), - int_header: int = Parameter(header="X-SOME-INT"), - int_cookie: int = Parameter(cookie="int-cookie"), + data: Parent, + int_param: int, + length_param: str = Parameter(min_length=2), + int_header: int = Parameter(header="X-SOME-INT"), + int_cookie: int = Parameter(cookie="int-cookie"), ) -> None: ... @@ -315,11 +321,9 @@ def test( assert data assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + {"message": "Expected `int`, got `str`", "key": "child.val", "source": "body"}, + {"message": "Expected `int`, got `str`", "key": "int_param", "source": "query"}, + {"message": "Expected `int`, got `str`", "key": "length_param", "source": "query"}, + {"message": "Expected `int`, got `str`", "key": "int_header", "source": "header"}, + {"message": "Expected `int`, got `str`", "key": "int_cookie", "source": "cookie"}, ] From d21ee1fece5e97f3bba7eb951f71defa6fcdd5f6 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 15 Jul 2023 10:23:20 +0200 Subject: [PATCH 4/5] chore(infra): fix ci issues --- .github/workflows/test.yaml | 5 +- docs/conf.py | 2 +- .../sqlalchemy/sqlalchemy_async_repository.py | 3 +- .../sqlalchemy_repository_extension.py | 3 +- .../sqlalchemy/sqlalchemy_sync_repository.py | 3 +- .../dto/factory/_backends/pydantic/utils.py | 13 +++-- .../serialization/_pydantic_serialization.py | 4 +- poetry.lock | 50 ++----------------- pyproject.toml | 3 +- tests/unit/test_serialization.py | 12 ++++- tests/unit/test_signature/test_validation.py | 10 ++-- 11 files changed, 41 insertions(+), 67 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 8f59979718..6cc1ad9ad9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -60,7 +60,10 @@ jobs: run: poetry install --no-interaction - if: ${{ inputs.pydantic-version == '1' }} name: Install pydantic v1 - run: source .venv/bin/activate && pip install "pydantic>=1.10.10" + run: poetry add "pydantic>=1.10.10,<2" && poetry remove pydantic-extra-types + - if: ${{ inputs.pydantic-version == '2' }} + name: Install pydantic v2 + run: poetry add "pydantic>=2" && poetry add pydantic-extra-types - name: Set pythonpath run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV - name: Test diff --git a/docs/conf.py b/docs/conf.py index ac6aa5e7fe..9327834a1d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -87,7 +87,6 @@ (PY_METH, "_types.TypeDecorator.process_bind_param"), (PY_METH, "_types.TypeDecorator.process_result_value"), (PY_METH, "type_engine"), - (PY_METH, "litestar.typing.ParsedType.is_subclass_of"), # type vars and aliases / intentionally undocumented (PY_CLASS, "RouteHandlerType"), (PY_OBJ, "litestar.security.base.AuthType"), @@ -114,6 +113,7 @@ (PY_CLASS, "litestar.response.RedirectResponse"), (PY_CLASS, "anyio.abc.BlockingPortal"), (PY_CLASS, "litestar.typing.ParsedType"), + (PY_CLASS, "pydantic.BaseModel"), ] nitpick_ignore_regex = [ diff --git a/docs/examples/contrib/sqlalchemy/sqlalchemy_async_repository.py b/docs/examples/contrib/sqlalchemy/sqlalchemy_async_repository.py index 5674932f92..f00f132abf 100644 --- a/docs/examples/contrib/sqlalchemy/sqlalchemy_async_repository.py +++ b/docs/examples/contrib/sqlalchemy/sqlalchemy_async_repository.py @@ -25,8 +25,7 @@ class BaseModel(_BaseModel): """Extend Pydantic's BaseModel to enable ORM mode""" - class Config: - orm_mode = True + model_config = {"from_attributes": True} # the SQLAlchemy base includes a declarative model for you to use in your models. diff --git a/docs/examples/contrib/sqlalchemy/sqlalchemy_repository_extension.py b/docs/examples/contrib/sqlalchemy/sqlalchemy_repository_extension.py index fe446b94f3..85fe99a052 100644 --- a/docs/examples/contrib/sqlalchemy/sqlalchemy_repository_extension.py +++ b/docs/examples/contrib/sqlalchemy/sqlalchemy_repository_extension.py @@ -25,8 +25,7 @@ class BaseModel(_BaseModel): """Extend Pydantic's BaseModel to enable ORM mode""" - class Config: - orm_mode = True + model_config = {"from_attributes": True} # we are going to add a simple "slug" to our model that is a URL safe surrogate key to diff --git a/docs/examples/contrib/sqlalchemy/sqlalchemy_sync_repository.py b/docs/examples/contrib/sqlalchemy/sqlalchemy_sync_repository.py index 35c4a565d0..2c3bbb5005 100644 --- a/docs/examples/contrib/sqlalchemy/sqlalchemy_sync_repository.py +++ b/docs/examples/contrib/sqlalchemy/sqlalchemy_sync_repository.py @@ -25,8 +25,7 @@ class BaseModel(_BaseModel): """Extend Pydantic's BaseModel to enable ORM mode""" - class Config: - orm_mode = True + model_config = {"from_attributes": True} # the SQLAlchemy base includes a declarative model for you to use in your models. diff --git a/litestar/dto/factory/_backends/pydantic/utils.py b/litestar/dto/factory/_backends/pydantic/utils.py index bd8f413b3c..690e64256e 100644 --- a/litestar/dto/factory/_backends/pydantic/utils.py +++ b/litestar/dto/factory/_backends/pydantic/utils.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, TypeVar, Union from msgspec import UNSET, UnsetType -from pydantic import BaseModel, create_model +from pydantic import VERSION, BaseModel, create_model from pydantic.fields import FieldInfo from litestar.dto.factory._backends.utils import create_transfer_model_type_annotation @@ -21,9 +21,14 @@ class _OrmModeBase(BaseModel): - class Config: - arbitrary_types_allowed = True - orm_mode = True + if VERSION.startswith("1"): + + class Config: + arbitrary_types_allowed = True + orm_mode = True + + else: + model_config = {"arbitrary_types_allowed": True, "from_attributes": True} def _create_field_info(field_definition: TransferDTOFieldDefinition) -> FieldInfo: diff --git a/litestar/serialization/_pydantic_serialization.py b/litestar/serialization/_pydantic_serialization.py index fe742d78ff..4504cd9f2f 100644 --- a/litestar/serialization/_pydantic_serialization.py +++ b/litestar/serialization/_pydantic_serialization.py @@ -81,7 +81,9 @@ def create_pydantic_encoders() -> dict[Any, Callable[[Any], Any]]: if pydantic.VERSION.startswith("1"): # pragma: no cover encoders.update( { - pydantic.BaseModel: lambda model: model.dict(), + pydantic.BaseModel: lambda model: { + k: v if not isinstance(v, bytes) else v.decode() for k, v in model.dict().items() + }, pydantic.SecretField: str, pydantic.StrictBool: int, pydantic.color.Color: str, # pyright: ignore diff --git a/poetry.lock b/poetry.lock index 5648ab4e4f..af8600cadb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -3720,46 +3720,6 @@ description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.18-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7ddd6d35c598af872f9a0a5bce7f7c4a1841684a72dab3302e3df7f17d1b5249"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:00aa050faf24ce5f2af643e2b86822fa1d7149649995f11bc1e769bbfbf9010b"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b52c6741073de5a744d27329f9803938dcad5c9fee7e61690c705f72973f4175"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db97eabd440327c35b751d5ebf78a107f505586485159bcc87660da8bb1fdca"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:589aba9a35869695b319ed76c6f673d896cd01a7ff78054be1596df7ad9b096f"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9da4ee8f711e077633730955c8f3cd2485c9abf5ea0f80aac23221a3224b9a8c"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-win32.whl", hash = "sha256:5dd574a37be388512c72fe0d7318cb8e31743a9b2699847a025e0c08c5bf579d"}, - {file = "SQLAlchemy-2.0.18-cp310-cp310-win_amd64.whl", hash = "sha256:6852cd34d96835e4c9091c1e6087325efb5b607b75fd9f7075616197d1c4688a"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10e001a84f820fea2640e4500e12322b03afc31d8f4f6b813b44813b2a7c7e0d"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bffd6cd47c2e68970039c0d3e355c9ed761d3ca727b204e63cd294cad0e3df90"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b7b3ebfa9416c8eafaffa65216e229480c495e305a06ba176dcac32710744e6"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:79228a7b90d95957354f37b9d46f2cc8926262ae17b0d3ed8f36c892f2a37e06"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ba633b51835036ff0f402c21f3ff567c565a22ff0a5732b060a68f4660e2a38f"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8da677135eff43502b7afab5a1e641edfb2dc734ba7fc146e9b1b86817a728e2"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-win32.whl", hash = "sha256:82edf3a6090554a83942cec79151d6b5eb96e63d143e80e4cf6671e5d772f6be"}, - {file = "SQLAlchemy-2.0.18-cp311-cp311-win_amd64.whl", hash = "sha256:69ae0e9509c43474e33152abe1385b8954922544616426bf793481e1a37e094f"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:09397a18733fa2a4c7680b746094f980060666ee549deafdb5e102a99ce4619b"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45b07470571bda5ee7f5ec471271bbde97267cc8403fce05e280c36ea73f4754"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1aac42a21a7fa6c9665392c840b295962992ddf40aecf0a88073bc5c76728117"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:da46beef0ce882546d92b7b2e8deb9e04dbb8fec72945a8eb28b347ca46bc15a"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a6f1d8256d06f58e6ece150fbe05c63c7f9510df99ee8ac37423f5476a2cebb4"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-win32.whl", hash = "sha256:67fbb40db3985c0cfb942fe8853ad94a5e9702d2987dec03abadc2f3b6a24afb"}, - {file = "SQLAlchemy-2.0.18-cp37-cp37m-win_amd64.whl", hash = "sha256:afb322ca05e2603deedbcd2e9910f11a3fd2f42bdeafe63018e5641945c7491c"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:908c850b98cac1e203ababd4ba76868d19ae0d7172cdc75d3f1b7829b16837d2"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10514adc41fc8f5922728fbac13d401a1aefcf037f009e64ca3b92464e33bf0e"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b791577c546b6bbd7b43953565fcb0a2fec63643ad605353dd48afbc3c48317"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:420bc6d06d4ae7fb6921524334689eebcbea7bf2005efef070a8562cc9527a37"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ebdd2418ab4e2e26d572d9a1c03877f8514a9b7436729525aa571862507b3fea"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:556dc18e39b6edb76239acfd1c010e37395a54c7fde8c57481c15819a3ffb13e"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-win32.whl", hash = "sha256:7b8cba5a25e95041e3413d91f9e50616bcfaec95afa038ce7dc02efefe576745"}, - {file = "SQLAlchemy-2.0.18-cp38-cp38-win_amd64.whl", hash = "sha256:0f7fdcce52cd882b559a57b484efc92e108efeeee89fab6b623aba1ac68aad2e"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d7a2c1e711ce59ac9d0bba780318bcd102d2958bb423209f24c6354d8c4da930"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5c95e3e7cc6285bf7ff263eabb0d3bfe3def9a1ff98124083d45e5ece72f4579"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc44e50f9d5e96af1a561faa36863f9191f27364a4df3eb70bca66e9370480b6"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfa1a0f83bdf8061db8d17c2029454722043f1e4dd1b3d3d3120d1b54e75825a"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:194f2d5a7cb3739875c4d25b3fe288ab0b3dc33f7c857ba2845830c8c51170a0"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4ebc542d2289c0b016d6945fd07a7e2e23f4abc41e731ac8ad18a9e0c2fd0ec2"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-win32.whl", hash = "sha256:774bd401e7993452ba0596e741c0c4d6d22f882dd2a798993859181dbffadc62"}, - {file = "SQLAlchemy-2.0.18-cp39-cp39-win_amd64.whl", hash = "sha256:2756485f49e7df5c2208bdc64263d19d23eba70666f14ad12d6d8278a2fff65f"}, - {file = "SQLAlchemy-2.0.18-py3-none-any.whl", hash = "sha256:6c5bae4c288bda92a7550fe8de9e068c0a7cd56b1c5d888aae5b40f0e13b40bd"}, {file = "SQLAlchemy-2.0.18.tar.gz", hash = "sha256:1fb792051db66e09c200e7bc3bda3b1eb18a5b8eb153d2cedb2b14b56a68b8cb"}, ] @@ -3769,7 +3729,7 @@ typing-extensions = ">=4.2.0" [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] -aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing-extensions (!=3.10.0.1)"] +aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] @@ -3779,7 +3739,7 @@ mssql-pyodbc = ["pyodbc"] mypy = ["mypy (>=0.910)"] mysql = ["mysqlclient (>=1.4.0)"] mysql-connector = ["mysql-connector-python"] -oracle = ["cx-oracle (>=7)"] +oracle = ["cx_oracle (>=7)"] oracle-oracledb = ["oracledb (>=1.0.1)"] postgresql = ["psycopg2 (>=2.7)"] postgresql-asyncpg = ["asyncpg", "greenlet (!=0.4.17)"] @@ -3789,7 +3749,7 @@ postgresql-psycopg2binary = ["psycopg2-binary"] postgresql-psycopg2cffi = ["psycopg2cffi"] postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] -sqlcipher = ["sqlcipher3-binary"] +sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlalchemy-spanner" @@ -4435,4 +4395,4 @@ tortoise-orm = ["tortoise-orm"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "3a803a93020a303daa341b9f32f91d5f9be5c4353d3916c9b5b5987ce80db47f" +content-hash = "f0e156f0fa0734ce9f6e19bfe50d94c87b98bb7bc2b98cd31803a67ff986bc2c" diff --git a/pyproject.toml b/pyproject.toml index 96b264c879..803298701b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,7 +139,7 @@ picologging = "*" pre-commit = "*" prometheus-client = "*" psycopg = "*" -pydantic = ">=2" +pydantic = "*" pydantic-extra-types = "*" pytest = "*" pytest-asyncio = "*" @@ -259,6 +259,7 @@ exclude_lines = [ 'except ImportError:', '\.\.\.', 'raise NotImplementedError', + 'if VERSION.startswith("1"):', ] [tool.pytest.ini_options] diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index 7f136e78bd..9908fd07b5 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -171,7 +171,17 @@ def test_serialization_of_model_instance(model: BaseModel) -> None: def test_pydantic_json_compatibility(model: BaseModel) -> None: raw = model.model_dump_json() if hasattr(model, "model_dump_json") else model.json() encoded_json = encode_json(model) - assert json.loads(raw) == json.loads(encoded_json) + + raw_result = json.loads(raw) + encoded_result = json.loads(encoded_json) + + if VERSION.startswith("1"): + # pydantic v1 dumps decimals into floats as json, we therefore regard this as an error + assert raw_result.get("condecimal") == float(encoded_result.get("condecimal")) + del raw_result["condecimal"] + del encoded_result["condecimal"] + + assert raw_result == encoded_result @pytest.mark.parametrize("encoder", [encode_json, encode_msgpack]) diff --git a/tests/unit/test_signature/test_validation.py b/tests/unit/test_signature/test_validation.py index 5c61da390a..85cd786747 100644 --- a/tests/unit/test_signature/test_validation.py +++ b/tests/unit/test_signature/test_validation.py @@ -165,13 +165,9 @@ def test( assert data if VERSION.startswith("1"): assert data["extra"] == [ - {"key": "child.val", "message": "value is not a valid integer", "source": "body"}, - {"key": "child.other_val", "message": "value is not a valid integer", "source": "body"}, - {"key": "other_child.val.1", "message": "value is not a valid integer", "source": "body"}, - {"key": "int_param", "message": "value is not a valid integer", "source": "query"}, - {"key": "length_param", "message": "ensure this value has at least 2 characters", "source": "query"}, - {"key": "int_header", "message": "value is not a valid integer", "source": "header"}, - {"key": "int_cookie", "message": "value is not a valid integer", "source": "cookie"}, + {"key": "child.val", "message": "value is not a valid integer"}, + {"key": "child.other_val", "message": "value is not a valid integer"}, + {"key": "other_child.val.1", "message": "value is not a valid integer"}, ] else: assert data["extra"] == [ From d6848b9e998edbfa4c45136930db5035a85f1e4a Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 15 Jul 2023 11:16:33 +0200 Subject: [PATCH 5/5] chore(docs): fix warning --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index 9327834a1d..335490ae26 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -87,6 +87,7 @@ (PY_METH, "_types.TypeDecorator.process_bind_param"), (PY_METH, "_types.TypeDecorator.process_result_value"), (PY_METH, "type_engine"), + (PY_METH, "litestar.typing.ParsedType.is_subclass_of"), # type vars and aliases / intentionally undocumented (PY_CLASS, "RouteHandlerType"), (PY_OBJ, "litestar.security.base.AuthType"),