Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative algorithm for union math #5255

Merged
merged 13 commits into from
Jul 3, 2018

Conversation

ilevkivskyi
Copy link
Member

@ilevkivskyi ilevkivskyi commented Jun 20, 2018

Fixes #5243
Fixes #5249

Some comments:

  • I went ahead with a slow but very simple recursive algorithm that treats all various complex cases correctly. On one hand it can be quadratic, but on the other hand, the complexity will be bad only if user abuses lots of unions
  • This is WIP because I use a hack caused by the fact that currently most function inference functions pass argument expressions instead of types
  • It may look like there are many changes in tests, but actually there are not, the differences are because:
    • Error messages now show the first potentially matching overload (which is kind of OK I think)
    • Order of items in many unions turned to the opposite, apparently union __repr__ is unstable.

@ilevkivskyi ilevkivskyi requested a review from Michael0x2a June 20, 2018 17:44
Copy link
Collaborator

@Michael0x2a Michael0x2a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good to me!

Most of my feedback is related to streamlining the logic -- the only substantive comment I left is with the asymptotic efficiency: I think this new algorithm is actually exponential, not quadratic, so I think it might be worth considering adding in some sort of limiter just in case.

object_type: Optional[Type] = None) -> Tuple[Type, Type]:
object_type: Optional[Type] = None,
*,
arg_types_override: Optional[List[Type]] = None) -> Tuple[Type, Type]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed this, but it doesn't look like we ever actually pass in something other then None for arg_types_override?

This param also feels a bit redundant given we also have the type_overrides field -- it seems like they're both trying to do the same thing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an artefact of a previous attempt to do this "right way".

