Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values #125735

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def __set_name__(self, enum_class, member_name):
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_.setdefault(value, enum_member)
if value not in enum_class._hashable_values_:
enum_class._hashable_values_.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
enum_class._unhashable_values_.append(value)
Expand Down Expand Up @@ -538,7 +540,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
classdict['_member_names_'] = []
classdict['_member_map_'] = {}
classdict['_value2member_map_'] = {}
classdict['_unhashable_values_'] = []
classdict['_hashable_values_'] = [] # for comparing with non-hashable types
classdict['_unhashable_values_'] = [] # e.g. frozenset() with set()
classdict['_unhashable_values_map_'] = {}
classdict['_member_type_'] = member_type
# now set the __repr__ for the value
Expand Down Expand Up @@ -748,7 +751,10 @@ def __contains__(cls, value):
try:
return value in cls._value2member_map_
except TypeError:
return value in cls._unhashable_values_
return (
value in cls._unhashable_values_ # both structures are lists
or value in cls._hashable_values_
)

def __delattr__(cls, attr):
# nicer error message when someone tries to delete an attribute
Expand Down Expand Up @@ -1166,8 +1172,11 @@ def __new__(cls, value):
pass
except TypeError:
# not there, now do long search -- O(n) behavior
for name, values in cls._unhashable_values_map_.items():
if value in values:
for name, unhashable_values in cls._unhashable_values_map_.items():
if value in unhashable_values:
return cls[name]
for name, member in cls._member_map_.items():
if value == member._value_:
return cls[name]
# still not found -- verify that members exist, in-case somebody got here mistakenly
# (such as via super when trying to override __new__)
Expand Down Expand Up @@ -1233,6 +1242,7 @@ def _add_value_alias_(self, value):
# to the map, and by-value lookups for this value will be
# linear.
cls._value2member_map_.setdefault(value, self)
cls._hashable_values_.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
cls._unhashable_values_.append(value)
Expand Down Expand Up @@ -1763,6 +1773,7 @@ def convert_class(cls):
body['_member_names_'] = member_names = []
body['_member_map_'] = member_map = {}
body['_value2member_map_'] = value2member_map = {}
body['_hashable_values_'] = hashable_values = []
body['_unhashable_values_'] = unhashable_values = []
body['_unhashable_values_map_'] = {}
body['_member_type_'] = member_type = etype._member_type_
Expand Down Expand Up @@ -1826,7 +1837,7 @@ def convert_class(cls):
contained = value2member_map.get(member._value_)
except TypeError:
contained = None
if member._value_ in unhashable_values:
if member._value_ in unhashable_values or member.value in hashable_values:
for m in enum_class:
if m._value_ == member._value_:
contained = m
Expand All @@ -1846,6 +1857,7 @@ def convert_class(cls):
else:
enum_class._add_member_(name, member)
value2member_map[value] = member
hashable_values.append(value)
if _is_single_bit(value):
# not a multi-bit alias, record in _member_names_ and _flag_mask_
member_names.append(name)
Expand Down Expand Up @@ -1882,7 +1894,7 @@ def convert_class(cls):
contained = value2member_map.get(member._value_)
except TypeError:
contained = None
if member._value_ in unhashable_values:
if member._value_ in unhashable_values or member._value_ in hashable_values:
for m in enum_class:
if m._value_ == member._value_:
contained = m
Expand All @@ -1908,6 +1920,8 @@ def convert_class(cls):
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_.setdefault(value, member)
if value not in hashable_values:
hashable_values.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
enum_class._unhashable_values_.append(value)
Expand Down
7 changes: 7 additions & 0 deletions Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3460,6 +3460,13 @@ def test_empty_names(self):
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', names=0)
self.assertRaisesRegex(TypeError, '.int. object is not iterable', Enum, 'bad_enum', 0, type=int)

def test_nonhashable_matches_hashable(self): # issue 125710
class Directions(Enum):
DOWN_ONLY = frozenset({"sc"})
UP_ONLY = frozenset({"cs"})
UNRESTRICTED = frozenset({"sc", "cs"})
self.assertIs(Directions({"sc"}), Directions.DOWN_ONLY)


class TestOrder(unittest.TestCase):
"test usage of the `_order_` attribute"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Enum] fix hashable<->nonhashable comparisons for member values
Loading