From 75b02ad1435aeb8f8d6f659327bca95005ebaee3 Mon Sep 17 00:00:00 2001 From: huangweiwu Date: Sun, 18 Feb 2024 16:47:18 +0800 Subject: [PATCH] Make instances weak-referenceable (#335) The logic closely follows that of internal attribute dictionaries and involves an additional pointer to store a weak reference list. --- include/nanobind/nb_attr.h | 1 + include/nanobind/nb_class.h | 8 +++ src/nb_type.cpp | 111 +++++++++++++++++++++++++++--------- tests/test_classes.cpp | 11 ++++ tests/test_classes.py | 22 +++++++ 5 files changed, 127 insertions(+), 26 deletions(-) diff --git a/include/nanobind/nb_attr.h b/include/nanobind/nb_attr.h index 9e0017748..74bad134f 100644 --- a/include/nanobind/nb_attr.h +++ b/include/nanobind/nb_attr.h @@ -48,6 +48,7 @@ template struct call_guard { }; struct dynamic_attr {}; +struct weak_referenceable {}; struct is_method {}; struct is_implicit {}; struct is_operator {}; diff --git a/include/nanobind/nb_class.h b/include/nanobind/nb_class.h index d2e408e6a..75a6b8f78 100644 --- a/include/nanobind/nb_class.h +++ b/include/nanobind/nb_class.h @@ -49,6 +49,9 @@ enum class type_flags : uint32_t { /// If so, type_data::keep_shared_from_this_alive is also set. has_shared_from_this = (1 << 12), + /// Instances of this type can be referenced by 'weakref' + is_weak_referenceable = (1 << 13), + // Six more flag bits available (13 through 18) without needing // a larger reorganization }; @@ -98,6 +101,7 @@ struct type_data { bool (*keep_shared_from_this_alive)(PyObject *) noexcept; #if defined(Py_LIMITED_API) size_t dictoffset; + size_t weaklistoffset; #endif }; @@ -152,6 +156,10 @@ NB_INLINE void type_extra_apply(type_init_data &t, dynamic_attr) { t.flags |= (uint32_t) type_flags::has_dynamic_attr; } +NB_INLINE void type_extra_apply(type_data & t, weak_referenceable) { + t.flags |= (uint32_t)type_flags::is_weak_referenceable; +} + template NB_INLINE void type_extra_apply(type_init_data &t, supplement) { static_assert(std::is_trivially_default_constructible_v, diff --git a/src/nb_type.cpp b/src/nb_type.cpp index a4b3610c5..454ac74e1 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -18,23 +18,35 @@ NAMESPACE_BEGIN(detail) static PyObject **nb_dict_ptr(PyObject *self) { PyTypeObject *tp = Py_TYPE(self); -#if !defined(Py_LIMITED_API) - return (PyObject **) ((uint8_t *) self + tp->tp_dictoffset); +#if defined(Py_LIMITED_API) + Py_ssize_t dictoffset = nb_type_data(tp)->dictoffset; #else - return (PyObject **) ((uint8_t *) self + nb_type_data(tp)->dictoffset); + Py_ssize_t dictoffset = tp->tp_dictoffset; #endif + return dictoffset ? (PyObject **) ((uint8_t *) self + dictoffset) : nullptr; +} + +static PyObject **nb_weaklist_ptr(PyObject *self) { + PyTypeObject *tp = Py_TYPE(self); +#if defined(Py_LIMITED_API) + Py_ssize_t weaklistoffset = nb_type_data(tp)->weaklistoffset; +#else + Py_ssize_t weaklistoffset = tp->tp_weaklistoffset; +#endif + return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; } static int inst_clear(PyObject *self) { - PyObject *&dict = *nb_dict_ptr(self); - Py_CLEAR(dict); + PyObject **dict = nb_dict_ptr(self); + if (dict) + Py_CLEAR(*dict); return 0; } static int inst_traverse(PyObject *self, visitproc visit, void *arg) { - PyObject *&dict = *nb_dict_ptr(self); + PyObject **dict = nb_dict_ptr(self); if (dict) - Py_VISIT(dict); + Py_VISIT(*dict); #if PY_VERSION_HEX >= 0x03090000 Py_VISIT(Py_TYPE(self)); #endif @@ -183,12 +195,24 @@ static void inst_dealloc(PyObject *self) { if (NB_UNLIKELY(gc)) { PyObject_GC_UnTrack(self); - if (t->flags & (uint32_t) type_flags::has_dynamic_attr) { - PyObject *&dict = *nb_dict_ptr(self); - Py_CLEAR(dict); + if (t->flags & (uint32_t)type_flags::has_dynamic_attr) { + PyObject **dict = nb_dict_ptr(self); + if (dict) + Py_CLEAR(*dict); } } + if (t->flags & (uint32_t)type_flags::is_weak_referenceable && + nb_weaklist_ptr(self) != nullptr) { +#if defined(PYPY_VERSION) + PyObject **weaklist = nb_weaklist_ptr(self); + if (weaklist) + Py_CLEAR(*weaklist); +#else + PyObject_ClearWeakRefs(self); +#endif + } + nb_inst *inst = (nb_inst *) self; void *p = inst_ptr(inst); @@ -765,14 +789,15 @@ static PyTypeObject *nb_type_tp(size_t supplement) noexcept { /// Called when a C++ type is bound via nb::class_<> PyObject *nb_type_new(const type_init_data *t) noexcept { - bool has_doc = t->flags & (uint32_t) type_init_flags::has_doc, - has_base = t->flags & (uint32_t) type_init_flags::has_base, - has_base_py = t->flags & (uint32_t) type_init_flags::has_base_py, - has_type_slots = t->flags & (uint32_t) type_init_flags::has_type_slots, - has_supplement = t->flags & (uint32_t) type_init_flags::has_supplement, - has_dynamic_attr = t->flags & (uint32_t) type_flags::has_dynamic_attr, - intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr, - has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this; + bool has_doc = t->flags & (uint32_t) type_init_flags::has_doc, + has_base = t->flags & (uint32_t) type_init_flags::has_base, + has_base_py = t->flags & (uint32_t) type_init_flags::has_base_py, + has_type_slots = t->flags & (uint32_t) type_init_flags::has_type_slots, + has_supplement = t->flags & (uint32_t) type_init_flags::has_supplement, + has_dynamic_attr = t->flags & (uint32_t) type_flags::has_dynamic_attr, + is_weak_referenceable = t->flags & (uint32_t) type_flags::is_weak_referenceable, + intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr, + has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this; str name(t->name), qualname = name; object modname; @@ -834,6 +859,9 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { if (tb->flags & (uint32_t) type_flags::has_dynamic_attr) has_dynamic_attr = true; + if (tb->flags & (uint32_t) type_flags::is_weak_referenceable) + is_weak_referenceable = true; + /* Handle a corner case (base class larger than derived class) which can arise when extending trampoline base classes */ size_t base_basicsize = sizeof(nb_inst) + tb->size; @@ -853,7 +881,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { nb_total_slots = nb_type_max_slots + nb_extra_slots + 1; - PyMemberDef members[2] { }; + PyMemberDef members[3] { }; PyType_Slot slots[nb_total_slots], *s = slots; PyType_Spec spec = { /* .name = */ name_copy, @@ -898,15 +926,20 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { for (PyType_Slot *ts = slots; ts != s; ++ts) has_traverse |= ts->slot == Py_tp_traverse; - if (has_dynamic_attr) { - // realign to sizeof(void*), add one pointer + Py_ssize_t dictoffset = 0, weaklistoffset = 0; + int num_members = 0; + + // realign to sizeof(void*) if needed + if (has_dynamic_attr || is_weak_referenceable) basicsize = (basicsize + ptr_size - 1) / ptr_size * ptr_size; + + if (has_dynamic_attr) { + dictoffset = (Py_ssize_t) basicsize; basicsize += ptr_size; - members[0] = PyMemberDef{ "__dictoffset__", T_PYSSIZET, - (Py_ssize_t) (basicsize - ptr_size), READONLY, - nullptr }; - *s++ = { Py_tp_members, (void *) members }; + members[num_members] = PyMemberDef{ "__dictoffset__", T_PYSSIZET, + dictoffset, READONLY, nullptr }; + ++num_members; // Install GC traverse and clear routines if not inherited/overridden if (!has_traverse) { @@ -914,10 +947,29 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { *s++ = { Py_tp_clear, (void *) inst_clear }; has_traverse = true; } + spec.basicsize = (int) basicsize; + } + + if (is_weak_referenceable) { + weaklistoffset = (Py_ssize_t) basicsize; + basicsize += ptr_size; + members[num_members] = PyMemberDef{ "__weaklistoffset__", T_PYSSIZET, + weaklistoffset, READONLY, nullptr }; + ++num_members; + + // Install GC traverse and clear routines if not inherited/overridden + if (!has_traverse) { + *s++ = { Py_tp_traverse, (void *) inst_traverse }; + *s++ = { Py_tp_clear, (void *) inst_clear }; + has_traverse = true; + } spec.basicsize = (int) basicsize; } + if (num_members > 0) + *s++ = { Py_tp_members, (void*)members }; + if (has_traverse) spec.flags |= Py_TPFLAGS_HAVE_GC; @@ -955,7 +1007,14 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { if (has_dynamic_attr) { to->flags |= (uint32_t) type_flags::has_dynamic_attr; #if defined(Py_LIMITED_API) - to->dictoffset = (size_t) (basicsize - ptr_size); + to->dictoffset = dictoffset; + #endif + } + + if (is_weak_referenceable) { + to->flags |= (uint32_t)type_flags::is_weak_referenceable; + #if defined(Py_LIMITED_API) + to->weaklistoffset = weaklistoffset; #endif } diff --git a/tests/test_classes.cpp b/tests/test_classes.cpp index 6efa5d649..513799994 100644 --- a/tests/test_classes.cpp +++ b/tests/test_classes.cpp @@ -90,6 +90,10 @@ struct Wrapper { std::shared_ptr value; }; +struct StructWithWeakrefs : Struct { }; + +struct StructWithWeakrefsAndDynamicAttrs : Struct { }; + int wrapper_tp_traverse(PyObject *self, visitproc visit, void *arg) { Wrapper *w = nb::inst_ptr(self); @@ -554,4 +558,11 @@ NB_MODULE(test_classes_ext, m) { "get_incrementing_struct_value", [](IncrementingStruct &s) { return new Struct(s.i + 100); }, nb::keep_alive<0, 1>()); + + nb::class_(m, "StructWithWeakrefs", nb::weak_referenceable()) + .def(nb::init()); + + nb::class_(m, "StructWithWeakrefsAndDynamicAttrs", + nb::weak_referenceable(), nb::dynamic_attr()) + .def(nb::init()); } diff --git a/tests/test_classes.py b/tests/test_classes.py index 84efb41f3..0c37dc88e 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -759,3 +759,25 @@ def test41_implicit_conversion_keep_alive(): assert d1 == [] assert d2 == [5] assert d3 == [106, 6] + +def test42_weak_references(): + import weakref + import gc + import time + o = t.StructWithWeakrefs(42) + w = weakref.ref(o) + assert w() is o + del o + gc.collect() + gc.collect() + assert w() is None + + p = t.StructWithWeakrefsAndDynamicAttrs(43) + p.a_dynamic_attr = 101 + w = weakref.ref(p) + assert w() is p + assert w().a_dynamic_attr == 101 + del p + gc.collect() + gc.collect() + assert w() is None