unioned_result = self.union_overload_result(plausible_targets, args, arg_types,
arg_kinds, arg_names,
callable_name, object_type,
context, arg_messages=unioned_errors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check: switching to doing union math on plausible_targets instead of erased_targets is an intentional change?

(The former variable is supposed to contain callables that just have the right number of args; the latter is supposed to contain (non-erased) callables that match after performing type erasure. The names admittedly aren't very clear, but I'm planning on submitting another PR soon that cleans that up)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought erased_targets are calculated using original types (i.e. including full unions). I just wanted to throw away once all the overloads that have no way to match, while keep those that can match after a union destructuring. I think plausible_targets is what is needed or am I wrong?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think erased_targets is computed in a way that it matches a callable if at least one of the union elements, not if all of them match. That allowed it to act as an additional filter that pruned plausible_targets down even further.

That said, I think the way we compute the erased targets is currently a bit janky -- pretty much all of that code is from the old overloads implementation and is something I'm planning on refactoring/simplifying in a future PR. So, I wouldn't be too surprised if it didn't end up working here. (But if it did, it might be a way we could get a minor speedup in the case where the list of overloads is long.)

That said, I guess that's a bit of an edge case, so maybe just using plausible_targets is the right thing to do here, since it's more obviously correct.

args: List[Expression],
arg_types: List[Type],
arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off-topic: I wonder if it might be worth writing some sort of custom class/namedtuple that stores arg, arg_types, and friends in some future refactoring PR. It's starting to become pretty awkward passing around all of these types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have many potential refactoring ideas on our internal roadmap, we can add this one to the list.

object_type: Optional[Type],
context: Context,
arg_messages: Optional[MessageBuilder] = None,
) -> Optional[Tuple[Type, Type]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about this, I think this algorithm might actually be exponential, not quadratic. (E.g. assuming we have n union arguments, each with 2 options, we'd do T(n) = T(n - 1) + T(n - 1) work, which is 2n).

In light of that, I wonder if it might maybe be worth adding in some sort of limit? E.g. if we recurse to a certain depth, we give up -- or maybe we add a check up above that skips union math entirely if the call has more then some set number of union arguments.

Copy link
Member Author

@ilevkivskyi ilevkivskyi Jun 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, it can be much worse. Here is some more detailed analysis.

The relevant numbers that we have here are:

  • number of (potentially matching) overloads, p
  • number of unions in args, k
  • numbers of items in every union n_i, i = 1, ..., k

The worst case scenario is when every check succeeds, in this case we have p * n_1 * ... * n_k inference attempts, assuming an average n items per union this is roughly p * n ** k. This might look very bad, but I have some intuition why in practice average/typical performance may be much better. The key is that in practice it is hard to write an overload that will match so many combinations, plus there is a return None that will fully "unwind" the union call stack on first failure. So for any realistic overload the number of union item combinations that successfully match will be limited by number of overloads p, for example to match successfully two args with types Union[A, B] and Union[C, D] one will need four overloads. Thus we will typically make total at most p ** 2 inference attempts.

Actually, now thinking more about this, we can also bailout soon on success, i.e. after splitting the first union, before we split the next one, we can check, maybe it already matches, for example:

@overload
def f(x: int, y: object, z: object) -> int: ...
@overload
def f(x: str, y: object, z: object) -> str: ...

x: Union[int, str]
f(x, x, x)

There is not need to split the second and third argument. This will need just few extra lines, but can be a decent performance gain.

An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth.

EDIT: typo: overload -> argument

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key is that in practice it is hard to write an overload that will match so many combinations

I'm not sure if it'll actually be that hard in practice, especially if the overload is intended to match a wide variety of different calls using a mixture of typevars, Any types, object types, *args, **kwargs...

For example, here's a program which calls a simplified version of map. It's a bit contrived, but I don't think it's too far off from the sorts of programs some users might write in practice:

from typing import overload, Any, TypeVar, Iterable, List, Dict, Callable, Union

S = TypeVar('S')

@overload
def simple_map() -> None: ...
@overload
def simple_map(func: Callable[..., S], *iterables: Iterable[Any]) -> S: ...
def simple_map(*args): pass

def format_row(*entries: object) -> str: pass

class DateTime: pass
JsonBlob = Dict[str, Any]
Column = Union[List[str], List[int], List[bool], List[float], List[DateTime], List[JsonBlob]]

def print_custom_table() -> None:
    a: Column
    b: Column
    c: Column
    d: Column
    e: Column
    f: Column
    g: Column

    for row in simple_map(format_row, a, b, c, d, e, f):
        print(row)

If I try running simple_map with 6 columns (a through f), it takes the new algorithm roughly a minute and a half on my computer to type check this program. If I tack on the 7th column (g), it takes mypy significantly longer. (I don't have an exact time unfortunately: I got bored of waiting after 10-15 minutes and killed the thing.)

I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).

An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth.

I also like the simplicity of this algorithm. (It's also not obvious to me how we could robustly solve this problem in any other way, in any case.)

But yeah, I still think we should add a limit of some sort. Even if the pathological cases end up being rare, I think it'd be a poor user experience if mypy starts unexpectedly stalling in those edge cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).

It actually helped. I added it (simplistic WIP style, need to polish it) and your example passed in a second (even with eight(!) union args) also few other complex examples I tried passed instantly.

But yeah, I still think we should add a limit of some sort.

I don't want to set a calculated limit, I would add a nesting/split level limit probably 5, and then raise TooManyUnionsError, so that we can better diagnose an error and add a note like:

main.py: error: No overload variant blah-blah-blah..
main.py: note: Not all union combinations were tried because there are too many unions

There are some ideas to avoid hitting the limit, like add some back-tracking add multi-seed (like start splitting the unions from last arg as well). But I will not do this, because this will only add code complexity while gaining some super-rare corner cases.

del self.type_overrides[arg]
return res
first_union = next(typ for typ in arg_types if self.real_union(typ))
idx = arg_types.index(first_union)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a micro-optimization, but I think we could combine the above two lines and the if-check into one loop. E.g. try finding the index first, then case on that to see if we should enter the base case or not.

@@ -2519,10 +2518,10 @@ def wrapper() -> None:

obj2: Union[W1[A], W2[B]]

foo(obj2) # E: Cannot infer type argument 1 of "foo"
foo(obj2) # OK
bar(obj2) # E: Cannot infer type argument 1 of "bar"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you if you want to do this or not, but it might be worth removing the bar function in this test. It exists mostly so we could compare the output of the overload with a manually-unioned callable -- but the comparison is no longer relevant now that we're not actually producing a unioned callable.

I think we have several similar manually-unioned functions in the surrounding test cases here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually what do you think about the new error messages? They are a bit shorter and match a bit more the logic of "match the first overload", but there is less "symmetry" between overloads and normal functions that have unions in signature.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm ok with the new error messages. It does lose symmetry, but idk if the user was expecting that symmetry to begin with. The new messages also require less explanation: if we report back that a union is expected when the overload definitions doesn't contain any unions at all, I could see that confusing a user who wasn't aware of the union math feature.

@@ -2646,12 +2645,13 @@ def t_is_same_bound(arg1: T1, arg2: S) -> Tuple[T1, S]:
# The arguments in the tuple are swapped
x3: Union[List[S], List[Tuple[S, T1]]]
y3: S
Dummy[T1]().foo(x3, y3) # E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[S], List[Tuple[S, T1]]]"; expected "Union[List[Tuple[T1, S]], List[S]]"
Dummy[T1]().foo(x3, y3) # E: Cannot infer type argument 1 of "foo" of "Dummy" \
# E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[S], List[Tuple[S, T1]]]"; expected "List[Tuple[T1, Any]]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, what?! We can embed multiline error messages?

That's amazing!

context, arg_messages=unioned_errors)
# Record if we succeeded. Next we need to see if maybe normal procedure
# gives a narrower type.
union_success = unioned_result is not None and not unioned_errors.is_errors()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the way you implemented union_overload_result, I think it's impossible for union_errors to actually ever contain an error. As a consequence, union_success will be true if and only if union_result is not None is also true.

In that case, I think we can simplify some of the logic by getting rid of union_success and using just union_result. E.g. here, we replace this line with if union_errors.is_errors(): unioned_result = None just in case and get rid of union_success entirely.

context, arg_messages=unioned_errors)
# Record if we succeeded. Next we need to see if maybe normal procedure
# gives a narrower type.
union_success = unioned_result is not None and not unioned_errors.is_errors()

# Step 3: We try checking each branch one-by-one.
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github won't let me leave a comment down below after doing the simplification suggested up above, I think we can also simplify the following if statements to something roughly like:

# Success! Stop early by returning the best among normal and unioned.
if inferred_result is not None and unioned_result is not None:
    return inferred_result if is_subtype(inferred_result[0], unioned_result[0]) else unioned_result
elif inferred_result is not None:
    return inferred_result
elif unioned_result is not None:
    return unioned_result

erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types,
arg_kinds, arg_names, context)

