Skip to content

Commit

Permalink
Fix inference when class and instance match protocol (#18587)
Browse files Browse the repository at this point in the history
Fixes #14688

The bug resulted from (accidentally) inferring against `Iterable` for
both instance and class object. While working on this I noticed there
are also couple flaws in direction handling in constrain inference,
namely:
* A protocol can never ever be a subtype of class object or a `Type[X]`
* When matching against callback protocol, subtype check direction must
match inference direction

I also (conservatively) fix some unrelated issues uncovered by the fix
(to avoid fallout):
* Callable subtyping with trivial suffixes was broken for
positional-only args
* Join of `Parameters` could lead to meaningless results in case of
incompatible arg kinds
* Protocol inference was inconsistent with protocol subtyping w.r.t.
metaclasses.
  • Loading branch information
ilevkivskyi authored Feb 3, 2025
1 parent c8489a2 commit 274af1c
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 28 deletions.
51 changes: 31 additions & 20 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,40 +756,40 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
"__call__", template, actual, is_operator=True
)
assert call is not None
if mypy.subtypes.is_subtype(actual, erase_typevars(call)):
subres = infer_constraints(call, actual, self.direction)
res.extend(subres)
if (
self.direction == SUPERTYPE_OF
and mypy.subtypes.is_subtype(actual, erase_typevars(call))
or self.direction == SUBTYPE_OF
and mypy.subtypes.is_subtype(erase_typevars(call), actual)
):
res.extend(infer_constraints(call, actual, self.direction))
template.type.inferring.pop()
if isinstance(actual, CallableType) and actual.fallback is not None:
if actual.is_type_obj() and template.type.is_protocol:
if (
actual.is_type_obj()
and template.type.is_protocol
and self.direction == SUPERTYPE_OF
):
ret_type = get_proper_type(actual.ret_type)
if isinstance(ret_type, TupleType):
ret_type = mypy.typeops.tuple_fallback(ret_type)
if isinstance(ret_type, Instance):
if self.direction == SUBTYPE_OF:
subtype = template
else:
subtype = ret_type
res.extend(
self.infer_constraints_from_protocol_members(
ret_type, template, subtype, template, class_obj=True
ret_type, template, ret_type, template, class_obj=True
)
)
actual = actual.fallback
if isinstance(actual, TypeType) and template.type.is_protocol:
if isinstance(actual.item, Instance):
if self.direction == SUBTYPE_OF:
subtype = template
else:
subtype = actual.item
res.extend(
self.infer_constraints_from_protocol_members(
actual.item, template, subtype, template, class_obj=True
)
)
if self.direction == SUPERTYPE_OF:
# Infer constraints for Type[T] via metaclass of T when it makes sense.
a_item = actual.item
if isinstance(a_item, Instance):
res.extend(
self.infer_constraints_from_protocol_members(
a_item, template, a_item, template, class_obj=True
)
)
# Infer constraints for Type[T] via metaclass of T when it makes sense.
if isinstance(a_item, TypeVarType):
a_item = get_proper_type(a_item.upper_bound)
if isinstance(a_item, Instance) and a_item.type.metaclass_type:
Expand Down Expand Up @@ -1043,6 +1043,17 @@ def infer_constraints_from_protocol_members(
return [] # See #11020
# The above is safe since at this point we know that 'instance' is a subtype
# of (erased) 'template', therefore it defines all protocol members
if class_obj:
# For class objects we must only infer constraints if possible, otherwise it
# can lead to confusion between class and instance, for example StrEnum is
# Iterable[str] for an instance, but Iterable[StrEnum] for a class object.
if not mypy.subtypes.is_subtype(
inst, erase_typevars(temp), ignore_pos_arg_names=True
):
continue
# This exception matches the one in subtypes.py, see PR #14121 for context.
if member == "__call__" and instance.type.is_metaclass():
continue
res.extend(infer_constraints(temp, inst, self.direction))
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol):
# Settable members are invariant, add opposite constraints
Expand Down
12 changes: 11 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType:

def visit_parameters(self, t: Parameters) -> ProperType:
if isinstance(self.s, Parameters):
if len(t.arg_types) != len(self.s.arg_types):
if not is_similar_params(t, self.s):
# TODO: it would be prudent to return [*object, **object] instead of Any.
return self.default(self.s)
from mypy.meet import meet_types

Expand Down Expand Up @@ -724,6 +725,15 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool:
)


def is_similar_params(t: Parameters, s: Parameters) -> bool:
# This matches the logic in is_similar_callables() above.
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and (t.var_arg() is not None) == (s.var_arg() is not None)
)


def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
tv_map = {}
tvs = []
Expand Down
9 changes: 7 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,11 +1719,16 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
):
return False

if trivial_suffix:
# For trivial right suffix we *only* check that every non-star right argument
# has a valid match on the left.
return True

# Phase 1c: Check var args. Right has an infinite series of optional positional
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None and not trivial_suffix:
if right_star is not None:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand All @@ -1750,7 +1755,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1d: Check kw args. Right has an infinite series of optional named
# arguments. Get all further named args of left, and make sure
# they're more general than the corresponding member in right.
if right_star2 is not None and not trivial_suffix:
if right_star2 is not None:
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
Expand Down
22 changes: 22 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -2394,3 +2394,25 @@ def do_check(value: E) -> None:

[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrEnumClassCorrectIterable]
from enum import StrEnum
from typing import Type, TypeVar

class Choices(StrEnum):
LOREM = "lorem"
IPSUM = "ipsum"

var = list(Choices)
reveal_type(var) # N: Revealed type is "builtins.list[__main__.Choices]"

e: type[StrEnum]
reveal_type(list(e)) # N: Revealed type is "builtins.list[enum.StrEnum]"

T = TypeVar("T", bound=StrEnum)
def list_vals(e: Type[T]) -> list[T]:
reveal_type(list(e)) # N: Revealed type is "builtins.list[T`-1]"
return list(e)

reveal_type(list_vals(Choices)) # N: Revealed type is "builtins.list[__main__.Choices]"
[builtins fixtures/enum.pyi]
30 changes: 26 additions & 4 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,38 @@ if int():
h = h

[case testSubtypingFunctionsDoubleCorrespondence]
def l(x) -> None: ...
def r(__x, *, x) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]")

[case testSubtypingFunctionsDoubleCorrespondenceNamedOptional]
def l(x) -> None: ...
def r(__, *, x) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]")
def r(__x, *, x = 1) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]")

[case testSubtypingFunctionsRequiredLeftArgNotPresent]
[case testSubtypingFunctionsDoubleCorrespondenceBothNamedOptional]
def l(x = 1) -> None: ...
def r(__x, *, x = 1) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]")

[case testSubtypingFunctionsTrivialSuffixRequired]
def l(__x) -> None: ...
def r(x, *args, **kwargs) -> None: ...

r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Arg(Any, 'x'), VarArg(Any), KwArg(Any)], None]")
[builtins fixtures/dict.pyi]

[case testSubtypingFunctionsTrivialSuffixOptional]
def l(__x = 1) -> None: ...
def r(x = 1, *args, **kwargs) -> None: ...

r = l # E: Incompatible types in assignment (expression has type "Callable[[DefaultArg(Any)], None]", variable has type "Callable[[DefaultArg(Any, 'x'), VarArg(Any), KwArg(Any)], None]")
[builtins fixtures/dict.pyi]

[case testSubtypingFunctionsRequiredLeftArgNotPresent]
def l(x, y) -> None: ...
def r(x) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]")
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]")

[case testSubtypingFunctionsImplicitNames]
from typing import Any
Expand Down
27 changes: 27 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2532,3 +2532,30 @@ class GenericWrapper(Generic[P]):
def contains(c: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
def inherits(*args: P.args, **kwargs: P.kwargs) -> None: ...
[builtins fixtures/paramspec.pyi]

[case testCallbackProtocolClassObjectParamSpec]
from typing import Any, Callable, Protocol, Optional, Generic
from typing_extensions import ParamSpec

P = ParamSpec("P")

class App: ...

class MiddlewareFactory(Protocol[P]):
def __call__(self, app: App, /, *args: P.args, **kwargs: P.kwargs) -> App:
...

class Capture(Generic[P]): ...

class ServerErrorMiddleware(App):
def __init__(
self,
app: App,
handler: Optional[str] = None,
debug: bool = False,
) -> None: ...

def fn(f: MiddlewareFactory[P]) -> Capture[P]: ...

reveal_type(fn(ServerErrorMiddleware)) # N: Revealed type is "__main__.Capture[[handler: Union[builtins.str, None] =, debug: builtins.bool =]]"
[builtins fixtures/paramspec.pyi]
9 changes: 8 additions & 1 deletion test-data/unit/fixtures/enum.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Minimal set of builtins required to work with Enums
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Iterator, Sequence, overload, Iterable

T = TypeVar('T')

Expand All @@ -13,6 +13,13 @@ class tuple(Generic[T]):
class int: pass
class str:
def __len__(self) -> int: pass
def __iter__(self) -> Iterator[str]: pass

class dict: pass
class ellipsis: pass

class list(Sequence[T]):
@overload
def __init__(self) -> None: pass
@overload
def __init__(self, x: Iterable[T]) -> None: pass

0 comments on commit 274af1c

Please sign in to comment.