Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for variadic type aliases #15219

Merged
merged 11 commits into from
May 21, 2023
48 changes: 44 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,13 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
flatten_nested_unions,
get_proper_type,
get_proper_types,
has_recursive_types,
is_named_instance,
split_with_prefix_and_suffix,
)
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
from mypy.typestate import type_state
Expand Down Expand Up @@ -4070,6 +4072,35 @@ class LongName(Generic[T]): ...
# The _SpecialForm type can be used in some runtime contexts (e.g. it may have __or__).
return self.named_type("typing._SpecialForm")

def split_for_callable(
self, t: CallableType, args: Sequence[Type], ctx: Context
) -> list[Type]:
"""Handle directly applying type arguments to a variadic Callable.

This is needed in situations where e.g. variadic class object appears in
runtime context. For example:
class C(Generic[T, Unpack[Ts]]): ...
x = C[int, str]()

We simply group the arguments that need to go into Ts variable into a TupleType,
similar to how it is done in other places using split_with_prefix_and_suffix().
"""
vars = t.variables
if not vars or not any(isinstance(v, TypeVarTupleType) for v in vars):
return list(args)

prefix = next(i for (i, v) in enumerate(vars) if isinstance(v, TypeVarTupleType))
suffix = len(vars) - prefix - 1
args = flatten_nested_tuples(args)
if len(args) < len(vars) - 1:
self.msg.incompatible_type_application(len(vars), len(args), ctx)
return [AnyType(TypeOfAny.from_error)] * len(vars)

tvt = vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
start, middle, end = split_with_prefix_and_suffix(tuple(args), prefix, suffix)
return list(start) + [TupleType(list(middle), tvt.tuple_fallback)] + list(end)