# Step 5: We try and infer a second-best alternative if possible. If not, fall back
# to using 'Any'.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github still won't let me leave a comment below, but I think after doing the above two simplifications, we can get rid of the next if statement entirely.

(It's a bit hard to tell, but based on new error messages in the changed tests, it seems like we're no longer entering into this case at all, so it's redundant either way.)

x: D
y: Union[D, Any]
# TODO: update after we decide on https://github.com/python/mypy/pull/5254
reveal_type(x.f(y)) # E: Revealed type is 'Any'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment to myself, we need to fix this before merging this PR, Jukka prefers Union[D, Any] here, which actually the result a unioned match, but direct match is Any, so:

@ilevkivskyi
Copy link
Member Author

@Michael0x2a thanks for a detailed review! I will implement the changes tomorrow probably.

@ilevkivskyi
Copy link
Member Author

@Michael0x2a Btw what do you think about the type_overrides hack? I am a bit worried actually about its performance implications, it is probably small 1-2%, but it applies to every expression visit. We have something very similar for multiassign from union, but somehow it works only with temporarily modifying type_map (btw I should probably also make type_overrides a context manager, like in case of multiassign from union).

@Michael0x2a
Copy link
Collaborator

Regarding the type_overrides hack: I agree that it's suboptimal and that something using arg_types_override approach would be nice. If there's no other way to make this approach work though, I guess I wouldn't be too upset -- it's at least easy to understand.

if not any(self.real_union(typ) for typ in arg_types):
# No unions in args, just fall back to normal inference
for arg, typ in zip(args, arg_types):
self.type_overrides[arg] = typ
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be made a context manager, also we need to be sure that the last type for an arg expression (the one that stays in type map after we leave the call expression) is a union.

@ilevkivskyi ilevkivskyi changed the title [WIP] Alternative algorithm for union math Alternative algorithm for union math Jul 1, 2018
@ilevkivskyi
Copy link
Member Author

@Michael0x2a This should be ready for review now.

Copy link
Collaborator

@Michael0x2a Michael0x2a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! I did have one or two questions about union_overload_matches and several typo nits, but no real remaining concerns.

Feel free to merge whenever!

if level >= MAX_UNIONS:
raise TooManyUnions

# Step 2: Find position of the first union in arguments. Return the normal infered
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really minor nit: "infered" -> "inferred"

sound manner.

Assumes all of the given callables have argument counts compatible with the caller.
If there is at least one non-callabe type, return Any (this can happen if there is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"callabe" -> "callable"

unioned_result = (UnionType.make_simplified_union(list(returns),
context.line,
context.column),
self.union_overload_matches(inferred_types))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is there any reason why we switched from using UnionType.make_simplified_union to self.union_overload_matches for the inferred types?

(fwiw, I don't mind whether we keep it or not -- if we remove it, it simplifies the code; if we keep it, it might help us perform better subtype checks with overloads in the future.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that if I would write

f: Union[Callable[[int], int], Callable[[str], str]]
x: Union[int, str]

then mypy would (correctly flagged the call) as an error. Therefore I don't want mypy to infer callable type that is formally inconsistent with argument type(s). I will add a comment about this.

@@ -58,6 +62,16 @@
ArgChecker = Callable[[Type, Type, int, Type, int, int, CallableType, Context, MessageBuilder],
None]

# Maximum nesting level for math union in overloads, setting this to large values
# may cause performance issues.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should include add a blurb + a link telling people to read this PR if they want more details? Up to you.

object_type, context, arg_messages,
level + 1)
if sub_result is not None:
res_items.append(sub_result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another minor nit: I think it'd look cleaner if we do res_items.extend(sub_result) here. We can then simplify the nested loop down below to just a single one. (And we'd need to change the comment to something like "Step 5: If splitting succeeded, filter out duplicate items".)

# Some item doesn't match, return soon.
return None

# Step 5: If spliting succeeded, then flatten union results into a single
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"spliting" -> "splitting"

# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
# the exact same signature. The only exception is if one arg is optional and
# the other is positional: in that case, we continue unioning (and expect a
# positional arg).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about falling back to Union[callable1, callable2, ..., callableN] instead? It would let us introduce fewer new Anys. Or does that end up breaking too many things?

(I think this is basically the same question as the previous one -- maybe add a comment here about this?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or does that end up breaking too many things?

It actually breaks nothing, I just don't want to introduce some internal inconsistencies (as I explained above). Also the Anys here don't impact code precision currently, we mostly care about the return type, not inferred callable type.

@ilevkivskyi ilevkivskyi merged commit 0ca6bf9 into python:master Jul 3, 2018
@ilevkivskyi ilevkivskyi deleted the new-union-math branch July 3, 2018 11:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants