Skip to content

Commit

Permalink
Alternative algorithm for union math (#5255)
Browse files Browse the repository at this point in the history
Fixes #5243
Fixes #5249

Some comments:
* I went ahead with a slow but very simple recursive algorithm that treats all various complex cases correctly. On one hand it can be exponential, but on the other hand, the complexity will be bad _only_ if user abuses lots of unions
* I use a hack caused by the fact that currently most function inference functions pass argument _expressions_ instead of types, I left a TODO to use a more unified approach similar to multiassign_from_union
* It may look like there are many changes in tests, but actually there are not, the differences are because:
  - Error messages now show the _first potentially matching_ overload (which is OK I think)
  - Order of items in many unions turned to the opposite, apparently union `__repr__` is unstable.
  • Loading branch information
ilevkivskyi authored Jul 3, 2018
1 parent 44e789d commit 0ca6bf9
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 109 deletions.
272 changes: 191 additions & 81 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Expression type checker. This file is conceptually part of TypeChecker."""

from collections import OrderedDict
from typing import cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable, Sequence, Any
from contextlib import contextmanager
from typing import (
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable,
Sequence, Any, Iterator
)

from mypy.errors import report_internal_error
from mypy.typeanal import (
Expand Down Expand Up @@ -58,6 +62,18 @@
ArgChecker = Callable[[Type, Type, int, Type, int, int, CallableType, Context, MessageBuilder],
None]

# Maximum nesting level for math union in overloads, setting this to large values
# may cause performance issues. The reason is that although union math algorithm we use
# nicely captures most corner cases, its worst case complexity is exponential,
# see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion.
MAX_UNIONS = 5


class TooManyUnions(Exception):
"""Indicates that we need to stop splitting unions in an attempt
to match an overload in order to save performance.
"""


def extract_refexpr_names(expr: RefExpr) -> Set[str]:
"""Recursively extracts all module references from a reference expression.
Expand Down Expand Up @@ -120,6 +136,11 @@ def __init__(self,
self.msg = msg
self.plugin = plugin
self.type_context = [None]
# Temporary overrides for expression types. This is currently
# used by the union math in overloads.
# TODO: refactor this to use a pattern similar to one in
# multiassign_from_union, or maybe even combine the two?
self.type_overrides = {} # type: Dict[Expression, Type]
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)

def visit_name_expr(self, e: NameExpr) -> Type:
Expand Down Expand Up @@ -1138,41 +1159,46 @@ def check_overload_call(self,
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
erased_targets = None # type: Optional[List[CallableType]]
unioned_result = None # type: Optional[Tuple[Type, Type]]
unioned_errors = None # type: Optional[MessageBuilder]
union_success = False
if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union
for arg in arg_types):
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)
unioned_callable = self.union_overload_matches(erased_targets)

if unioned_callable is not None:
unioned_errors = arg_messages.clean_copy()
unioned_result = self.check_call(unioned_callable, args, arg_kinds,
context, arg_names,
arg_messages=unioned_errors,
callable_name=callable_name,
object_type=object_type)
union_interrupted = False # did we try all union combinations?
if any(self.real_union(arg) for arg in arg_types):
unioned_errors = arg_messages.clean_copy()
try:
unioned_return = self.union_overload_result(plausible_targets, args,
arg_types, arg_kinds, arg_names,
callable_name, object_type,
context,
arg_messages=unioned_errors)
except TooManyUnions:
union_interrupted = True
else:
# Record if we succeeded. Next we need to see if maybe normal procedure
# gives a narrower type.
union_success = unioned_result is not None and not unioned_errors.is_errors()
if unioned_return:
returns, inferred_types = zip(*unioned_return)
# Note that we use `union_overload_matches` instead of just returning
# a union of inferred callables because for example a call
# Union[int -> int, str -> str](Union[int, str]) is invalid and
# we don't want to introduce internal inconsistencies.
unioned_result = (UnionType.make_simplified_union(list(returns),
context.line,
context.column),
self.union_overload_matches(inferred_types))

