-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Alternative algorithm for union math #5255
Changes from 2 commits
057e329
b796af7
d127ecb
e98cb76
7057796
4e6a03d
33e845f
dd8adba
c934465
5f91877
b9659e1
7d02ba8
643c3b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,6 +118,7 @@ def __init__(self, | |
self.msg = msg | ||
self.plugin = plugin | ||
self.type_context = [None] | ||
self.type_overrides = {} # type: Dict[Expression, Type] | ||
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) | ||
|
||
def visit_name_expr(self, e: NameExpr) -> Type: | ||
|
@@ -519,7 +520,9 @@ def check_call(self, callee: Type, args: List[Expression], | |
callable_node: Optional[Expression] = None, | ||
arg_messages: Optional[MessageBuilder] = None, | ||
callable_name: Optional[str] = None, | ||
object_type: Optional[Type] = None) -> Tuple[Type, Type]: | ||
object_type: Optional[Type] = None, | ||
*, | ||
arg_types_override: Optional[List[Type]] = None) -> Tuple[Type, Type]: | ||
"""Type check a call. | ||
|
||
Also infer type arguments if the callee is a generic function. | ||
|
@@ -575,9 +578,11 @@ def check_call(self, callee: Type, args: List[Expression], | |
callee, context) | ||
callee = self.infer_function_type_arguments( | ||
callee, args, arg_kinds, formal_to_actual, context) | ||
|
||
arg_types = self.infer_arg_types_in_context2( | ||
callee, args, arg_kinds, formal_to_actual) | ||
if arg_types_override is not None: | ||
arg_types = arg_types_override.copy() | ||
else: | ||
arg_types = self.infer_arg_types_in_context2( | ||
callee, args, arg_kinds, formal_to_actual) | ||
|
||
self.check_argument_count(callee, arg_types, arg_kinds, | ||
arg_names, formal_to_actual, context, self.msg) | ||
|
@@ -1130,22 +1135,15 @@ def check_overload_call(self, | |
unioned_result = None # type: Optional[Tuple[Type, Type]] | ||
unioned_errors = None # type: Optional[MessageBuilder] | ||
union_success = False | ||
if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union | ||
for arg in arg_types): | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
unioned_callable = self.union_overload_matches(erased_targets) | ||
|
||
if unioned_callable is not None: | ||
unioned_errors = arg_messages.clean_copy() | ||
unioned_result = self.check_call(unioned_callable, args, arg_kinds, | ||
context, arg_names, | ||
arg_messages=unioned_errors, | ||
callable_name=callable_name, | ||
object_type=object_type) | ||
# Record if we succeeded. Next we need to see if maybe normal procedure | ||
# gives a narrower type. | ||
union_success = unioned_result is not None and not unioned_errors.is_errors() | ||
if any(self.real_union(arg) for arg in arg_types): | ||
unioned_errors = arg_messages.clean_copy() | ||
unioned_result = self.union_overload_result(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, | ||
callable_name, object_type, | ||
context, arg_messages=unioned_errors) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to check: switching to doing union math on (The former variable is supposed to contain callables that just have the right number of args; the latter is supposed to contain (non-erased) callables that match after performing type erasure. The names admittedly aren't very clear, but I'm planning on submitting another PR soon that cleans that up) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think That said, I think the way we compute the erased targets is currently a bit janky -- pretty much all of that code is from the old overloads implementation and is something I'm planning on refactoring/simplifying in a future PR. So, I wouldn't be too surprised if it didn't end up working here. (But if it did, it might be a way we could get a minor speedup in the case where the list of overloads is long.) That said, I guess that's a bit of an edge case, so maybe just using |
||
# Record if we succeeded. Next we need to see if maybe normal procedure | ||
# gives a narrower type. | ||
union_success = unioned_result is not None and not unioned_errors.is_errors() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the way you implemented In that case, I think we can simplify some of the logic by getting rid of |
||
|
||
# Step 3: We try checking each branch one-by-one. | ||
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Github won't let me leave a comment down below after doing the simplification suggested up above, I think we can also simplify the following if statements to something roughly like:
|
||
|
@@ -1173,9 +1171,8 @@ def check_overload_call(self, | |
# | ||
# Neither alternative matches, but we can guess the user probably wants the | ||
# second one. | ||
if erased_targets is None: | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, | ||
arg_kinds, arg_names, context) | ||
|
||
# Step 5: We try and infer a second-best alternative if possible. If not, fall back | ||
# to using 'Any'. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Github still won't let me leave a comment below, but I think after doing the above two simplifications, we can get rid of the next if statement entirely. (It's a bit hard to tell, but based on new error messages in the changed tests, it seems like we're no longer entering into this case at all, so it's redundant either way.) |
||
|
@@ -1350,91 +1347,62 @@ def overload_erased_call_targets(self, | |
matches.append(typ) | ||
return matches | ||
|
||
def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]: | ||
"""Accepts a list of overload signatures and attempts to combine them together into a | ||
new CallableType consisting of the union of all of the given arguments and return types. | ||
|
||
Returns None if it is not possible to combine the different callables together in a | ||
sound manner. | ||
|
||
Assumes all of the given callables have argument counts compatible with the caller. | ||
def union_overload_result(self, | ||
plausible_targets: List[CallableType], | ||
args: List[Expression], | ||
arg_types: List[Type], | ||
arg_kinds: List[int], | ||
arg_names: Optional[Sequence[Optional[str]]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Off-topic: I wonder if it might be worth writing some sort of custom class/namedtuple that stores There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we have many potential refactoring ideas on our internal roadmap, we can add this one to the list. |
||
callable_name: Optional[str], | ||
object_type: Optional[Type], | ||
context: Context, | ||
arg_messages: Optional[MessageBuilder] = None, | ||
) -> Optional[Tuple[Type, Type]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After thinking about this, I think this algorithm might actually be exponential, not quadratic. (E.g. assuming we have In light of that, I wonder if it might maybe be worth adding in some sort of limit? E.g. if we recurse to a certain depth, we give up -- or maybe we add a check up above that skips union math entirely if the call has more then some set number of union arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you are right, it can be much worse. Here is some more detailed analysis. The relevant numbers that we have here are:
The worst case scenario is when every check succeeds, in this case we have Actually, now thinking more about this, we can also bailout soon on success, i.e. after splitting the first union, before we split the next one, we can check, maybe it already matches, for example: @overload
def f(x: int, y: object, z: object) -> int: ...
@overload
def f(x: str, y: object, z: object) -> str: ...
x: Union[int, str]
f(x, x, x) There is not need to split the second and third argument. This will need just few extra lines, but can be a decent performance gain. An additional comment in the defence of this algorithm is its simplicity of implementation (especially as compared to the breadth of covered use cases). Anyway, if you still think it is dangerous it is easy to put a limit on the call stack depth. EDIT: typo: overload -> argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure if it'll actually be that hard in practice, especially if the overload is intended to match a wide variety of different calls using a mixture of typevars, For example, here's a program which calls a simplified version of from typing import overload, Any, TypeVar, Iterable, List, Dict, Callable, Union
S = TypeVar('S')
@overload
def simple_map() -> None: ...
@overload
def simple_map(func: Callable[..., S], *iterables: Iterable[Any]) -> S: ...
def simple_map(*args): pass
def format_row(*entries: object) -> str: pass
class DateTime: pass
JsonBlob = Dict[str, Any]
Column = Union[List[str], List[int], List[bool], List[float], List[DateTime], List[JsonBlob]]
def print_custom_table() -> None:
a: Column
b: Column
c: Column
d: Column
e: Column
f: Column
g: Column
for row in simple_map(format_row, a, b, c, d, e, f):
print(row) If I try running I think type-checking after each split would probably only aggravate the problem, unfortunately (though I do think it could help in other cases).
I also like the simplicity of this algorithm. (It's also not obvious to me how we could robustly solve this problem in any other way, in any case.) But yeah, I still think we should add a limit of some sort. Even if the pathological cases end up being rare, I think it'd be a poor user experience if mypy starts unexpectedly stalling in those edge cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It actually helped. I added it (simplistic WIP style, need to polish it) and your example passed in a second (even with eight(!) union args) also few other complex examples I tried passed instantly.
I don't want to set a calculated limit, I would add a nesting/split level limit probably 5, and then raise
There are some ideas to avoid hitting the limit, like add some back-tracking add multi-seed (like start splitting the unions from last arg as well). But I will not do this, because this will only add code complexity while gaining some super-rare corner cases. |
||
"""Accepts a list of overload signatures and attempts to match calls by destructuring | ||
the first union. Returns None if there is no match. | ||
""" | ||
if len(callables) == 0: | ||
return None | ||
elif len(callables) == 1: | ||
return callables[0] | ||
|
||
# Note: we are assuming here that if a user uses some TypeVar 'T' in | ||
# two different overloads, they meant for that TypeVar to mean the | ||
# same thing. | ||
# | ||
# This function will make sure that all instances of that TypeVar 'T' | ||
# refer to the same underlying TypeVarType and TypeVarDef objects to | ||
# simplify the union-ing logic below. | ||
# | ||
# (If the user did *not* mean for 'T' to be consistently bound to the | ||
# same type in their overloads, well, their code is probably too | ||
# confusing and ought to be re-written anyways.) | ||
callables, variables = merge_typevars_in_callables_by_name(callables) | ||
|
||
new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]] | ||
new_kinds = list(callables[0].arg_kinds) | ||
new_returns = [] # type: List[Type] | ||
|
||
for target in callables: | ||
# We conservatively end if the overloads do not have the exact same signature. | ||
# The only exception is if one arg is optional and the other is positional: in that | ||
# case, we continue unioning (and expect a positional arg). | ||
# TODO: Enhance the union overload logic to handle a wider variety of signatures. | ||
if len(new_kinds) != len(target.arg_kinds): | ||
if not any(self.real_union(typ) for typ in arg_types): | ||
# No unions in args, just fall back to normal inference | ||
for arg, typ in zip(args, arg_types): | ||
self.type_overrides[arg] = typ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be made a context manager, also we need to be sure that the last type for an arg expression (the one that stays in type map after we leave the call expression) is a union. |
||
res = self.infer_overload_return_type(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages) | ||
for arg, typ in zip(args, arg_types): | ||
del self.type_overrides[arg] | ||
return res | ||
# Try direct match before splitting | ||
direct = self.infer_overload_return_type(plausible_targets, args, arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages) | ||
if direct is not None and not isinstance(direct[0], UnionType): | ||
# We only return non-unions soon, to avoid gredy match. | ||
return direct | ||
first_union = next(typ for typ in arg_types if self.real_union(typ)) | ||
idx = arg_types.index(first_union) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit of a micro-optimization, but I think we could combine the above two lines and the if-check into one loop. E.g. try finding the index first, then case on that to see if we should enter the base case or not. |
||
assert isinstance(first_union, UnionType) | ||
returns = [] | ||
inferred_types = [] | ||
for item in first_union.relevant_items(): | ||
new_arg_types = arg_types.copy() | ||
new_arg_types[idx] = item | ||
sub_result = self.union_overload_result(plausible_targets, args, new_arg_types, | ||
arg_kinds, arg_names, callable_name, | ||
object_type, context, arg_messages) | ||
if sub_result is not None: | ||
ret, inferred = sub_result | ||
returns.append(ret) | ||
inferred_types.append(inferred) | ||
else: | ||
return None | ||
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): | ||
if new_kind == target_kind: | ||
continue | ||
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT): | ||
new_kinds[i] = ARG_POS | ||
else: | ||
return None | ||
|
||
for i, arg in enumerate(target.arg_types): | ||
new_args[i].append(arg) | ||
new_returns.append(target.ret_type) | ||
|
||
union_count = 0 | ||
final_args = [] | ||
for args_list in new_args: | ||
new_type = UnionType.make_simplified_union(args_list) | ||
union_count += 1 if isinstance(new_type, UnionType) else 0 | ||
final_args.append(new_type) | ||
|
||
# TODO: Modify this check to be less conservative. | ||
# | ||
# Currently, we permit only one union in the arguments because if we allow | ||
# multiple, we can't always guarantee the synthesized callable will be correct. | ||
# | ||
# For example, suppose we had the following two overloads: | ||
# | ||
# @overload | ||
# def f(x: A, y: B) -> None: ... | ||
# @overload | ||
# def f(x: B, y: A) -> None: ... | ||
# | ||
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...", | ||
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to | ||
# be rejected. | ||
# | ||
# However, that means we'll also give up if the original overloads contained | ||
# any unions. This is likely unnecessary -- we only really need to give up if | ||
# there are more then one *synthesized* union arguments. | ||
if union_count >= 2: | ||
return None | ||
if returns: | ||
return (UnionType.make_simplified_union(returns, context.line, context.column), | ||
UnionType.make_simplified_union(inferred_types, context.line, context.column)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another micro-optimization, but I wonder if it would be more efficient to return a list or set of the returns and inferred_types and run Up to you if you decide to try this or not though -- this sounds like a minor hassle/might not end up being worth it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure we can use sets (although I really want to), the problem is that we need to have a stable/predictable order of union items to avoid flaky tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I think I figured out how to avoid multiple intermediate unions, and still have stable order, we can return a set of pairs (indexes of matched overload, return type), at the very end we construct a union once but before sort items by index. Another important note: I think returning a union of inferred callable types may be wrong here, we should actually return a single unioned callable (or fall back to |
||
return None | ||
|
||
return callables[0].copy_modified( | ||
arg_types=final_args, | ||
arg_kinds=new_kinds, | ||
ret_type=UnionType.make_simplified_union(new_returns), | ||
variables=variables, | ||
implicit=True) | ||
def real_union(self, typ: Type) -> bool: | ||
return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 | ||
|
||
def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int], | ||
arg_names: Optional[Sequence[Optional[str]]], | ||
|
@@ -2666,6 +2634,8 @@ def accept(self, | |
is True and this expression is a call, allow it to return None. This | ||
applies only to this expression and not any subexpressions. | ||
""" | ||
if node in self.type_overrides: | ||
return self.type_overrides[node] | ||
self.type_context.append(type_context) | ||
try: | ||
if allow_none_return and isinstance(node, CallExpr): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I missed this, but it doesn't look like we ever actually pass in something other then
None
forarg_types_override
?This param also feels a bit redundant given we also have the
type_overrides
field -- it seems like they're both trying to do the same thing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is an artefact of a previous attempt to do this "right way".