def apply_type_arguments_to_callable(
self, tp: Type, args: Sequence[Type], ctx: Context
) -> Type:
Expand All @@ -4083,19 +4114,28 @@ def apply_type_arguments_to_callable(
tp = get_proper_type(tp)

if isinstance(tp, CallableType):
if len(tp.variables) != len(args):
if len(tp.variables) != len(args) and not any(
isinstance(v, TypeVarTupleType) for v in tp.variables
):
if tp.is_type_obj() and tp.type_object().fullname == "builtins.tuple":
# TODO: Specialize the callable for the type arguments
return tp
self.msg.incompatible_type_application(len(tp.variables), len(args), ctx)
return AnyType(TypeOfAny.from_error)
return self.apply_generic_arguments(tp, args, ctx)
return self.apply_generic_arguments(tp, self.split_for_callable(tp, args, ctx), ctx)
if isinstance(tp, Overloaded):
for it in tp.items:
if len(it.variables) != len(args):
if len(it.variables) != len(args) and not any(
isinstance(v, TypeVarTupleType) for v in it.variables
):
self.msg.incompatible_type_application(len(it.variables), len(args), ctx)
return AnyType(TypeOfAny.from_error)
return Overloaded([self.apply_generic_arguments(it, args, ctx) for it in tp.items])
return Overloaded(
[
self.apply_generic_arguments(it, self.split_for_callable(it, args, ctx), ctx)
for it in tp.items
]
)
return AnyType(TypeOfAny.special_form)

def visit_list_expr(self, e: ListExpr) -> Type:
Expand Down
14 changes: 5 additions & 9 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, List, Sequence
from typing import TYPE_CHECKING, Iterable, List, Sequence, cast
from typing_extensions import Final

import mypy.subtypes
Expand Down Expand Up @@ -46,15 +46,11 @@
has_recursive_types,
has_type_vars,
is_named_instance,
split_with_prefix_and_suffix,
)
from mypy.types_utils import is_union_with_any
from mypy.typestate import type_state
from mypy.typevartuples import (
extract_unpack,
find_unpack_in_list,
split_with_mapped_and_template,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import extract_unpack, find_unpack_in_list, split_with_mapped_and_template

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
Expand Down Expand Up @@ -669,7 +665,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
)
tvars = list(tvars_prefix + tvars_suffix)
tvars = cast("list[TypeVarLikeType]", list(tvars_prefix + tvars_suffix))
else:
mapped_args = mapped.args
instance_args = instance.args
Expand Down Expand Up @@ -738,7 +734,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
template.type.type_var_tuple_prefix,
template.type.type_var_tuple_suffix,
)
tvars = list(tvars_prefix + tvars_suffix)
tvars = cast("list[TypeVarLikeType]", list(tvars_prefix + tvars_suffix))
else:
mapped_args = mapped.args
template_args = template.args
Expand Down
40 changes: 26 additions & 14 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
flatten_nested_unions,
get_proper_type,
)
from mypy.typevartuples import (
find_unpack_in_list,
split_with_instance,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import find_unpack_in_list, split_with_instance

# WARNING: these functions should never (directly or indirectly) depend on
# is_subtype(), meet_types(), join_types() etc.
Expand Down Expand Up @@ -115,6 +113,7 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
instance_args = instance.args

for binder, arg in zip(tvars, instance_args):
assert isinstance(binder, TypeVarLikeType)
variables[binder.id] = arg

return expand_type(typ, variables)
Expand Down Expand Up @@ -282,12 +281,14 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
raise NotImplementedError

def visit_unpack_type(self, t: UnpackType) -> Type:
# It is impossible to reasonally implement visit_unpack_type, because
# It is impossible to reasonably implement visit_unpack_type, because
# unpacking inherently expands to something more like a list of types.
#
# Relevant sections that can call unpack should call expand_unpack()
# instead.
assert False, "Mypy bug: unpacking must happen at a higher level"
# However, if the item is a variadic tuple, we can simply carry it over.
# it is hard to assert this without getting proper type.
return UnpackType(t.type.accept(self))

def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None:
return expand_unpack_with_variables(t, self.variables)
Expand Down Expand Up @@ -356,7 +357,15 @@ def interpolate_args_for_unpack(

# Extract the typevartuple so we can get a tuple fallback from it.
expanded_unpacked_tvt = expanded_unpack.type
assert isinstance(expanded_unpacked_tvt, TypeVarTupleType)
if isinstance(expanded_unpacked_tvt, TypeVarTupleType):
fallback = expanded_unpacked_tvt.tuple_fallback
else:
# This can happen when tuple[Any, ...] is used to "patch" a variadic
# generic type without type arguments provided.
assert isinstance(expanded_unpacked_tvt, ProperType)
assert isinstance(expanded_unpacked_tvt, Instance)
ilevkivskyi marked this conversation as resolved.
Show resolved Hide resolved
assert expanded_unpacked_tvt.type.fullname == "builtins.tuple"
fallback = expanded_unpacked_tvt

prefix_len = expanded_unpack_index
arg_names = t.arg_names[:star_index] + [None] * prefix_len + t.arg_names[star_index:]
Expand All @@ -368,11 +377,7 @@ def interpolate_args_for_unpack(
+ expanded_items[:prefix_len]
# Constructing the Unpack containing the tuple without the prefix.
+ [
UnpackType(
TupleType(
expanded_items[prefix_len:], expanded_unpacked_tvt.tuple_fallback
)
)
UnpackType(TupleType(expanded_items[prefix_len:], fallback))
if len(expanded_items) - prefix_len > 1
else expanded_items[0]
]
Expand Down Expand Up @@ -456,9 +461,12 @@ def expand_types_with_unpack(
indicates use of Any or some error occurred earlier. In this case callers should
simply propagate the resulting type.
"""
# TODO: this will cause a crash on aliases like A = Tuple[int, Unpack[A]].
# Although it is unlikely anyone will write this, we should fail gracefully.
typs = flatten_nested_tuples(typs)
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType):
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
unpacked_items = self.expand_unpack(item)
if unpacked_items is None:
# TODO: better error, something like tuple of unknown?
Expand Down Expand Up @@ -523,7 +531,11 @@ def visit_type_type(self, t: TypeType) -> Type:
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Target of the type alias cannot contain type variables (not bound by the type
# alias itself), so we just expand the arguments.
return t.copy_modified(args=self.expand_types(t.args))
args = self.expand_types_with_unpack(t.args)
if isinstance(args, list):
return t.copy_modified(args=args)
else:
return args

def expand_types(self, types: Iterable[Type]) -> list[Type]:
a: list[Type] = []
Expand Down
5 changes: 5 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3471,6 +3471,7 @@ def f(x: B[T]) -> T: ... # without T, Any would be used here
"normalized",
"_is_recursive",
"eager",
"tvar_tuple_index",
)

__match_args__ = ("name", "target", "alias_tvars", "no_args")
Expand Down Expand Up @@ -3498,6 +3499,10 @@ def __init__(
# it is the cached value.
self._is_recursive: bool | None = None
self.eager = eager
self.tvar_tuple_index = None
for i, t in enumerate(alias_tvars):
if isinstance(t, mypy.types.TypeVarTupleType):
self.tvar_tuple_index = i
super().__init__(line, column)

@classmethod
Expand Down
13 changes: 12 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
TypeOfAny,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UnpackType,
Expand Down Expand Up @@ -3424,8 +3425,18 @@ def analyze_alias(
allowed_alias_tvars=tvar_defs,
)

# There can be only one variadic variable at most, the error is reported elsewhere.
new_tvar_defs = []
variadic = False
for td in tvar_defs:
if isinstance(td, TypeVarTupleType):
if variadic:
continue
variadic = True
new_tvar_defs.append(td)

qualified_tvars = [node.fullname for _name, node in found_type_vars]
return analyzed, tvar_defs, depends_on, qualified_tvars
return analyzed, new_tvar_defs, depends_on, qualified_tvars

def is_pep_613(self, s: AssignmentStmt) -> bool:
if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType):
Expand Down
50 changes: 46 additions & 4 deletions mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mypy.options import Options
from mypy.scope import Scope
from mypy.subtypes import is_same_type, is_subtype
from mypy.typeanal import set_any_tvars
from mypy.types import (
AnyType,
Instance,
Expand All @@ -32,8 +33,10 @@
TypeVarType,
UnboundType,
UnpackType,
flatten_nested_tuples,
get_proper_type,
get_proper_types,
split_with_prefix_and_suffix,
)


Expand Down Expand Up @@ -79,10 +82,34 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None:
self.seen_aliases.add(t)
# Some recursive aliases may produce spurious args. In principle this is not very
# important, as we would simply ignore them when expanding, but it is better to keep
# correct aliases.
if t.alias and len(t.args) != len(t.alias.alias_tvars):
t.args = [AnyType(TypeOfAny.from_error) for _ in t.alias.alias_tvars]
# correct aliases. Also, variadic aliases are better to check when fully analyzed,
# so we do this here.
assert t.alias is not None, f"Unfixed type alias {t.type_ref}"
args = flatten_nested_tuples(t.args)
if t.alias.tvar_tuple_index is not None:
correct = len(args) >= len(t.alias.alias_tvars) - 1
if any(
isinstance(a, UnpackType) and isinstance(get_proper_type(a.type), Instance)
for a in args
):
correct = True
else:
correct = len(args) == len(t.alias.alias_tvars)
if not correct:
if t.alias.tvar_tuple_index is not None:
exp_len = f"at least {len(t.alias.alias_tvars) - 1}"
else:
exp_len = f"{len(t.alias.alias_tvars)}"
self.fail(
f"Bad number of arguments for type alias, expected: {exp_len}, given: {len(args)}",
t,
code=codes.TYPE_ARG,
)
t.args = set_any_tvars(
t.alias, t.line, t.column, self.options, from_error=True, fail=self.fail
).args
else:
t.args = args
is_error = self.validate_args(t.alias.name, t.args, t.alias.alias_tvars, t)
if not is_error:
# If there was already an error for the alias itself, there is no point in checking
Expand All @@ -101,6 +128,17 @@ def visit_instance(self, t: Instance) -> None:
def validate_args(
self, name: str, args: Sequence[Type], type_vars: list[TypeVarLikeType], ctx: Context
) -> bool:
# TODO: we need to do flatten_nested_tuples and validate arg count for instances
# similar to how do we do this for type aliases above, but this may have perf penalty.
if any(isinstance(v, TypeVarTupleType) for v in type_vars):
prefix = next(i for (i, v) in enumerate(type_vars) if isinstance(v, TypeVarTupleType))
tvt = type_vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
start, middle, end = split_with_prefix_and_suffix(
tuple(args), prefix, len(type_vars) - prefix - 1
)
args = list(start) + [TupleType(list(middle), tvt.tuple_fallback)] + list(end)

is_error = False
for (i, arg), tvar in zip(enumerate(args), type_vars):
if isinstance(tvar, TypeVarType):
Expand Down Expand Up @@ -167,7 +205,11 @@ def visit_unpack_type(self, typ: UnpackType) -> None:
return
if isinstance(proper_type, Instance) and proper_type.type.fullname == "builtins.tuple":
return
if isinstance(proper_type, AnyType) and proper_type.type_of_any == TypeOfAny.from_error:
if (
isinstance(proper_type, UnboundType)
or isinstance(proper_type, AnyType)
and proper_type.type_of_any == TypeOfAny.from_error
):
return

# TODO: Infer something when it can't be unpacked to allow rest of
Expand Down
2 changes: 2 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool:
def visit_unpack_type(self, left: UnpackType) -> bool:
if isinstance(self.right, UnpackType):
return self._is_subtype(left.type, self.right.type)
if isinstance(self.right, Instance) and self.right.type.fullname == "builtins.object":
return True
return False

def visit_parameters(self, left: Parameters) -> bool:
Expand Down
Loading