-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alternative algorithm for union math #5255
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good to me!
Most of my feedback is related to streamlining the logic -- the only substantive comment I left is with the asymptotic efficiency: I think this new algorithm is actually exponential, not quadratic, so I think it might be worth considering adding in some sort of limiter just in case.
mypy/checkexpr.py
Outdated
object_type: Optional[Type] = None) -> Tuple[Type, Type]: | ||
object_type: Optional[Type] = None, | ||
*, | ||
arg_types_override: Optional[List[Type]] = None) -> Tuple[Type, Type]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I missed this, but it doesn't look like we ever actually pass in something other then None
for arg_types_override
?
This param also feels a bit redundant given we also have the type_overrides
field -- it seems like they're both trying to do the same thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is an artefact of a previous attempt to do this "right way".
mypy/checkexpr.py
Outdated
unioned_result = self.union_overload_result(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, | ||
callable_name, object_type, | ||
context, arg_messages=unioned_errors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check: switching to doing union math on plausible_targets
instead of erased_targets
is an intentional change?
(The former variable is supposed to contain callables that just have the right number of args; the latter is supposed to contain (non-erased) callables that match after performing type erasure. The names admittedly aren't very clear, but I'm planning on submitting another PR soon that cleans that up)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought erased_targets
are calculated using original types (i.e. including full unions). I just wanted to throw away once all the overloads that have no way to match, while keep those that can match after a union destructuring. I think plausible_targets
is what is needed or am I wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think erased_targets
is computed in a way that it matches a callable if at least one of the union elements, not if all of them match. That allowed it to act as an additional filter that pruned plausible_targets
down even further.
That said, I think the way we compute the erased targets is currently a bit janky -- pretty much all of that code is from the old overloads implementation and is something I'm planning on refactoring/simplifying in a future PR. So, I wouldn't be too surprised if it didn't end up working here. (But if it did, it might be a way we could get a minor speedup in the case where the list of overloads is long.)
That said, I guess that's a bit of an edge case, so maybe just using plausible_targets
is the right thing to do here, since it's more obviously correct.
args: List[Expression], | ||
arg_types: List[Type], | ||
arg_kinds: List[int], | ||
arg_names: Optional[Sequence[Optional[str]]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off-topic: I wonder if it might be worth writing some sort of custom class/namedtuple that stores arg
, arg_types
, and friends in some future refactoring PR. It's starting to become pretty awkward passing around all of these types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we have many potential refactoring ideas on our internal roadmap, we can add this one to the list.
mypy/checkexpr.py
Outdated
object_type: Optional[Type], | ||
context: Context, | ||
arg_messages: Optional[MessageBuilder] = None, | ||
) -> Optional[Tuple[Type, Type]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After thinking about this, I think this algorithm might actually be exponential, not quadratic. (E.g. assuming we have n
union arguments, each with 2 options, we'd do T(n) = T(n - 1) + T(n - 1) work, which is 2n).
In light of that, I wonder if it might maybe be worth adding in some sort of limit? E.g. if we recurse to a certain depth, we give up -- or maybe we add a check up above that skips union math entirely if the call has more then some set number of union arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right, it can be much worse. Here is some more detailed analysis.
The relevant numbers that we have here are:
- number of (potentially matching) overloads,
p
- number of unions in args,
k
- numbers of items in every union
n_i, i = 1, ..., k
The worst case scenario is when every check succeeds, in this case we have p * n_1 * ... * n_k
inference attempts, assuming an average n
items per union this is roughly p * n ** k
. This might look very bad, but I have some intuition why in practice average/typical performance may be much better. The key is that in practice it is hard to write an overload that will match so many combinations, plus there is a return None
that will fully "unwind" the union call stack on first failure. So for any realistic overload the number of union item combinations that successfully match will be limited by number of overloads p
, for example to match successfully two args with types Union[A, B]
and Union[C, D]
one will need four overloads. Thus we will typically make total at most p ** 2
inference attempts.
Actually, now thinking more about this, we can also bailout soon on success, i.e. after splitting the first union, before we split the next one, we can check, maybe it already matches, for example:
@overload
def f(x: int, y: object, z: object) -> int: ...
@overload
def f(x: str, y: object, z: object) -> str: ...
x: Union[int, str]
f(x, x, x)
There is not need to split the second and third argument. This will need just few extra lines, but can be a decent performance gain.
An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth.
EDIT: typo: overload -> argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key is that in practice it is hard to write an overload that will match so many combinations
I'm not sure if it'll actually be that hard in practice, especially if the overload is intended to match a wide variety of different calls using a mixture of typevars, Any
types, object types, *args
, **kwargs
...
For example, here's a program which calls a simplified version of map
. It's a bit contrived, but I don't think it's too far off from the sorts of programs some users might write in practice:
from typing import overload, Any, TypeVar, Iterable, List, Dict, Callable, Union
S = TypeVar('S')
@overload
def simple_map() -> None: ...
@overload
def simple_map(func: Callable[..., S], *iterables: Iterable[Any]) -> S: ...
def simple_map(*args): pass
def format_row(*entries: object) -> str: pass
class DateTime: pass
JsonBlob = Dict[str, Any]
Column = Union[List[str], List[int], List[bool], List[float], List[DateTime], List[JsonBlob]]
def print_custom_table() -> None:
a: Column
b: Column
c: Column
d: Column
e: Column
f: Column
g: Column
for row in simple_map(format_row, a, b, c, d, e, f):
print(row)
If I try running simple_map
with 6 columns (a through f), it takes the new algorithm roughly a minute and a half on my computer to type check this program. If I tack on the 7th column (g), it takes mypy significantly longer. (I don't have an exact time unfortunately: I got bored of waiting after 10-15 minutes and killed the thing.)
I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).
An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth.
I also like the simplicity of this algorithm. (It's also not obvious to me how we could robustly solve this problem in any other way, in any case.)
But yeah, I still think we should add a limit of some sort. Even if the pathological cases end up being rare, I think it'd be a poor user experience if mypy starts unexpectedly stalling in those edge cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).
It actually helped. I added it (simplistic WIP style, need to polish it) and your example passed in a second (even with eight(!) union args) also few other complex examples I tried passed instantly.
But yeah, I still think we should add a limit of some sort.
I don't want to set a calculated limit, I would add a nesting/split level limit probably 5, and then raise TooManyUnionsError
, so that we can better diagnose an error and add a note like:
main.py: error: No overload variant blah-blah-blah..
main.py: note: Not all union combinations were tried because there are too many unions
There are some ideas to avoid hitting the limit, like add some back-tracking add multi-seed (like start splitting the unions from last arg as well). But I will not do this, because this will only add code complexity while gaining some super-rare corner cases.
mypy/checkexpr.py
Outdated
del self.type_overrides[arg] | ||
return res | ||
first_union = next(typ for typ in arg_types if self.real_union(typ)) | ||
idx = arg_types.index(first_union) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a micro-optimization, but I think we could combine the above two lines and the if-check into one loop. E.g. try finding the index first, then case on that to see if we should enter the base case or not.
@@ -2519,10 +2518,10 @@ def wrapper() -> None: | |||
|
|||
obj2: Union[W1[A], W2[B]] | |||
|
|||
foo(obj2) # E: Cannot infer type argument 1 of "foo" | |||
foo(obj2) # OK | |||
bar(obj2) # E: Cannot infer type argument 1 of "bar" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to you if you want to do this or not, but it might be worth removing the bar
function in this test. It exists mostly so we could compare the output of the overload with a manually-unioned callable -- but the comparison is no longer relevant now that we're not actually producing a unioned callable.
I think we have several similar manually-unioned functions in the surrounding test cases here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually what do you think about the new error messages? They are a bit shorter and match a bit more the logic of "match the first overload", but there is less "symmetry" between overloads and normal functions that have unions in signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm ok with the new error messages. It does lose symmetry, but idk if the user was expecting that symmetry to begin with. The new messages also require less explanation: if we report back that a union is expected when the overload definitions doesn't contain any unions at all, I could see that confusing a user who wasn't aware of the union math feature.
@@ -2646,12 +2645,13 @@ def t_is_same_bound(arg1: T1, arg2: S) -> Tuple[T1, S]: | |||
# The arguments in the tuple are swapped | |||
x3: Union[List[S], List[Tuple[S, T1]]] | |||
y3: S | |||
Dummy[T1]().foo(x3, y3) # E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[S], List[Tuple[S, T1]]]"; expected "Union[List[Tuple[T1, S]], List[S]]" | |||
Dummy[T1]().foo(x3, y3) # E: Cannot infer type argument 1 of "foo" of "Dummy" \ | |||
# E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[S], List[Tuple[S, T1]]]"; expected "List[Tuple[T1, Any]]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, what?! We can embed multiline error messages?
That's amazing!
mypy/checkexpr.py
Outdated
context, arg_messages=unioned_errors) | ||
# Record if we succeeded. Next we need to see if maybe normal procedure | ||
# gives a narrower type. | ||
union_success = unioned_result is not None and not unioned_errors.is_errors() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the way you implemented union_overload_result
, I think it's impossible for union_errors
to actually ever contain an error. As a consequence, union_success
will be true if and only if union_result is not None
is also true.
In that case, I think we can simplify some of the logic by getting rid of union_success
and using just union_result
. E.g. here, we replace this line with if union_errors.is_errors(): unioned_result = None
just in case and get rid of union_success
entirely.
mypy/checkexpr.py
Outdated
context, arg_messages=unioned_errors) | ||
# Record if we succeeded. Next we need to see if maybe normal procedure | ||
# gives a narrower type. | ||
union_success = unioned_result is not None and not unioned_errors.is_errors() | ||
|
||
# Step 3: We try checking each branch one-by-one. | ||
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Github won't let me leave a comment down below after doing the simplification suggested up above, I think we can also simplify the following if statements to something roughly like:
# Success! Stop early by returning the best among normal and unioned.
if inferred_result is not None and unioned_result is not None:
return inferred_result if is_subtype(inferred_result[0], unioned_result[0]) else unioned_result
elif inferred_result is not None:
return inferred_result
elif unioned_result is not None:
return unioned_result
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
|
||
# Step 5: We try and infer a second-best alternative if possible. If not, fall back | ||
# to using 'Any'. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Github still won't let me leave a comment below, but I think after doing the above two simplifications, we can get rid of the next if statement entirely.
(It's a bit hard to tell, but based on new error messages in the changed tests, it seems like we're no longer entering into this case at all, so it's redundant either way.)
x: D | ||
y: Union[D, Any] | ||
# TODO: update after we decide on https://github.com/python/mypy/pull/5254 | ||
reveal_type(x.f(y)) # E: Revealed type is 'Any' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment to myself, we need to fix this before merging this PR, Jukka prefers Union[D, Any]
here, which actually the result a unioned match, but direct match is Any
, so:
- Make overload checks more strict when there are multiple 'Any's #5254 should go in first
- after that we might tweak the choice between direct and unioned match
@Michael0x2a thanks for a detailed review! I will implement the changes tomorrow probably. |
@Michael0x2a Btw what do you think about the |
Regarding the |
mypy/checkexpr.py
Outdated
if not any(self.real_union(typ) for typ in arg_types): | ||
# No unions in args, just fall back to normal inference | ||
for arg, typ in zip(args, arg_types): | ||
self.type_overrides[arg] = typ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be made a context manager, also we need to be sure that the last type for an arg expression (the one that stays in type map after we leave the call expression) is a union.
@Michael0x2a This should be ready for review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks! I did have one or two questions about union_overload_matches
and several typo nits, but no real remaining concerns.
Feel free to merge whenever!
mypy/checkexpr.py
Outdated
if level >= MAX_UNIONS: | ||
raise TooManyUnions | ||
|
||
# Step 2: Find position of the first union in arguments. Return the normal infered |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really minor nit: "infered" -> "inferred"
mypy/checkexpr.py
Outdated
sound manner. | ||
|
||
Assumes all of the given callables have argument counts compatible with the caller. | ||
If there is at least one non-callabe type, return Any (this can happen if there is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"callabe" -> "callable"
mypy/checkexpr.py
Outdated
unioned_result = (UnionType.make_simplified_union(list(returns), | ||
context.line, | ||
context.column), | ||
self.union_overload_matches(inferred_types)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is there any reason why we switched from using UnionType.make_simplified_union
to self.union_overload_matches
for the inferred types?
(fwiw, I don't mind whether we keep it or not -- if we remove it, it simplifies the code; if we keep it, it might help us perform better subtype checks with overloads in the future.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that if I would write
f: Union[Callable[[int], int], Callable[[str], str]]
x: Union[int, str]
then mypy would (correctly flagged the call) as an error. Therefore I don't want mypy to infer callable type that is formally inconsistent with argument type(s). I will add a comment about this.
mypy/checkexpr.py
Outdated
@@ -58,6 +62,16 @@ | |||
ArgChecker = Callable[[Type, Type, int, Type, int, int, CallableType, Context, MessageBuilder], | |||
None] | |||
|
|||
# Maximum nesting level for math union in overloads, setting this to large values | |||
# may cause performance issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should include add a blurb + a link telling people to read this PR if they want more details? Up to you.
mypy/checkexpr.py
Outdated
object_type, context, arg_messages, | ||
level + 1) | ||
if sub_result is not None: | ||
res_items.append(sub_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another minor nit: I think it'd look cleaner if we do res_items.extend(sub_result)
here. We can then simplify the nested loop down below to just a single one. (And we'd need to change the comment to something like "Step 5: If splitting succeeded, filter out duplicate items".)
mypy/checkexpr.py
Outdated
# Some item doesn't match, return soon. | ||
return None | ||
|
||
# Step 5: If spliting succeeded, then flatten union results into a single |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"spliting" -> "splitting"
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have | ||
# the exact same signature. The only exception is if one arg is optional and | ||
# the other is positional: in that case, we continue unioning (and expect a | ||
# positional arg). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about falling back to Union[callable1, callable2, ..., callableN]
instead? It would let us introduce fewer new Any
s. Or does that end up breaking too many things?
(I think this is basically the same question as the previous one -- maybe add a comment here about this?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or does that end up breaking too many things?
It actually breaks nothing, I just don't want to introduce some internal inconsistencies (as I explained above). Also the Any
s here don't impact code precision currently, we mostly care about the return type, not inferred callable type.
Fixes #5243
Fixes #5249
Some comments:
__repr__
is unstable.