Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative algorithm for union math #5255

Merged
merged 13 commits into from
Jul 3, 2018
178 changes: 74 additions & 104 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(self,
self.msg = msg
self.plugin = plugin
self.type_context = [None]
self.type_overrides = {} # type: Dict[Expression, Type]
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)

def visit_name_expr(self, e: NameExpr) -> Type:
Expand Down Expand Up @@ -519,7 +520,9 @@ def check_call(self, callee: Type, args: List[Expression],
callable_node: Optional[Expression] = None,
arg_messages: Optional[MessageBuilder] = None,
callable_name: Optional[str] = None,
object_type: Optional[Type] = None) -> Tuple[Type, Type]:
object_type: Optional[Type] = None,
*,
arg_types_override: Optional[List[Type]] = None) -> Tuple[Type, Type]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed this, but it doesn't look like we ever actually pass in something other then None for arg_types_override?

This param also feels a bit redundant given we also have the type_overrides field -- it seems like they're both trying to do the same thing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an artefact of a previous attempt to do this "right way".

"""Type check a call.

Also infer type arguments if the callee is a generic function.
Expand Down Expand Up @@ -575,9 +578,11 @@ def check_call(self, callee: Type, args: List[Expression],
callee, context)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context)

arg_types = self.infer_arg_types_in_context2(
callee, args, arg_kinds, formal_to_actual)
if arg_types_override is not None:
arg_types = arg_types_override.copy()
else:
arg_types = self.infer_arg_types_in_context2(
callee, args, arg_kinds, formal_to_actual)

self.check_argument_count(callee, arg_types, arg_kinds,
arg_names, formal_to_actual, context, self.msg)
Expand Down Expand Up @@ -1130,22 +1135,15 @@ def check_overload_call(self,
unioned_result = None # type: Optional[Tuple[Type, Type]]
unioned_errors = None # type: Optional[MessageBuilder]
union_success = False
if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union
for arg in arg_types):
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)
unioned_callable = self.union_overload_matches(erased_targets)

if unioned_callable is not None:
unioned_errors = arg_messages.clean_copy()
unioned_result = self.check_call(unioned_callable, args, arg_kinds,
context, arg_names,
arg_messages=unioned_errors,
callable_name=callable_name,
object_type=object_type)
# Record if we succeeded. Next we need to see if maybe normal procedure
# gives a narrower type.
union_success = unioned_result is not None and not unioned_errors.is_errors()
if any(self.real_union(arg) for arg in arg_types):
unioned_errors = arg_messages.clean_copy()
unioned_result = self.union_overload_result(plausible_targets, args, arg_types,
arg_kinds, arg_names,
callable_name, object_type,
context, arg_messages=unioned_errors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check: switching to doing union math on plausible_targets instead of erased_targets is an intentional change?

(The former variable is supposed to contain callables that just have the right number of args; the latter is supposed to contain (non-erased) callables that match after performing type erasure. The names admittedly aren't very clear, but I'm planning on submitting another PR soon that cleans that up)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought erased_targets are calculated using original types (i.e. including full unions). I just wanted to throw away once all the overloads that have no way to match, while keep those that can match after a union destructuring. I think plausible_targets is what is needed or am I wrong?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think erased_targets is computed in a way that it matches a callable if at least one of the union elements, not if all of them match. That allowed it to act as an additional filter that pruned plausible_targets down even further.

That said, I think the way we compute the erased targets is currently a bit janky -- pretty much all of that code is from the old overloads implementation and is something I'm planning on refactoring/simplifying in a future PR. So, I wouldn't be too surprised if it didn't end up working here. (But if it did, it might be a way we could get a minor speedup in the case where the list of overloads is long.)

That said, I guess that's a bit of an edge case, so maybe just using plausible_targets is the right thing to do here, since it's more obviously correct.

# Record if we succeeded. Next we need to see if maybe normal procedure
# gives a narrower type.
union_success = unioned_result is not None and not unioned_errors.is_errors()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the way you implemented union_overload_result, I think it's impossible for union_errors to actually ever contain an error. As a consequence, union_success will be true if and only if union_result is not None is also true.

In that case, I think we can simplify some of the logic by getting rid of union_success and using just union_result. E.g. here, we replace this line with if union_errors.is_errors(): unioned_result = None just in case and get rid of union_success entirely.


# Step 3: We try checking each branch one-by-one.
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github won't let me leave a comment down below after doing the simplification suggested up above, I think we can also simplify the following if statements to something roughly like:

# Success! Stop early by returning the best among normal and unioned.
if inferred_result is not None and unioned_result is not None:
    return inferred_result if is_subtype(inferred_result[0], unioned_result[0]) else unioned_result
elif inferred_result is not None:
    return inferred_result
elif unioned_result is not None:
    return unioned_result

Expand Down Expand Up @@ -1173,9 +1171,8 @@ def check_overload_call(self,
#
# Neither alternative matches, but we can guess the user probably wants the
# second one.
if erased_targets is None:
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)

# Step 5: We try and infer a second-best alternative if possible. If not, fall back
# to using 'Any'.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github still won't let me leave a comment below, but I think after doing the above two simplifications, we can get rid of the next if statement entirely.

(It's a bit hard to tell, but based on new error messages in the changed tests, it seems like we're no longer entering into this case at all, so it's redundant either way.)

Expand Down Expand Up @@ -1350,91 +1347,62 @@ def overload_erased_call_targets(self,
matches.append(typ)
return matches

def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.

Returns None if it is not possible to combine the different callables together in a
sound manner.

Assumes all of the given callables have argument counts compatible with the caller.
def union_overload_result(self,
plausible_targets: List[CallableType],
args: List[Expression],
arg_types: List[Type],
arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic: I wonder if it might be worth writing some sort of custom class/namedtuple that stores arg, arg_types, and friends in some future refactoring PR. It's starting to become pretty awkward passing around all of these types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have many potential refactoring ideas on our internal roadmap, we can add this one to the list.

callable_name: Optional[str],
object_type: Optional[Type],
context: Context,
arg_messages: Optional[MessageBuilder] = None,
) -> Optional[Tuple[Type, Type]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about this, I think this algorithm might actually be exponential, not quadratic. (E.g. assuming we have n union arguments, each with 2 options, we'd do T(n) = T(n - 1) + T(n - 1) work, which is 2n).

In light of that, I wonder if it might maybe be worth adding in some sort of limit? E.g. if we recurse to a certain depth, we give up -- or maybe we add a check up above that skips union math entirely if the call has more then some set number of union arguments.

Copy link
Member Author

@ilevkivskyi ilevkivskyi Jun 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, it can be much worse. Here is some more detailed analysis.

The relevant numbers that we have here are:

  • number of (potentially matching) overloads, p
  • number of unions in args, k
  • numbers of items in every union n_i, i = 1, ..., k

The worst case scenario is when every check succeeds, in this case we have p * n_1 * ... * n_k inference attempts, assuming an average n items per union this is roughly p * n ** k. This might look very bad, but I have some intuition why in practice average/typical performance may be much better. The key is that in practice it is hard to write an overload that will match so many combinations, plus there is a return None that will fully "unwind" the union call stack on first failure. So for any realistic overload the number of union item combinations that successfully match will be limited by number of overloads p, for example to match successfully two args with types Union[A, B] and Union[C, D] one will need four overloads. Thus we will typically make total at most p ** 2 inference attempts.

Actually, now thinking more about this, we can also bailout soon on success, i.e. after splitting the first union, before we split the next one, we can check, maybe it already matches, for example:

@overload
def f(x: int, y: object, z: object) -> int: ...
@overload
def f(x: str, y: object, z: object) -> str: ...

x: Union[int, str]
f(x, x, x)

There is not need to split the second and third argument. This will need just few extra lines, but can be a decent performance gain.

An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth.

EDIT: typo: overload -> argument

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key is that in practice it is hard to write an overload that will match so many combinations

I'm not sure if it'll actually be that hard in practice, especially if the overload is intended to match a wide variety of different calls using a mixture of typevars, Any types, object types, *args, **kwargs...

For example, here's a program which calls a simplified version of map. It's a bit contrived, but I don't think it's too far off from the sorts of programs some users might write in practice:

from typing import overload, Any, TypeVar, Iterable, List, Dict, Callable, Union

S = TypeVar('S')

@overload
def simple_map() -> None: ...
@overload
def simple_map(func: Callable[..., S], *iterables: Iterable[Any]) -> S: ...
def simple_map(*args): pass

def format_row(*entries: object) -> str: pass

class DateTime: pass
JsonBlob = Dict[str, Any]
Column = Union[List[str], List[int], List[bool], List[float], List[DateTime], List[JsonBlob]]

def print_custom_table() -> None:
    a: Column
    b: Column
    c: Column
    d: Column
    e: Column
    f: Column
    g: Column

    for row in simple_map(format_row, a, b, c, d, e, f):
        print(row)

If I try running simple_map with 6 columns (a through f), it takes the new algorithm roughly a minute and a half on my computer to type check this program. If I tack on the 7th column (g), it takes mypy significantly longer. (I don't have an exact time unfortunately: I got bored of waiting after 10-15 minutes and killed the thing.)

I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).

An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth.

I also like the simplicity of this algorithm. (It's also not obvious to me how we could robustly solve this problem in any other way, in any case.)

But yeah, I still think we should add a limit of some sort. Even if the pathological cases end up being rare, I think it'd be a poor user experience if mypy starts unexpectedly stalling in those edge cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).

It actually helped. I added it (simplistic WIP style, need to polish it) and your example passed in a second (even with eight(!) union args) also few other complex examples I tried passed instantly.

But yeah, I still think we should add a limit of some sort.

I don't want to set a calculated limit, I would add a nesting/split level limit probably 5, and then raise TooManyUnionsError, so that we can better diagnose an error and add a note like:

main.py: error: No overload variant blah-blah-blah..
main.py: note: Not all union combinations were tried because there are too many unions

There are some ideas to avoid hitting the limit, like add some back-tracking add multi-seed (like start splitting the unions from last arg as well). But I will not do this, because this will only add code complexity while gaining some super-rare corner cases.

"""Accepts a list of overload signatures and attempts to match calls by destructuring
the first union. Returns None if there is no match.
"""
if len(callables) == 0:
return None
elif len(callables) == 1:
return callables[0]

# Note: we are assuming here that if a user uses some TypeVar 'T' in
# two different overloads, they meant for that TypeVar to mean the
# same thing.
#
# This function will make sure that all instances of that TypeVar 'T'
# refer to the same underlying TypeVarType and TypeVarDef objects to
# simplify the union-ing logic below.
#
# (If the user did *not* mean for 'T' to be consistently bound to the
# same type in their overloads, well, their code is probably too
# confusing and ought to be re-written anyways.)
callables, variables = merge_typevars_in_callables_by_name(callables)

new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]]
new_kinds = list(callables[0].arg_kinds)
new_returns = [] # type: List[Type]

for target in callables:
# We conservatively end if the overloads do not have the exact same signature.
# The only exception is if one arg is optional and the other is positional: in that
# case, we continue unioning (and expect a positional arg).
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
if len(new_kinds) != len(target.arg_kinds):
if not any(self.real_union(typ) for typ in arg_types):
# No unions in args, just fall back to normal inference
for arg, typ in zip(args, arg_types):
self.type_overrides[arg] = typ
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be made a context manager, also we need to be sure that the last type for an arg expression (the one that stays in type map after we leave the call expression) is a union.

res = self.infer_overload_return_type(plausible_targets, args, arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages)
for arg, typ in zip(args, arg_types):
del self.type_overrides[arg]
return res
# Try direct match before splitting
direct = self.infer_overload_return_type(plausible_targets, args, arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages)
if direct is not None and not isinstance(direct[0], UnionType):
# We only return non-unions soon, to avoid gredy match.
return direct
first_union = next(typ for typ in arg_types if self.real_union(typ))
idx = arg_types.index(first_union)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a micro-optimization, but I think we could combine the above two lines and the if-check into one loop. E.g. try finding the index first, then case on that to see if we should enter the base case or not.

assert isinstance(first_union, UnionType)
returns = []
inferred_types = []
for item in first_union.relevant_items():
new_arg_types = arg_types.copy()
new_arg_types[idx] = item
sub_result = self.union_overload_result(plausible_targets, args, new_arg_types,
arg_kinds, arg_names, callable_name,
object_type, context, arg_messages)
if sub_result is not None:
ret, inferred = sub_result
returns.append(ret)
inferred_types.append(inferred)
else:
return None
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
if new_kind == target_kind:
continue
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
new_kinds[i] = ARG_POS
else:
return None

for i, arg in enumerate(target.arg_types):
new_args[i].append(arg)
new_returns.append(target.ret_type)

union_count = 0
final_args = []
for args_list in new_args:
new_type = UnionType.make_simplified_union(args_list)
union_count += 1 if isinstance(new_type, UnionType) else 0
final_args.append(new_type)

# TODO: Modify this check to be less conservative.
#
# Currently, we permit only one union in the arguments because if we allow
# multiple, we can't always guarantee the synthesized callable will be correct.
#
# For example, suppose we had the following two overloads:
#
# @overload
# def f(x: A, y: B) -> None: ...
# @overload
# def f(x: B, y: A) -> None: ...
#
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
# be rejected.
#
# However, that means we'll also give up if the original overloads contained
# any unions. This is likely unnecessary -- we only really need to give up if
# there are more then one *synthesized* union arguments.
if union_count >= 2:
return None
if returns:
return (UnionType.make_simplified_union(returns, context.line, context.column),
UnionType.make_simplified_union(inferred_types, context.line, context.column))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another micro-optimization, but I wonder if it would be more efficient to return a list or set of the returns and inferred_types and run UnionType.make_simplified_union just once at the top-level. (My main concern is that make_simplified_union repeatedly calls is_proper_subtype, and I'm not sure whether or not that call is supposed to be efficient or not).

Up to you if you decide to try this or not though -- this sounds like a minor hassle/might not end up being worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we can use sets (although I really want to), the problem is that we need to have a stable/predictable order of union items to avoid flaky tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think I figured out how to avoid multiple intermediate unions, and still have stable order, we can return a set of pairs (indexes of matched overload, return type), at the very end we construct a union once but before sort items by index.

Another important note: I think returning a union of inferred callable types may be wrong here, we should actually return a single unioned callable (or fall back to Callable[..., Union[<ret_types>]] if the matched variants are too different to combine). The motivation is that (Union[int -> None, str -> None])(Union[int, str]) is not a valid call, maybe some code (even future code) relies on the fact that inferred callable type and initial arg types are consistent.

return None

return callables[0].copy_modified(
arg_types=final_args,
arg_kinds=new_kinds,
ret_type=UnionType.make_simplified_union(new_returns),
variables=variables,
implicit=True)
def real_union(self, typ: Type) -> bool:
return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1

def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
Expand Down Expand Up @@ -2666,6 +2634,8 @@ def accept(self,
is True and this expression is a call, allow it to return None. This
applies only to this expression and not any subexpressions.
"""
if node in self.type_overrides:
return self.type_overrides[node]
self.type_context.append(type_context)
try:
if allow_none_return and isinstance(node, CallExpr):
Expand Down
Loading