# Step 3: We try checking each branch one-by-one.
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages)
if inferred_result is not None:
# Success! Stop early by returning the best among normal and unioned.
if not union_success:
# If any of checks succeed, stop early.
if inferred_result is not None and unioned_result is not None:
# Both unioned and direct checks succeeded, choose the more precise type.
if (is_subtype(inferred_result[0], unioned_result[0]) and
not isinstance(inferred_result[0], AnyType)):
return inferred_result
else:
assert unioned_result is not None
if is_subtype(inferred_result[0], unioned_result[0]):
return inferred_result
return unioned_result
elif union_success:
assert unioned_result is not None
return unioned_result
elif unioned_result is not None:
return unioned_result
elif inferred_result is not None:
return inferred_result

# Step 4: Failure. At this point, we know there is no match. We fall back to trying
# to find a somewhat plausible overload target using the erased types
Expand All @@ -1183,19 +1209,12 @@ def check_overload_call(self,
#
# Neither alternative matches, but we can guess the user probably wants the
# second one.
if erased_targets is None:
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)

# Step 5: We try and infer a second-best alternative if possible. If not, fall back
# to using 'Any'.
if unioned_result is not None:
# When possible, return the error messages generated from the union-math attempt:
# they tend to be a little nicer.
assert unioned_errors is not None
arg_messages.add_errors(unioned_errors)
return unioned_result
elif len(erased_targets) > 0:
if len(erased_targets) > 0:
# Pick the first plausible erased target as the fallback
# TODO: Adjust the error message here to make it clear there was no match.
target = erased_targets[0] # type: Type
Expand All @@ -1204,11 +1223,14 @@ def check_overload_call(self,
if not self.chk.should_suppress_optional_error(arg_types):
arg_messages.no_variant_matches_arguments(callee, arg_types, context)
target = AnyType(TypeOfAny.from_error)

return self.check_call(target, args, arg_kinds, context, arg_names,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
result = self.check_call(target, args, arg_kinds, context, arg_names,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
if union_interrupted:
self.chk.msg.note("Not all union combinations were tried"
" because there are too many unions", context)
return result

def plausible_overload_call_targets(self,
arg_types: List[Type],
Expand Down Expand Up @@ -1358,18 +1380,110 @@ def overload_erased_call_targets(self,
matches.append(typ)
return matches

def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
def union_overload_result(self,
plausible_targets: List[CallableType],
args: List[Expression],
arg_types: List[Type],
arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
callable_name: Optional[str],
object_type: Optional[Type],
context: Context,
arg_messages: Optional[MessageBuilder] = None,
level: int = 0
) -> Optional[List[Tuple[Type, Type]]]:
"""Accepts a list of overload signatures and attempts to match calls by destructuring
the first union.
Return a list of (<return type>, <inferred variant type>) if call succeeds for every
item of the desctructured union. Returns None if there is no match.
"""
# Step 1: If we are already too deep, then stop immediately. Otherwise mypy might
# hang for long time because of a weird overload call. The caller will get
# the exception and generate an appropriate note message, if needed.
if level >= MAX_UNIONS:
raise TooManyUnions

# Step 2: Find position of the first union in arguments. Return the normal inferred
# type if no more unions left.
for idx, typ in enumerate(arg_types):
if self.real_union(typ):
break
else:
# No unions in args, just fall back to normal inference
with self.type_overrides_set(args, arg_types):
res = self.infer_overload_return_type(plausible_targets, args, arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages)
if res is not None:
return [res]
return None

# Step 3: Try a direct match before splitting to avoid unnecessary union splits
# and save performance.
with self.type_overrides_set(args, arg_types):
direct = self.infer_overload_return_type(plausible_targets, args, arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages)
if direct is not None and not isinstance(direct[0], (UnionType, AnyType)):
# We only return non-unions soon, to avoid greedy match.
return [direct]

# Step 4: Split the first remaining union type in arguments into items and
# try to match each item individually (recursive).
first_union = arg_types[idx]
assert isinstance(first_union, UnionType)
res_items = []
for item in first_union.relevant_items():
new_arg_types = arg_types.copy()
new_arg_types[idx] = item
sub_result = self.union_overload_result(plausible_targets, args, new_arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages,
level + 1)
if sub_result is not None:
res_items.extend(sub_result)
else:
# Some item doesn't match, return soon.
return None

# Step 5: If splitting succeeded, then filter out duplicate items before returning.
seen = set() # type: Set[Tuple[Type, Type]]
result = []
for pair in res_items:
if pair not in seen:
seen.add(pair)
result.append(pair)
return result

def real_union(self, typ: Type) -> bool:
return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1

@contextmanager
def type_overrides_set(self, exprs: Sequence[Expression],
overrides: Sequence[Type]) -> Iterator[None]:
"""Set _temporary_ type overrides for given expressions."""
assert len(exprs) == len(overrides)
for expr, typ in zip(exprs, overrides):
self.type_overrides[expr] = typ
try:
yield
finally:
for expr in exprs:
del self.type_overrides[expr]

def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.
Returns None if it is not possible to combine the different callables together in a
sound manner.
Assumes all of the given callables have argument counts compatible with the caller.
If there is at least one non-callable type, return Any (this can happen if there is
an ambiguity because of Any in arguments).
"""
if len(callables) == 0:
return None
elif len(callables) == 1:
assert types, "Trying to merge no callables"
if not all(isinstance(c, CallableType) for c in types):
return AnyType(TypeOfAny.special_form)
callables = cast(List[CallableType], types)
if len(callables) == 1:
return callables[0]

# Note: we are assuming here that if a user uses some TypeVar 'T' in
Expand All @@ -1389,58 +1503,52 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
new_kinds = list(callables[0].arg_kinds)
new_returns = [] # type: List[Type]

too_complex = False
for target in callables:
# We conservatively end if the overloads do not have the exact same signature.
# The only exception is if one arg is optional and the other is positional: in that
# case, we continue unioning (and expect a positional arg).
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
# the exact same signature. The only exception is if one arg is optional and
# the other is positional: in that case, we continue unioning (and expect a
# positional arg).
# TODO: Enhance the merging logic to handle a wider variety of signatures.
if len(new_kinds) != len(target.arg_kinds):
return None
too_complex = True
break
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
if new_kind == target_kind:
continue
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
new_kinds[i] = ARG_POS
else:
return None
too_complex = True
break

if too_complex:
break # outer loop

for i, arg in enumerate(target.arg_types):
new_args[i].append(arg)
new_returns.append(target.ret_type)

union_count = 0
union_return = UnionType.make_simplified_union(new_returns)
if too_complex:
any = AnyType(TypeOfAny.special_form)
return callables[0].copy_modified(
arg_types=[any, any],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None],
ret_type=union_return,
variables=variables,
implicit=True)

final_args = []
for args_list in new_args:
new_type = UnionType.make_simplified_union(args_list)
union_count += 1 if isinstance(new_type, UnionType) else 0
final_args.append(new_type)

# TODO: Modify this check to be less conservative.
#
# Currently, we permit only one union in the arguments because if we allow
# multiple, we can't always guarantee the synthesized callable will be correct.
#
# For example, suppose we had the following two overloads:
#
# @overload
# def f(x: A, y: B) -> None: ...
# @overload
# def f(x: B, y: A) -> None: ...
#
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
# be rejected.
#
# However, that means we'll also give up if the original overloads contained
# any unions. This is likely unnecessary -- we only really need to give up if
# there are more then one *synthesized* union arguments.
if union_count >= 2:
return None

return callables[0].copy_modified(
arg_types=final_args,
arg_kinds=new_kinds,
ret_type=UnionType.make_simplified_union(new_returns),
ret_type=union_return,
variables=variables,
implicit=True)

Expand Down Expand Up @@ -2733,6 +2841,8 @@ def accept(self,
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand Down
Loading

0 comments on commit 0ca6bf9

Please sign in to comment.