Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference when class and instance match protocol #18587

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,40 +756,40 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
"__call__", template, actual, is_operator=True
)
assert call is not None
if mypy.subtypes.is_subtype(actual, erase_typevars(call)):
subres = infer_constraints(call, actual, self.direction)
res.extend(subres)
if (
self.direction == SUPERTYPE_OF
and mypy.subtypes.is_subtype(actual, erase_typevars(call))
or self.direction == SUBTYPE_OF
and mypy.subtypes.is_subtype(erase_typevars(call), actual)
):
res.extend(infer_constraints(call, actual, self.direction))
template.type.inferring.pop()
if isinstance(actual, CallableType) and actual.fallback is not None:
if actual.is_type_obj() and template.type.is_protocol:
if (
actual.is_type_obj()
and template.type.is_protocol
and self.direction == SUPERTYPE_OF
):
ret_type = get_proper_type(actual.ret_type)
if isinstance(ret_type, TupleType):
ret_type = mypy.typeops.tuple_fallback(ret_type)
if isinstance(ret_type, Instance):
if self.direction == SUBTYPE_OF:
subtype = template
else:
subtype = ret_type
res.extend(
self.infer_constraints_from_protocol_members(
ret_type, template, subtype, template, class_obj=True
ret_type, template, ret_type, template, class_obj=True
)
)
actual = actual.fallback
if isinstance(actual, TypeType) and template.type.is_protocol:
if isinstance(actual.item, Instance):
if self.direction == SUBTYPE_OF:
subtype = template
else:
subtype = actual.item
res.extend(
self.infer_constraints_from_protocol_members(
actual.item, template, subtype, template, class_obj=True
)
)
if self.direction == SUPERTYPE_OF:
# Infer constraints for Type[T] via metaclass of T when it makes sense.
a_item = actual.item
if isinstance(a_item, Instance):
res.extend(
self.infer_constraints_from_protocol_members(
a_item, template, a_item, template, class_obj=True
)
)
# Infer constraints for Type[T] via metaclass of T when it makes sense.
if isinstance(a_item, TypeVarType):
a_item = get_proper_type(a_item.upper_bound)
if isinstance(a_item, Instance) and a_item.type.metaclass_type:
Expand Down Expand Up @@ -1043,6 +1043,17 @@ def infer_constraints_from_protocol_members(
return [] # See #11020
# The above is safe since at this point we know that 'instance' is a subtype
# of (erased) 'template', therefore it defines all protocol members
if class_obj:
# For class objects we must only infer constraints if possible, otherwise it
# can lead to confusion between class and instance, for example StrEnum is
# Iterable[str] for an instance, but Iterable[StrEnum] for a class object.
if not mypy.subtypes.is_subtype(
inst, erase_typevars(temp), ignore_pos_arg_names=True
):
continue
# This exception matches the one in subtypes.py, see PR #14121 for context.
if member == "__call__" and instance.type.is_metaclass():
continue
res.extend(infer_constraints(temp, inst, self.direction))
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol):
# Settable members are invariant, add opposite constraints
Expand Down
12 changes: 11 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType:

def visit_parameters(self, t: Parameters) -> ProperType:
if isinstance(self.s, Parameters):
if len(t.arg_types) != len(self.s.arg_types):
if not is_similar_params(t, self.s):
# TODO: it would be prudent to return [*object, **object] instead of Any.
return self.default(self.s)
from mypy.meet import meet_types

Expand Down Expand Up @@ -724,6 +725,15 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool:
)


def is_similar_params(t: Parameters, s: Parameters) -> bool:
# This matches the logic in is_similar_callables() above.
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and (t.var_arg() is not None) == (s.var_arg() is not None)
)


def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
tv_map = {}
tvs = []
Expand Down
9 changes: 7 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,11 +1719,16 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
):
return False

if trivial_suffix:
# For trivial right suffix we *only* check that every non-star right argument
# has a valid match on the left.
return True

# Phase 1c: Check var args. Right has an infinite series of optional positional
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None and not trivial_suffix:
if right_star is not None:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand All @@ -1750,7 +1755,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1d: Check kw args. Right has an infinite series of optional named
# arguments. Get all further named args of left, and make sure
# they're more general than the corresponding member in right.
if right_star2 is not None and not trivial_suffix:
if right_star2 is not None:
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
Expand Down
22 changes: 22 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -2394,3 +2394,25 @@ def do_check(value: E) -> None:

