-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alternative algorithm for union math #5255
Changes from all commits
057e329
b796af7
d127ecb
e98cb76
7057796
4e6a03d
33e845f
dd8adba
c934465
5f91877
b9659e1
7d02ba8
643c3b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
"""Expression type checker. This file is conceptually part of TypeChecker.""" | ||
|
||
from collections import OrderedDict | ||
from typing import cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable, Sequence, Any | ||
from contextlib import contextmanager | ||
from typing import ( | ||
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable, | ||
Sequence, Any, Iterator | ||
) | ||
|
||
from mypy.errors import report_internal_error | ||
from mypy.typeanal import ( | ||
|
@@ -58,6 +62,18 @@ | |
ArgChecker = Callable[[Type, Type, int, Type, int, int, CallableType, Context, MessageBuilder], | ||
None] | ||
|
||
# Maximum nesting level for math union in overloads, setting this to large values | ||
# may cause performance issues. The reason is that although union math algorithm we use | ||
# nicely captures most corner cases, its worst case complexity is exponential, | ||
# see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion. | ||
MAX_UNIONS = 5 | ||
|
||
|
||
class TooManyUnions(Exception): | ||
"""Indicates that we need to stop splitting unions in an attempt | ||
to match an overload in order to save performance. | ||
""" | ||
|
||
|
||
def extract_refexpr_names(expr: RefExpr) -> Set[str]: | ||
"""Recursively extracts all module references from a reference expression. | ||
|
@@ -120,6 +136,11 @@ def __init__(self, | |
self.msg = msg | ||
self.plugin = plugin | ||
self.type_context = [None] | ||
# Temporary overrides for expression types. This is currently | ||
# used by the union math in overloads. | ||
# TODO: refactor this to use a pattern similar to one in | ||
# multiassign_from_union, or maybe even combine the two? | ||
self.type_overrides = {} # type: Dict[Expression, Type] | ||
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) | ||
|
||
def visit_name_expr(self, e: NameExpr) -> Type: | ||
|
@@ -1138,41 +1159,46 @@ def check_overload_call(self, | |
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion. | ||
erased_targets = None # type: Optional[List[CallableType]] | ||
unioned_result = None # type: Optional[Tuple[Type, Type]] | ||
unioned_errors = None # type: Optional[MessageBuilder] | ||
union_success = False | ||
if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union | ||
for arg in arg_types): | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
unioned_callable = self.union_overload_matches(erased_targets) | ||
|
||
if unioned_callable is not None: | ||
unioned_errors = arg_messages.clean_copy() | ||
unioned_result = self.check_call(unioned_callable, args, arg_kinds, | ||
context, arg_names, | ||
arg_messages=unioned_errors, | ||
callable_name=callable_name, | ||
object_type=object_type) | ||
union_interrupted = False # did we try all union combinations? | ||
if any(self.real_union(arg) for arg in arg_types): | ||
unioned_errors = arg_messages.clean_copy() | ||
try: | ||
unioned_return = self.union_overload_result(plausible_targets, args, | ||
arg_types, arg_kinds, arg_names, | ||
callable_name, object_type, | ||
context, | ||
arg_messages=unioned_errors) | ||
except TooManyUnions: | ||
union_interrupted = True | ||
else: | ||
# Record if we succeeded. Next we need to see if maybe normal procedure | ||
# gives a narrower type. | ||
union_success = unioned_result is not None and not unioned_errors.is_errors() | ||
if unioned_return: | ||
returns, inferred_types = zip(*unioned_return) | ||
# Note that we use `union_overload_matches` instead of just returning | ||
# a union of inferred callables because for example a call | ||
# Union[int -> int, str -> str](Union[int, str]) is invalid and | ||
# we don't want to introduce internal inconsistencies. | ||
unioned_result = (UnionType.make_simplified_union(list(returns), | ||
context.line, | ||
context.column), | ||
self.union_overload_matches(inferred_types)) | ||
|
||
# Step 3: We try checking each branch one-by-one. | ||
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages) | ||
if inferred_result is not None: | ||
# Success! Stop early by returning the best among normal and unioned. | ||
if not union_success: | ||
# If any of checks succeed, stop early. | ||
if inferred_result is not None and unioned_result is not None: | ||
# Both unioned and direct checks succeeded, choose the more precise type. | ||
if (is_subtype(inferred_result[0], unioned_result[0]) and | ||
not isinstance(inferred_result[0], AnyType)): | ||
return inferred_result | ||
else: | ||
assert unioned_result is not None | ||
if is_subtype(inferred_result[0], unioned_result[0]): | ||
return inferred_result | ||
return unioned_result | ||
elif union_success: | ||
assert unioned_result is not None | ||
return unioned_result | ||
elif unioned_result is not None: | ||
return unioned_result | ||
elif inferred_result is not None: | ||
return inferred_result | ||
|
||
# Step 4: Failure. At this point, we know there is no match. We fall back to trying | ||
# to find a somewhat plausible overload target using the erased types | ||
|
@@ -1183,19 +1209,12 @@ def check_overload_call(self, | |
# | ||
# Neither alternative matches, but we can guess the user probably wants the | ||
# second one. | ||
if erased_targets is None: | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
|
||
# Step 5: We try and infer a second-best alternative if possible. If not, fall back | ||
# to using 'Any'. | ||
if unioned_result is not None: | ||
# When possible, return the error messages generated from the union-math attempt: | ||
# they tend to be a little nicer. | ||
assert unioned_errors is not None | ||
arg_messages.add_errors(unioned_errors) | ||
return unioned_result | ||
elif len(erased_targets) > 0: | ||
if len(erased_targets) > 0: | ||
# Pick the first plausible erased target as the fallback | ||
# TODO: Adjust the error message here to make it clear there was no match. | ||
target = erased_targets[0] # type: Type | ||
|
@@ -1204,11 +1223,14 @@ def check_overload_call(self, | |
if not self.chk.should_suppress_optional_error(arg_types): | ||
arg_messages.no_variant_matches_arguments(callee, arg_types, context) | ||
target = AnyType(TypeOfAny.from_error) | ||
|
||
return self.check_call(target, args, arg_kinds, context, arg_names, | ||
arg_messages=arg_messages, | ||
callable_name=callable_name, | ||
object_type=object_type) | ||
result = self.check_call(target, args, arg_kinds, context, arg_names, | ||
arg_messages=arg_messages, | ||
callable_name=callable_name, | ||
object_type=object_type) | ||
if union_interrupted: | ||
self.chk.msg.note("Not all union combinations were tried" | ||
" because there are too many unions", context) | ||
return result | ||
|
||
def plausible_overload_call_targets(self, | ||
arg_types: List[Type], | ||
|
@@ -1358,18 +1380,110 @@ def overload_erased_call_targets(self, | |
matches.append(typ) | ||
return matches | ||
|
||
def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]: | ||
def union_overload_result(self, | ||
plausible_targets: List[CallableType], | ||
args: List[Expression], | ||
arg_types: List[Type], | ||
arg_kinds: List[int], | ||
arg_names: Optional[Sequence[Optional[str]]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Off-topic: I wonder if it might be worth writing some sort of custom class/namedtuple that stores There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we have many potential refactoring ideas on our internal roadmap, we can add this one to the list. |
||
callable_name: Optional[str], | ||
object_type: Optional[Type], | ||
context: Context, | ||
arg_messages: Optional[MessageBuilder] = None, | ||
level: int = 0 | ||
) -> Optional[List[Tuple[Type, Type]]]: | ||
"""Accepts a list of overload signatures and attempts to match calls by destructuring | ||
the first union. | ||
|
||
Return a list of (<return type>, <inferred variant type>) if call succeeds for every | ||
item of the desctructured union. Returns None if there is no match. | ||
""" | ||
# Step 1: If we are already too deep, then stop immediately. Otherwise mypy might | ||
# hang for long time because of a weird overload call. The caller will get | ||
# the exception and generate an appropriate note message, if needed. | ||
if level >= MAX_UNIONS: | ||
raise TooManyUnions | ||
|
||
# Step 2: Find position of the first union in arguments. Return the normal inferred | ||
# type if no more unions left. | ||
for idx, typ in enumerate(arg_types): | ||
if self.real_union(typ): | ||
break | ||
else: | ||
# No unions in args, just fall back to normal inference | ||
with self.type_overrides_set(args, arg_types): | ||
res = self.infer_overload_return_type(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages) | ||
if res is not None: | ||
return [res] | ||
return None | ||
|
||
# Step 3: Try a direct match before splitting to avoid unnecessary union splits | ||
# and save performance. | ||
with self.type_overrides_set(args, arg_types): | ||
direct = self.infer_overload_return_type(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages) | ||
if direct is not None and not isinstance(direct[0], (UnionType, AnyType)): | ||
# We only return non-unions soon, to avoid greedy match. | ||
return [direct] | ||
|
||
# Step 4: Split the first remaining union type in arguments into items and | ||
# try to match each item individually (recursive). | ||
first_union = arg_types[idx] | ||
assert isinstance(first_union, UnionType) | ||
res_items = [] | ||
for item in first_union.relevant_items(): | ||
new_arg_types = arg_types.copy() | ||
new_arg_types[idx] = item | ||
sub_result = self.union_overload_result(plausible_targets, args, new_arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages, | ||
level + 1) | ||
if sub_result is not None: | ||
res_items.extend(sub_result) | ||
else: | ||
# Some item doesn't match, return soon. | ||
return None | ||
|
||
# Step 5: If splitting succeeded, then filter out duplicate items before returning. | ||
seen = set() # type: Set[Tuple[Type, Type]] | ||
result = [] | ||
for pair in res_items: | ||
if pair not in seen: | ||
seen.add(pair) | ||
result.append(pair) | ||
return result | ||
|
||
def real_union(self, typ: Type) -> bool: | ||
return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 | ||
|
||
@contextmanager | ||
def type_overrides_set(self, exprs: Sequence[Expression], | ||
overrides: Sequence[Type]) -> Iterator[None]: | ||
"""Set _temporary_ type overrides for given expressions.""" | ||
assert len(exprs) == len(overrides) | ||
for expr, typ in zip(exprs, overrides): | ||
self.type_overrides[expr] = typ | ||
try: | ||
yield | ||
finally: | ||
for expr in exprs: | ||
del self.type_overrides[expr] | ||
|
||
def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]: | ||
"""Accepts a list of overload signatures and attempts to combine them together into a | ||
new CallableType consisting of the union of all of the given arguments and return types. | ||
|
||
Returns None if it is not possible to combine the different callables together in a | ||
sound manner. | ||
|
||
Assumes all of the given callables have argument counts compatible with the caller. | ||
If there is at least one non-callable type, return Any (this can happen if there is | ||
an ambiguity because of Any in arguments). | ||
""" | ||
if len(callables) == 0: | ||
return None | ||
elif len(callables) == 1: | ||
assert types, "Trying to merge no callables" | ||
if not all(isinstance(c, CallableType) for c in types): | ||
return AnyType(TypeOfAny.special_form) | ||
callables = cast(List[CallableType], types) | ||
if len(callables) == 1: | ||
return callables[0] | ||
|
||
# Note: we are assuming here that if a user uses some TypeVar 'T' in | ||
|
@@ -1389,58 +1503,52 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call | |
new_kinds = list(callables[0].arg_kinds) | ||
new_returns = [] # type: List[Type] | ||
|
||
too_complex = False | ||
for target in callables: | ||
# We conservatively end if the overloads do not have the exact same signature. | ||
# The only exception is if one arg is optional and the other is positional: in that | ||
# case, we continue unioning (and expect a positional arg). | ||
# TODO: Enhance the union overload logic to handle a wider variety of signatures. | ||
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have | ||
# the exact same signature. The only exception is if one arg is optional and | ||
# the other is positional: in that case, we continue unioning (and expect a | ||
# positional arg). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about falling back to (I think this is basically the same question as the previous one -- maybe add a comment here about this?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It actually breaks nothing, I just don't want to introduce some internal inconsistencies (as I explained above). Also the |
||
# TODO: Enhance the merging logic to handle a wider variety of signatures. | ||
if len(new_kinds) != len(target.arg_kinds): | ||
return None | ||
too_complex = True | ||
break | ||
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): | ||
if new_kind == target_kind: | ||
continue | ||
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT): | ||
new_kinds[i] = ARG_POS | ||
else: | ||
return None | ||
too_complex = True | ||
break | ||
|
||
if too_complex: | ||
break # outer loop | ||
|
||
for i, arg in enumerate(target.arg_types): | ||
new_args[i].append(arg) | ||
new_returns.append(target.ret_type) | ||
|
||
union_count = 0 | ||
union_return = UnionType.make_simplified_union(new_returns) | ||
if too_complex: | ||
any = AnyType(TypeOfAny.special_form) | ||
return callables[0].copy_modified( | ||
arg_types=[any, any], | ||
arg_kinds=[ARG_STAR, ARG_STAR2], | ||
arg_names=[None, None], | ||
ret_type=union_return, | ||
variables=variables, | ||
implicit=True) | ||
|
||
final_args = [] | ||
for args_list in new_args: | ||
new_type = UnionType.make_simplified_union(args_list) | ||
union_count += 1 if isinstance(new_type, UnionType) else 0 | ||
final_args.append(new_type) | ||
|
||
# TODO: Modify this check to be less conservative. | ||
# | ||
# Currently, we permit only one union in the arguments because if we allow | ||
# multiple, we can't always guarantee the synthesized callable will be correct. | ||
# | ||
# For example, suppose we had the following two overloads: | ||
# | ||
# @overload | ||
# def f(x: A, y: B) -> None: ... | ||
# @overload | ||
# def f(x: B, y: A) -> None: ... | ||
# | ||
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...", | ||
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to | ||
# be rejected. | ||
# | ||
# However, that means we'll also give up if the original overloads contained | ||
# any unions. This is likely unnecessary -- we only really need to give up if | ||
# there are more then one *synthesized* union arguments. | ||
if union_count >= 2: | ||
return None | ||
|
||
return callables[0].copy_modified( | ||
arg_types=final_args, | ||
arg_kinds=new_kinds, | ||
ret_type=UnionType.make_simplified_union(new_returns), | ||
ret_type=union_return, | ||
variables=variables, | ||
implicit=True) | ||
|
||
|
@@ -2733,6 +2841,8 @@ def accept(self, | |
is True and this expression is a call, allow it to return None. This | ||
applies only to this expression and not any subexpressions. | ||
""" | ||
if node in self.type_overrides: | ||
return self.type_overrides[node] | ||
self.type_context.append(type_context) | ||
try: | ||
if allow_none_return and isinstance(node, CallExpr): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Github still won't let me leave a comment below, but I think after doing the above two simplifications, we can get rid of the next if statement entirely.
(It's a bit hard to tell, but based on new error messages in the changed tests, it seems like we're no longer entering into this case at all, so it's redundant either way.)