Skip to content

Commit

Permalink
pythongh-118033: Fix __weakref__ not set for generic dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn committed Apr 19, 2024
1 parent 4605a19 commit 7c3d881
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,10 @@ def _get_slots(cls):
match cls.__dict__.get('__slots__'):
# A class which does not define __slots__ at all is equivalent
# to a class defining __slots__ = ('__dict__', '__weakref__')
case None if getattr(cls, '__weakrefoffset__', -1) == 0:
# Except for special cases, inheriting from them do not set
# any slots at all:
yield from ()
case None:
yield from ('__dict__', '__weakref__')
case str(slot):
Expand Down Expand Up @@ -1228,6 +1232,7 @@ def _add_slots(cls, is_frozen, weakref_slot):
inherited_slots = set(
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
)
print(inherited_slots)
# The slots for our class. Remove slots from our base classes. Add
# '__weakref__' if weakref_slot was given, unless it is already present.
cls_dict["__slots__"] = tuple(
Expand Down
95 changes: 95 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3515,8 +3515,103 @@ class A:
class B(A):
pass

self.assertEqual(B.__slots__, ())
B()

def test_dataclass_derived_generic(self):
T = typing.TypeVar('T')

@dataclass(slots=True, weakref_slot=True)
class A(typing.Generic[T]):
pass
self.assertEqual(A.__slots__, ('__weakref__',))
self.assertTrue(A.__weakref__)
A()

@dataclass(slots=True, weakref_slot=True)
class B[T2]:
pass
self.assertEqual(B.__slots__, ('__weakref__',))
self.assertTrue(B.__weakref__)
B()

def test_dataclass_derived_generic_from_base(self):
T = typing.TypeVar('T')

class RawBase: ...

@dataclass(slots=True, weakref_slot=True)
class C1(typing.Generic[T], RawBase):
pass
self.assertEqual(C1.__slots__, ())
self.assertTrue(C1.__weakref__)
C1()
@dataclass(slots=True, weakref_slot=True)
class C2(RawBase, typing.Generic[T]):
pass
self.assertEqual(C2.__slots__, ())
self.assertTrue(C2.__weakref__)
C2()

@dataclass(slots=True, weakref_slot=True)
class D[T2](RawBase):
pass
self.assertEqual(D.__slots__, ())
self.assertTrue(D.__weakref__)
D()

def test_dataclass_derived_generic_from_slotted_base(self):
T = typing.TypeVar('T')

class WithSlots:
__slots__ = ('a', 'b')

@dataclass(slots=True, weakref_slot=True)
class E1(WithSlots, Generic[T]):
pass
self.assertEqual(E1.__slots__, ('__weakref__',))
self.assertTrue(E1.__weakref__)
E1()
@dataclass(slots=True, weakref_slot=True)
class E2(Generic[T], WithSlots):
pass
self.assertEqual(E2.__slots__, ('__weakref__',))
self.assertTrue(E2.__weakref__)
E2()

@dataclass(slots=True, weakref_slot=True)
class F[T2](WithSlots):
pass
self.assertEqual(F.__slots__, ('__weakref__',))
self.assertTrue(F.__weakref__)
F()

def test_dataclass_derived_generic_from_slotted_base(self):
T = typing.TypeVar('T')

class WithWeakrefSlot:
__slots__ = ('__weakref__',)

@dataclass(slots=True, weakref_slot=True)
class G1(WithWeakrefSlot, Generic[T]):
pass
self.assertEqual(G1.__slots__, ())
self.assertTrue(G1.__weakref__)
G1()
@dataclass(slots=True, weakref_slot=True)
class G2(Generic[T], WithWeakrefSlot):
pass
self.assertEqual(G2.__slots__, ())
self.assertTrue(G2.__weakref__)
G2()

@dataclass(slots=True, weakref_slot=True)
class H[T2](WithWeakrefSlot):
pass
self.assertEqual(H.__slots__, ())
self.assertTrue(H.__weakref__)
H()


class TestDescriptors(unittest.TestCase):
def test_set_name(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix :func:`dataclasses.dataclass` not creating a ``__weakref__`` slot when
subclassing :class:`typing.Generic`.

0 comments on commit 7c3d881

Please sign in to comment.