[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testStrEnumClassCorrectIterable]
from enum import StrEnum
from typing import Type, TypeVar

class Choices(StrEnum):
LOREM = "lorem"
IPSUM = "ipsum"

var = list(Choices)
reveal_type(var) # N: Revealed type is "builtins.list[__main__.Choices]"

e: type[StrEnum]
reveal_type(list(e)) # N: Revealed type is "builtins.list[enum.StrEnum]"

T = TypeVar("T", bound=StrEnum)
def list_vals(e: Type[T]) -> list[T]:
reveal_type(list(e)) # N: Revealed type is "builtins.list[T`-1]"
return list(e)

reveal_type(list_vals(Choices)) # N: Revealed type is "builtins.list[__main__.Choices]"
[builtins fixtures/enum.pyi]
30 changes: 26 additions & 4 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,38 @@ if int():
h = h

[case testSubtypingFunctionsDoubleCorrespondence]
def l(x) -> None: ...
def r(__x, *, x) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]")

[case testSubtypingFunctionsDoubleCorrespondenceNamedOptional]
def l(x) -> None: ...
def r(__, *, x) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]")
def r(__x, *, x = 1) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]")

[case testSubtypingFunctionsRequiredLeftArgNotPresent]
[case testSubtypingFunctionsDoubleCorrespondenceBothNamedOptional]
def l(x = 1) -> None: ...
def r(__x, *, x = 1) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]")

[case testSubtypingFunctionsTrivialSuffixRequired]
def l(__x) -> None: ...
def r(x, *args, **kwargs) -> None: ...

r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Arg(Any, 'x'), VarArg(Any), KwArg(Any)], None]")
[builtins fixtures/dict.pyi]

[case testSubtypingFunctionsTrivialSuffixOptional]
def l(__x = 1) -> None: ...
def r(x = 1, *args, **kwargs) -> None: ...

r = l # E: Incompatible types in assignment (expression has type "Callable[[DefaultArg(Any)], None]", variable has type "Callable[[DefaultArg(Any, 'x'), VarArg(Any), KwArg(Any)], None]")
[builtins fixtures/dict.pyi]

[case testSubtypingFunctionsRequiredLeftArgNotPresent]
def l(x, y) -> None: ...
def r(x) -> None: ...
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]")
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]")

[case testSubtypingFunctionsImplicitNames]
from typing import Any
Expand Down
27 changes: 27 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2532,3 +2532,30 @@ class GenericWrapper(Generic[P]):
def contains(c: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
def inherits(*args: P.args, **kwargs: P.kwargs) -> None: ...
[builtins fixtures/paramspec.pyi]

[case testCallbackProtocolClassObjectParamSpec]
from typing import Any, Callable, Protocol, Optional, Generic
from typing_extensions import ParamSpec

P = ParamSpec("P")

class App: ...

class MiddlewareFactory(Protocol[P]):
def __call__(self, app: App, /, *args: P.args, **kwargs: P.kwargs) -> App:
...

class Capture(Generic[P]): ...

class ServerErrorMiddleware(App):
def __init__(
self,
app: App,
handler: Optional[str] = None,
debug: bool = False,
) -> None: ...

def fn(f: MiddlewareFactory[P]) -> Capture[P]: ...

reveal_type(fn(ServerErrorMiddleware)) # N: Revealed type is "__main__.Capture[[handler: Union[builtins.str, None] =, debug: builtins.bool =]]"
[builtins fixtures/paramspec.pyi]
9 changes: 8 additions & 1 deletion test-data/unit/fixtures/enum.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Minimal set of builtins required to work with Enums
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Iterator, Sequence, overload, Iterable

T = TypeVar('T')

Expand All @@ -13,6 +13,13 @@ class tuple(Generic[T]):
class int: pass
class str:
def __len__(self) -> int: pass
def __iter__(self) -> Iterator[str]: pass

class dict: pass
class ellipsis: pass

class list(Sequence[T]):
@overload
def __init__(self) -> None: pass
@overload
def __init__(self, x: Iterable[T]) -> None: pass