Skip to content

Commit

Permalink
Add support for generic NamedTuple following python/cpython#92027
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Jan 22, 2023
1 parent 985e91a commit 3467f0b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 15 deletions.
10 changes: 8 additions & 2 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,13 @@ def pack_tuple(spec: ValueSpec, args: Tuple[Type, ...]) -> Expression:


def pack_named_tuple(spec: ValueSpec) -> Expression:
annotations = getattr(spec.type, "__annotations__", {})
resolved = resolve_type_params(spec.origin_type, get_args(spec.type))[
spec.origin_type
]
annotations = {
k: resolved.get(v, v)
for k, v in getattr(spec.origin_type, "__annotations__", {}).items()
}
fields = getattr(spec.type, "_fields", ())
packers = []
as_dict = spec.builder.get_config().namedtuple_as_dict
Expand Down Expand Up @@ -619,7 +625,7 @@ def inner_expr(
elif issubclass(spec.origin_type, str):
return spec.expression
elif issubclass(spec.origin_type, Tuple): # type: ignore
if is_named_tuple(spec.type):
if is_named_tuple(spec.origin_type):
return pack_named_tuple(spec)
elif ensure_generic_collection(spec):
return pack_tuple(spec, args)
Expand Down
10 changes: 8 additions & 2 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,13 @@ def unpack_tuple(spec: ValueSpec, args: Tuple[Type, ...]) -> Expression:


def unpack_named_tuple(spec: ValueSpec) -> Expression:
annotations = getattr(spec.type, "__annotations__", {})
resolved = resolve_type_params(spec.origin_type, get_args(spec.type))[
spec.origin_type
]
annotations = {
k: resolved.get(v, v)
for k, v in getattr(spec.origin_type, "__annotations__", {}).items()
}
fields = getattr(spec.type, "_fields", ())
defaults = getattr(spec.type, "_field_defaults", {})
unpackers = []
Expand Down Expand Up @@ -757,7 +763,7 @@ def inner_expr(
f"for value in {spec.expression}])"
)
elif issubclass(spec.origin_type, Tuple): # type: ignore
if is_named_tuple(spec.type):
if is_named_tuple(spec.origin_type):
return unpack_named_tuple(spec)
elif ensure_generic_collection(spec):
return unpack_tuple(spec, args)
Expand Down
18 changes: 7 additions & 11 deletions tests/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,7 @@
from datetime import date, datetime
from enum import Enum, Flag, IntEnum, IntFlag
from os import PathLike
from typing import (
Any,
Generic,
List,
NamedTuple,
NewType,
Optional,
TypeVar,
Union,
)
from typing import Any, Generic, List, NewType, Optional, TypeVar, Union

try:
from enum import StrEnum
Expand All @@ -22,7 +13,7 @@ class StrEnum(str, Enum):
pass


from typing_extensions import TypedDict
from typing_extensions import NamedTuple, TypedDict

from mashumaro import DataClassDictMixin
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig
Expand Down Expand Up @@ -287,4 +278,9 @@ class MyNamedTupleWithOptional(NamedTuple):
)


class GenericNamedTuple(NamedTuple, Generic[T]):
x: T
y: int


MyDatetimeNewType = NewType("MyDatetimeNewType", datetime)
21 changes: 21 additions & 0 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .entities import (
CustomPath,
DataClassWithoutMixin,
GenericNamedTuple,
GenericSerializableList,
GenericSerializableTypeDataClass,
GenericTypedDict,
Expand Down Expand Up @@ -1319,6 +1320,26 @@ class DataClass(DataClassDictMixin):
assert DataClass().to_dict() == {"x": [None, 7]}


def test_unbound_generic_named_tuple():
@dataclass
class DataClass(DataClassDictMixin):
x: GenericNamedTuple

obj = DataClass(GenericNamedTuple("2023-01-22", 42))
assert DataClass.from_dict({"x": ["2023-01-22", "42"]}) == obj
assert obj.to_dict() == {"x": ["2023-01-22", 42]}


def test_bound_generic_named_tuple():
@dataclass
class DataClass(DataClassDictMixin):
x: GenericNamedTuple[date]

obj = DataClass(GenericNamedTuple(date(2023, 1, 22), 42))
assert DataClass.from_dict({"x": ["2023-01-22", "42"]}) == obj
assert obj.to_dict() == {"x": ["2023-01-22", 42]}


def test_typed_dict_required_keys_with_optional():
@dataclass
class DataClass(DataClassDictMixin):
Expand Down

0 comments on commit 3467f0b

Please sign in to comment.