Skip to content

Commit

Permalink
fix tests by fixing the protocol member listing
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham committed Feb 26, 2024
1 parent e0b2991 commit e7c2511
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions tests/test_protocols.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools
import inspect
from types import ModuleType
from typing import Protocol, cast

Expand All @@ -23,6 +25,7 @@ def _is_runtime_protocol(cls: type) -> bool:
return _is_protocol(cls) and getattr(cls, '_is_runtime_protocol', False)


@functools.cache
def _get_protocol_members(cls: type) -> frozenset[str]:
"""
A variant of `typing_extensions.get_protocol_members()` that doesn't
Expand All @@ -33,7 +36,9 @@ def _get_protocol_members(cls: type) -> frozenset[str]:
assert _is_protocol(cls)

module = cls.__module__
members = cls.__annotations__.keys() | {
annotations = cls.__annotations__

members = annotations.keys() | {
name for name, v in vars(cls).items()
if (
callable(v) and (
Expand Down Expand Up @@ -63,19 +68,29 @@ def _get_protocol_members(cls: type) -> frozenset[str]:
and v.fget.__module__ == module
)
}
if not members:
# no idea why this happens, probably something to do with inheritance..
# ... investigating this can of worms any further will physically harm
# my very soul, or at least, what's left of it at this point.
# ...
# anyway, this hack here is plagiarized from the (often incorrect)
# `typing_extensions.get_protocol_members`.
# Maybe the `typing.get_protocol_members` that's coming in 3.13 will
# won't be as broken. I have little hope though...
members = cast(
set[str],
getattr(cls, '__protocol_attrs__', None) or set(),
)

# this hack here is plagiarized from the (often incorrect)
# `typing_extensions.get_protocol_members`.
# Maybe the `typing.get_protocol_member`s` that's coming in 3.13 will
# won't be as broken. I have little hope though...
members |= cast(
set[str],
getattr(cls, '__protocol_attrs__', None) or set(),
)

# sometimes __protocol_attrs__ hallicunates some non-existing dunders.
# the `getattr_static` avoids potential descriptor magic
members = {
member for member in members
if member in annotations
or inspect.getattr_static(cls, member) is not None
# or getattr(cls, member) is not None
}

# also include any of the parents
for supercls in cls.mro()[1:]:
if _is_protocol(supercls):
members |= _get_protocol_members(supercls)

return frozenset(members)

Expand Down

0 comments on commit e7c2511

Please sign in to comment.