From 636fbdc59a0c28584dfbe6757a25e48b16702d77 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 3 Jul 2023 12:08:06 +0100 Subject: [PATCH 1/3] Fix strict optional handling in dataclasses --- mypy/plugins/dataclasses.py | 30 +++++++++++++++------------ test-data/unit/check-dataclasses.test | 13 ++++++++++++ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 9e054493828f..efa1338962a0 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -104,6 +104,7 @@ def __init__( info: TypeInfo, kw_only: bool, is_neither_frozen_nor_nonfrozen: bool, + api: SemanticAnalyzerPluginInterface, ) -> None: self.name = name self.alias = alias @@ -116,6 +117,7 @@ def __init__( self.info = info self.kw_only = kw_only self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen + self._api = api def to_argument(self, current_info: TypeInfo) -> Argument: arg_kind = ARG_POS @@ -138,7 +140,10 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]: # however this plugin is called very late, so all types should be fully ready. # Also, it is tricky to avoid eager expansion of Self types here (e.g. because # we serialize attributes). - return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)}) + with state.strict_optional_set(self._api.options.strict_optional): + return expand_type( + self.type, {self.info.self_type.id: fill_typevars(current_info)} + ) return self.type def to_var(self, current_info: TypeInfo) -> Var: @@ -165,13 +170,14 @@ def deserialize( ) -> DataclassAttribute: data = data.copy() typ = deserialize_and_fixup_type(data.pop("type"), api) - return cls(type=typ, info=info, **data) + return cls(type=typ, info=info, **data, api=api) def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: """Expands type vars in the context of a subtype when an attribute is inherited from a generic super type.""" if self.type is not None: - self.type = map_type_from_supertype(self.type, sub_type, self.info) + with state.strict_optional_set(self._api.options.strict_optional): + self.type = map_type_from_supertype(self.type, sub_type, self.info) class DataclassTransformer: @@ -230,12 +236,11 @@ def transform(self) -> bool: and ("__init__" not in info.names or info.names["__init__"].plugin_generated) and attributes ): - with state.strict_optional_set(self._api.options.strict_optional): - args = [ - attr.to_argument(info) - for attr in attributes - if attr.is_in_init and not self._is_kw_only_type(attr.type) - ] + args = [ + attr.to_argument(info) + for attr in attributes + if attr.is_in_init and not self._is_kw_only_type(attr.type) + ] if info.fallback_to_any: # Make positional args optional since we don't know their order. @@ -355,8 +360,7 @@ def transform(self) -> bool: self._add_dataclass_fields_magic_attribute() if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES: - with state.strict_optional_set(self._api.options.strict_optional): - self._add_internal_replace_method(attributes) + self._add_internal_replace_method(attributes) if "__post_init__" in info.names: self._add_internal_post_init_method(attributes) @@ -546,8 +550,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # TODO: We shouldn't be performing type operations during the main # semantic analysis pass, since some TypeInfo attributes might # still be in flux. This should be performed in a later phase. - with state.strict_optional_set(self._api.options.strict_optional): - attr.expand_typevar_from_subtype(cls.info) + attr.expand_typevar_from_subtype(cls.info) found_attrs[name] = attr sym_node = cls.info.names.get(name) @@ -693,6 +696,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass( cls.info ), + api=self._api, ) all_attrs = list(found_attrs.values()) diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index 4a6e737ddd8d..86fc9768603d 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -2420,3 +2420,16 @@ class Test(Protocol): def reset(self) -> None: self.x = DEFAULT [builtins fixtures/dataclasses.pyi] + +[case testStrictOptionalAlwaysSet] +# flags: --strict-optional +from dataclasses import dataclass +from typing import Callable, Optional + +@dataclass +class Description: + name_fn: Callable[[Optional[int]], Optional[str]] + +def f(d: Description) -> None: + reveal_type(d.name_fn) # N: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]" +[builtins fixtures/dataclasses.pyi] From a7fecdb370c0fae44212c9d161e83bc32f465423 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 3 Jul 2023 13:35:23 +0100 Subject: [PATCH 2/3] Move test to real stubs --- test-data/unit/check-dataclasses.test | 13 ------------- test-data/unit/pythoneval.test | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index 86fc9768603d..4a6e737ddd8d 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -2420,16 +2420,3 @@ class Test(Protocol): def reset(self) -> None: self.x = DEFAULT [builtins fixtures/dataclasses.pyi] - -[case testStrictOptionalAlwaysSet] -# flags: --strict-optional -from dataclasses import dataclass -from typing import Callable, Optional - -@dataclass -class Description: - name_fn: Callable[[Optional[int]], Optional[str]] - -def f(d: Description) -> None: - reveal_type(d.name_fn) # N: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]" -[builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 1460002e1b65..e4837e79c61b 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -2109,3 +2109,17 @@ reveal_type(a2) [out] _testDataclassReplaceOptional.py:10: note: Revealed type is "_testDataclassReplaceOptional.A" _testDataclassReplaceOptional.py:12: note: Revealed type is "_testDataclassReplaceOptional.A" + +[case testDataclassStrictOptionalAlwaysSet] +# flags: --strict-optional +from dataclasses import dataclass +from typing import Callable, Optional + +@dataclass +class Description: + name_fn: Callable[[Optional[int]], Optional[str]] + +def f(d: Description) -> None: + reveal_type(d.name_fn) +[out] +_testDataclassStrictOptionalAlwaysSet.py:10: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]" From a3b46d35d72c2d7e8cfc2eced9c2a9aa7d029ec2 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 3 Jul 2023 19:18:57 +0100 Subject: [PATCH 3/3] pythoneval already has strict optional --- test-data/unit/pythoneval.test | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index e4837e79c61b..adc8c2e2e8ee 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -2094,7 +2094,6 @@ grouped = groupby(pairs, key=fst) [out] [case testDataclassReplaceOptional] -# flags: --strict-optional from dataclasses import dataclass, replace from typing import Optional @@ -2107,11 +2106,10 @@ reveal_type(a) a2 = replace(a, x=None) # OK reveal_type(a2) [out] -_testDataclassReplaceOptional.py:10: note: Revealed type is "_testDataclassReplaceOptional.A" -_testDataclassReplaceOptional.py:12: note: Revealed type is "_testDataclassReplaceOptional.A" +_testDataclassReplaceOptional.py:9: note: Revealed type is "_testDataclassReplaceOptional.A" +_testDataclassReplaceOptional.py:11: note: Revealed type is "_testDataclassReplaceOptional.A" [case testDataclassStrictOptionalAlwaysSet] -# flags: --strict-optional from dataclasses import dataclass from typing import Callable, Optional @@ -2122,4 +2120,4 @@ class Description: def f(d: Description) -> None: reveal_type(d.name_fn) [out] -_testDataclassStrictOptionalAlwaysSet.py:10: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]" +_testDataclassStrictOptionalAlwaysSet.py:9: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]"