diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index c71d898e1c62a..45ce6db4596ed 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -6,8 +6,10 @@ from typing_extensions import Final, Literal import mypy.plugin # To avoid circular imports. +from mypy.applytype import apply_generic_arguments from mypy.checker import TypeChecker from mypy.errorcodes import LITERAL_REQ +from mypy.expandtype import expand_type from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type from mypy.messages import format_type_bare from mypy.nodes import ( @@ -23,6 +25,7 @@ Decorator, Expression, FuncDef, + IndexExpr, JsonDict, LambdaExpr, ListExpr, @@ -34,6 +37,7 @@ SymbolTableNode, TempNode, TupleExpr, + TypeApplication, TypeInfo, TypeVarExpr, Var, @@ -49,7 +53,7 @@ deserialize_and_fixup_type, ) from mypy.server.trigger import make_wildcard_trigger -from mypy.typeops import make_simplified_union, map_type_from_supertype +from mypy.typeops import get_type_vars, make_simplified_union, map_type_from_supertype from mypy.types import ( AnyType, CallableType, @@ -85,8 +89,9 @@ class Converter: """Holds information about a `converter=` argument""" - def __init__(self, init_type: Type | None = None) -> None: + def __init__(self, init_type: Type | None = None, ret_type: Type | None = None) -> None: self.init_type = init_type + self.ret_type = ret_type class Attribute: @@ -115,11 +120,20 @@ def __init__( def argument(self, ctx: mypy.plugin.ClassDefContext) -> Argument: """Return this attribute as an argument to __init__.""" assert self.init - init_type: Type | None = None if self.converter: if self.converter.init_type: init_type = self.converter.init_type + if init_type and self.init_type and self.converter.ret_type: + # The converter return type should be the same type as the attribute type. + # Copy type vars from attr type to converter. + converter_vars = get_type_vars(self.converter.ret_type) + init_vars = get_type_vars(self.init_type) + if converter_vars and len(converter_vars) == len(init_vars): + variables = { + binder.id: arg for binder, arg in zip(converter_vars, init_vars) + } + init_type = expand_type(init_type, variables) else: ctx.api.fail("Cannot determine __init__ type from converter", self.context) init_type = AnyType(TypeOfAny.from_error) @@ -653,6 +667,26 @@ def _parse_converter( from mypy.checkmember import type_object_type # To avoid import cycle. converter_type = type_object_type(converter_expr.node, ctx.api.named_type) + elif ( + isinstance(converter_expr, IndexExpr) + and isinstance(converter_expr.analyzed, TypeApplication) + and isinstance(converter_expr.base, RefExpr) + and isinstance(converter_expr.base.node, TypeInfo) + ): + # The converter is a generic type. + from mypy.checkmember import type_object_type # To avoid import cycle. + + converter_type = type_object_type(converter_expr.base.node, ctx.api.named_type) + if isinstance(converter_type, CallableType): + converter_type = apply_generic_arguments( + converter_type, + converter_expr.analyzed.types, + ctx.api.msg.incompatible_typevar_value, + converter_type, + ) + else: + converter_type = None + if isinstance(converter_expr, LambdaExpr): # TODO: should we send a fail if converter_expr.min_args > 1? converter_info.init_type = AnyType(TypeOfAny.unannotated) @@ -671,6 +705,8 @@ def _parse_converter( converter_type = get_proper_type(converter_type) if isinstance(converter_type, CallableType) and converter_type.arg_types: converter_info.init_type = converter_type.arg_types[0] + if not is_attr_converters_optional: + converter_info.ret_type = converter_type.ret_type elif isinstance(converter_type, Overloaded): types: list[Type] = [] for item in converter_type.items: diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 83a441aca2330..89eab92ebabdf 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -469,6 +469,56 @@ A([1], '2') # E: Cannot infer type argument 1 of "A" [builtins fixtures/list.pyi] +[case testAttrsGenericWithConverter] +from typing import TypeVar, Generic, List, Iterable, Iterator, Callable +import attr +T = TypeVar('T') + +def int_gen() -> Iterator[int]: + yield 1 + +def list_converter(x: Iterable[T]) -> List[T]: + return list(x) + +@attr.s(auto_attribs=True) +class A(Generic[T]): + x: List[T] = attr.ib(converter=list_converter) + y: T = attr.ib() + def foo(self) -> List[T]: + return [self.y] + def bar(self) -> T: + return self.x[0] + def problem(self) -> T: + return self.x # E: Incompatible return value type (got "List[T]", expected "T") +reveal_type(A) # N: Revealed type is "def [T] (x: typing.Iterable[T`1], y: T`1) -> __main__.A[T`1]" +a1 = A([1], 2) +reveal_type(a1) # N: Revealed type is "__main__.A[builtins.int]" +reveal_type(a1.x) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(a1.y) # N: Revealed type is "builtins.int" + +a2 = A(int_gen(), 2) +reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]" +reveal_type(a2.x) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(a2.y) # N: Revealed type is "builtins.int" + + +def get_int() -> int: + return 1 + +class Other(Generic[T]): + def __init__(self, x: T) -> None: + pass + +@attr.s(auto_attribs=True) +class B(Generic[T]): + x: Other[Callable[..., T]] = attr.ib(converter=Other[Callable[..., T]]) + +b1 = B(get_int) +reveal_type(b1) # N: Revealed type is "__main__.B[builtins.int]" +reveal_type(b1.x) # N: Revealed type is "__main__.Other[def (*Any, **Any) -> builtins.int]" + +[builtins fixtures/list.pyi] + [case testAttrsUntypedGenericInheritance] from typing import Generic, TypeVar