diff --git a/include/nanobind/nb_attr.h b/include/nanobind/nb_attr.h index 57db1399..055cab57 100644 --- a/include/nanobind/nb_attr.h +++ b/include/nanobind/nb_attr.h @@ -48,6 +48,7 @@ template struct call_guard { }; struct dynamic_attr {}; +struct weak_referenceable {}; struct is_method {}; struct is_implicit {}; struct is_operator {}; diff --git a/include/nanobind/nb_class.h b/include/nanobind/nb_class.h index d2e408e6..1ae63036 100644 --- a/include/nanobind/nb_class.h +++ b/include/nanobind/nb_class.h @@ -49,6 +49,8 @@ enum class type_flags : uint32_t { /// If so, type_data::keep_shared_from_this_alive is also set. has_shared_from_this = (1 << 12), + is_weak_referenceable = (1 << 13), + // Six more flag bits available (13 through 18) without needing // a larger reorganization }; @@ -98,6 +100,7 @@ struct type_data { bool (*keep_shared_from_this_alive)(PyObject *) noexcept; #if defined(Py_LIMITED_API) size_t dictoffset; + size_t weaklistoffset; #endif }; @@ -152,6 +155,10 @@ NB_INLINE void type_extra_apply(type_init_data &t, dynamic_attr) { t.flags |= (uint32_t) type_flags::has_dynamic_attr; } +NB_INLINE void type_extra_apply(type_data & t, weak_referenceable) { + t.flags |= (uint32_t)type_flags::is_weak_referenceable; +} + template NB_INLINE void type_extra_apply(type_init_data &t, supplement) { static_assert(std::is_trivially_default_constructible_v, diff --git a/src/nb_type.cpp b/src/nb_type.cpp index 39c05ccc..16b3d75c 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -18,23 +18,35 @@ NAMESPACE_BEGIN(detail) static PyObject **nb_dict_ptr(PyObject *self) { PyTypeObject *tp = Py_TYPE(self); -#if !defined(Py_LIMITED_API) - return (PyObject **) ((uint8_t *) self + tp->tp_dictoffset); +#if defined(Py_LIMITED_API) + Py_ssize_t dictoffset = nb_type_data(tp)->dictoffset; +#else + Py_ssize_t dictoffset = tp->tp_dictoffset; +#endif + return dictoffset ? (PyObject**)((uint8_t*)self + dictoffset) : nullptr; +} + +static PyObject **nb_weaklist_ptr(PyObject * self) { + PyTypeObject *tp = Py_TYPE(self); +#if defined(Py_LIMITED_API) + Py_ssize_t weaklistoffset = nb_type_data(tp)->weaklistoffset; #else - return (PyObject **) ((uint8_t *) self + nb_type_data(tp)->dictoffset); + Py_ssize_t weaklistoffset = tp->tp_weaklistoffset; #endif + return weaklistoffset ? (PyObject**)((uint8_t*)self + weaklistoffset) : nullptr; } static int inst_clear(PyObject *self) { - PyObject *&dict = *nb_dict_ptr(self); - Py_CLEAR(dict); + PyObject **dict = nb_dict_ptr(self); + if (dict) + Py_CLEAR(*dict); return 0; } static int inst_traverse(PyObject *self, visitproc visit, void *arg) { - PyObject *&dict = *nb_dict_ptr(self); + PyObject **dict = nb_dict_ptr(self); if (dict) - Py_VISIT(dict); + Py_VISIT(*dict); #if PY_VERSION_HEX >= 0x03090000 Py_VISIT(Py_TYPE(self)); #endif @@ -183,12 +195,24 @@ static void inst_dealloc(PyObject *self) { if (NB_UNLIKELY(gc)) { PyObject_GC_UnTrack(self); - if (t->flags & (uint32_t) type_flags::has_dynamic_attr) { - PyObject *&dict = *nb_dict_ptr(self); - Py_CLEAR(dict); + if (t->flags & (uint32_t)type_flags::has_dynamic_attr) { + PyObject **dict = nb_dict_ptr(self); + if (dict) + Py_CLEAR(*dict); } } + if (t->flags & (uint32_t)type_flags::is_weak_referenceable && + nb_weaklist_ptr(self) != nullptr) { +#if defined(PYPY_VERSION) + PyObject **weaklist = nb_weaklist_ptr(self); + if (weaklist) + Py_CLEAR(*weaklist); +#else + PyObject_ClearWeakRefs(self); +#endif + } + nb_inst *inst = (nb_inst *) self; void *p = inst_ptr(inst); @@ -771,6 +795,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { has_type_slots = t->flags & (uint32_t) type_init_flags::has_type_slots, has_supplement = t->flags & (uint32_t) type_init_flags::has_supplement, has_dynamic_attr = t->flags & (uint32_t) type_flags::has_dynamic_attr, + is_weak_referenceable = t->flags & (uint32_t)type_flags::is_weak_referenceable, intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr, has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this; @@ -834,6 +859,9 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { if (tb->flags & (uint32_t) type_flags::has_dynamic_attr) has_dynamic_attr = true; + if (tb->flags & (uint32_t)type_flags::is_weak_referenceable) + is_weak_referenceable = true; + /* Handle a corner case (base class larger than derived class) which can arise when extending trampoline base classes */ size_t base_basicsize = sizeof(nb_inst) + tb->size; @@ -853,7 +881,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { nb_total_slots = nb_type_max_slots + nb_extra_slots + 1; - PyMemberDef members[2] { }; + PyMemberDef members[3] { }; PyType_Slot slots[nb_total_slots], *s = slots; PyType_Spec spec = { /* .name = */ name_copy, @@ -898,26 +926,48 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { for (PyType_Slot *ts = slots; ts != s; ++ts) has_traverse |= ts->slot == Py_tp_traverse; - if (has_dynamic_attr) { - // realign to sizeof(void*), add one pointer + Py_ssize_t dictoffset = 0; + Py_ssize_t weaklistoffset = 0; + int num_members = 0; + if (has_dynamic_attr || is_weak_referenceable) { basicsize = (basicsize + ptr_size - 1) / ptr_size * ptr_size; + } + if (has_dynamic_attr) { + dictoffset = (Py_ssize_t)basicsize; basicsize += ptr_size; - members[0] = PyMemberDef{ "__dictoffset__", T_PYSSIZET, - (Py_ssize_t) (basicsize - ptr_size), READONLY, - nullptr }; - *s++ = { Py_tp_members, (void *) members }; + members[num_members] = PyMemberDef{ "__dictoffset__", T_PYSSIZET, + dictoffset, READONLY, + nullptr }; + ++num_members; // Install GC traverse and clear routines if not inherited/overridden if (!has_traverse) { - *s++ = { Py_tp_traverse, (void *) inst_traverse }; - *s++ = { Py_tp_clear, (void *) inst_clear }; + *s++ = { Py_tp_traverse, (void*)inst_traverse }; + *s++ = { Py_tp_clear, (void*)inst_clear }; has_traverse = true; } + spec.basicsize = (int)basicsize; + } + if (is_weak_referenceable) { + weaklistoffset = (Py_ssize_t)basicsize; + basicsize += ptr_size; - spec.basicsize = (int) basicsize; + members[num_members] = PyMemberDef{ "__weaklistoffset__", T_PYSSIZET, + weaklistoffset, READONLY, + nullptr }; + ++num_members; + if (!has_traverse) { + *s++ = { Py_tp_traverse, (void*)inst_traverse }; + *s++ = { Py_tp_clear, (void*)inst_clear }; + has_traverse = true; + } + spec.basicsize = (int)basicsize; } + if (num_members > 0) + *s++ = { Py_tp_members, (void*)members }; + if (has_traverse) spec.flags |= Py_TPFLAGS_HAVE_GC; @@ -955,7 +1005,14 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { if (has_dynamic_attr) { to->flags |= (uint32_t) type_flags::has_dynamic_attr; #if defined(Py_LIMITED_API) - to->dictoffset = (size_t) (basicsize - ptr_size); + to->dictoffset = dictoffset; + #endif + } + + if (is_weak_referenceable) { + to->flags |= (uint32_t)type_flags::is_weak_referenceable; + #if defined(Py_LIMITED_API) + to->weaklistoffset = weaklistoffset; #endif } diff --git a/tests/test_classes.cpp b/tests/test_classes.cpp index 633aa171..b8ff7ff2 100644 --- a/tests/test_classes.cpp +++ b/tests/test_classes.cpp @@ -90,6 +90,10 @@ struct Wrapper { std::shared_ptr value; }; +struct StructWithWeakrefs : Struct { }; + +struct StructWithWeakrefsAndDynamicAttrs : Struct { }; + int wrapper_tp_traverse(PyObject *self, visitproc visit, void *arg) { Wrapper *w = nb::inst_ptr(self); @@ -554,4 +558,11 @@ NB_MODULE(test_classes_ext, m) { "get_incrementing_struct_value", [](IncrementingStruct &s) { return new Struct(s.i + 100); }, nb::keep_alive<0, 1>()); + + nb::class_(m, "StructWithWeakrefs", nb::weak_referenceable()) + .def(nb::init()); + + nb::class_(m, "StructWithWeakrefsAndDynamicAttrs", + nb::weak_referenceable(), nb::dynamic_attr()) + .def(nb::init()); } diff --git a/tests/test_classes.py b/tests/test_classes.py index eb0e9672..fa2b4e28 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -741,3 +741,25 @@ def test41_implicit_conversion_keep_alive(): assert d1 == [] assert d2 == [5] assert d3 == [106, 6] + +def test42_weak_references(): + import weakref + import gc + import time + o = t.StructWithWeakrefs(42) + w = weakref.ref(o) + assert w() is o + del o + gc.collect() + gc.collect() + assert w() is None + + p = t.StructWithWeakrefsAndDynamicAttrs(43) + p.a_dynamic_attr = 101 + w = weakref.ref(p) + assert w() is p + assert w().a_dynamic_attr == 101 + del p + gc.collect() + gc.collect() + assert w() is None