Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
huangweiwu committed Oct 23, 2023
2 parents b515b1f + 3000536 commit da50335
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 21 deletions.
1 change: 1 addition & 0 deletions include/nanobind/nb_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ template <typename... Ts> struct call_guard {
};

struct dynamic_attr {};
struct weak_referenceable {};
struct is_method {};
struct is_implicit {};
struct is_operator {};
Expand Down
7 changes: 7 additions & 0 deletions include/nanobind/nb_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ enum class type_flags : uint32_t {
/// If so, type_data::keep_shared_from_this_alive is also set.
has_shared_from_this = (1 << 12),

is_weak_referenceable = (1 << 13),

// Six more flag bits available (13 through 18) without needing
// a larger reorganization
};
Expand Down Expand Up @@ -98,6 +100,7 @@ struct type_data {
bool (*keep_shared_from_this_alive)(PyObject *) noexcept;
#if defined(Py_LIMITED_API)
size_t dictoffset;
size_t weaklistoffset;
#endif
};

Expand Down Expand Up @@ -152,6 +155,10 @@ NB_INLINE void type_extra_apply(type_init_data &t, dynamic_attr) {
t.flags |= (uint32_t) type_flags::has_dynamic_attr;
}

NB_INLINE void type_extra_apply(type_data & t, weak_referenceable) {
t.flags |= (uint32_t)type_flags::is_weak_referenceable;
}

template <typename T>
NB_INLINE void type_extra_apply(type_init_data &t, supplement<T>) {
static_assert(std::is_trivially_default_constructible_v<T>,
Expand Down
99 changes: 78 additions & 21 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,35 @@ NAMESPACE_BEGIN(detail)

static PyObject **nb_dict_ptr(PyObject *self) {
PyTypeObject *tp = Py_TYPE(self);
#if !defined(Py_LIMITED_API)
return (PyObject **) ((uint8_t *) self + tp->tp_dictoffset);
#if defined(Py_LIMITED_API)
Py_ssize_t dictoffset = nb_type_data(tp)->dictoffset;
#else
Py_ssize_t dictoffset = tp->tp_dictoffset;
#endif
return dictoffset ? (PyObject**)((uint8_t*)self + dictoffset) : nullptr;
}

static PyObject **nb_weaklist_ptr(PyObject * self) {
PyTypeObject *tp = Py_TYPE(self);
#if defined(Py_LIMITED_API)
Py_ssize_t weaklistoffset = nb_type_data(tp)->weaklistoffset;
#else
return (PyObject **) ((uint8_t *) self + nb_type_data(tp)->dictoffset);
Py_ssize_t weaklistoffset = tp->tp_weaklistoffset;
#endif
return weaklistoffset ? (PyObject**)((uint8_t*)self + weaklistoffset) : nullptr;
}

static int inst_clear(PyObject *self) {
PyObject *&dict = *nb_dict_ptr(self);
Py_CLEAR(dict);
PyObject **dict = nb_dict_ptr(self);
if (dict)
Py_CLEAR(*dict);
return 0;
}

static int inst_traverse(PyObject *self, visitproc visit, void *arg) {
PyObject *&dict = *nb_dict_ptr(self);
PyObject **dict = nb_dict_ptr(self);
if (dict)
Py_VISIT(dict);
Py_VISIT(*dict);
#if PY_VERSION_HEX >= 0x03090000
Py_VISIT(Py_TYPE(self));
#endif
Expand Down Expand Up @@ -183,12 +195,24 @@ static void inst_dealloc(PyObject *self) {
if (NB_UNLIKELY(gc)) {
PyObject_GC_UnTrack(self);

if (t->flags & (uint32_t) type_flags::has_dynamic_attr) {
PyObject *&dict = *nb_dict_ptr(self);
Py_CLEAR(dict);
if (t->flags & (uint32_t)type_flags::has_dynamic_attr) {
PyObject **dict = nb_dict_ptr(self);
if (dict)
Py_CLEAR(*dict);
}
}

if (t->flags & (uint32_t)type_flags::is_weak_referenceable &&
nb_weaklist_ptr(self) != nullptr) {
#if defined(PYPY_VERSION)
PyObject **weaklist = nb_weaklist_ptr(self);
if (weaklist)
Py_CLEAR(*weaklist);
#else
PyObject_ClearWeakRefs(self);
#endif
}

nb_inst *inst = (nb_inst *) self;
void *p = inst_ptr(inst);

Expand Down Expand Up @@ -771,6 +795,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
has_type_slots = t->flags & (uint32_t) type_init_flags::has_type_slots,
has_supplement = t->flags & (uint32_t) type_init_flags::has_supplement,
has_dynamic_attr = t->flags & (uint32_t) type_flags::has_dynamic_attr,
is_weak_referenceable = t->flags & (uint32_t)type_flags::is_weak_referenceable,
intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr,
has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this;

Expand Down Expand Up @@ -834,6 +859,9 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
if (tb->flags & (uint32_t) type_flags::has_dynamic_attr)
has_dynamic_attr = true;

if (tb->flags & (uint32_t)type_flags::is_weak_referenceable)
is_weak_referenceable = true;

/* Handle a corner case (base class larger than derived class)
which can arise when extending trampoline base classes */
size_t base_basicsize = sizeof(nb_inst) + tb->size;
Expand All @@ -853,7 +881,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
nb_total_slots = nb_type_max_slots +
nb_extra_slots + 1;

PyMemberDef members[2] { };
PyMemberDef members[3] { };
PyType_Slot slots[nb_total_slots], *s = slots;
PyType_Spec spec = {
/* .name = */ name_copy,
Expand Down Expand Up @@ -898,26 +926,48 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
for (PyType_Slot *ts = slots; ts != s; ++ts)
has_traverse |= ts->slot == Py_tp_traverse;

if (has_dynamic_attr) {
// realign to sizeof(void*), add one pointer
Py_ssize_t dictoffset = 0;
Py_ssize_t weaklistoffset = 0;
int num_members = 0;
if (has_dynamic_attr || is_weak_referenceable) {
basicsize = (basicsize + ptr_size - 1) / ptr_size * ptr_size;
}
if (has_dynamic_attr) {
dictoffset = (Py_ssize_t)basicsize;
basicsize += ptr_size;

members[0] = PyMemberDef{ "__dictoffset__", T_PYSSIZET,
(Py_ssize_t) (basicsize - ptr_size), READONLY,
nullptr };
*s++ = { Py_tp_members, (void *) members };
members[num_members] = PyMemberDef{ "__dictoffset__", T_PYSSIZET,
dictoffset, READONLY,
nullptr };
++num_members;

// Install GC traverse and clear routines if not inherited/overridden
if (!has_traverse) {
*s++ = { Py_tp_traverse, (void *) inst_traverse };
*s++ = { Py_tp_clear, (void *) inst_clear };
*s++ = { Py_tp_traverse, (void*)inst_traverse };
*s++ = { Py_tp_clear, (void*)inst_clear };
has_traverse = true;
}
spec.basicsize = (int)basicsize;
}
if (is_weak_referenceable) {
weaklistoffset = (Py_ssize_t)basicsize;
basicsize += ptr_size;

spec.basicsize = (int) basicsize;
members[num_members] = PyMemberDef{ "__weaklistoffset__", T_PYSSIZET,
weaklistoffset, READONLY,
nullptr };
++num_members;
if (!has_traverse) {
*s++ = { Py_tp_traverse, (void*)inst_traverse };
*s++ = { Py_tp_clear, (void*)inst_clear };
has_traverse = true;
}
spec.basicsize = (int)basicsize;
}

if (num_members > 0)
*s++ = { Py_tp_members, (void*)members };

if (has_traverse)
spec.flags |= Py_TPFLAGS_HAVE_GC;

Expand Down Expand Up @@ -955,7 +1005,14 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
if (has_dynamic_attr) {
to->flags |= (uint32_t) type_flags::has_dynamic_attr;
#if defined(Py_LIMITED_API)
to->dictoffset = (size_t) (basicsize - ptr_size);
to->dictoffset = dictoffset;
#endif
}

if (is_weak_referenceable) {
to->flags |= (uint32_t)type_flags::is_weak_referenceable;
#if defined(Py_LIMITED_API)
to->weaklistoffset = weaklistoffset;
#endif
}

Expand Down
11 changes: 11 additions & 0 deletions tests/test_classes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ struct Wrapper {
std::shared_ptr<Wrapper> value;
};

struct StructWithWeakrefs : Struct { };

struct StructWithWeakrefsAndDynamicAttrs : Struct { };

int wrapper_tp_traverse(PyObject *self, visitproc visit, void *arg) {
Wrapper *w = nb::inst_ptr<Wrapper>(self);

Expand Down Expand Up @@ -554,4 +558,11 @@ NB_MODULE(test_classes_ext, m) {
"get_incrementing_struct_value",
[](IncrementingStruct &s) { return new Struct(s.i + 100); },
nb::keep_alive<0, 1>());

nb::class_<StructWithWeakrefs, Struct>(m, "StructWithWeakrefs", nb::weak_referenceable())
.def(nb::init<int>());

nb::class_<StructWithWeakrefsAndDynamicAttrs, Struct>(m, "StructWithWeakrefsAndDynamicAttrs",
nb::weak_referenceable(), nb::dynamic_attr())
.def(nb::init<int>());
}
22 changes: 22 additions & 0 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,25 @@ def test41_implicit_conversion_keep_alive():
assert d1 == []
assert d2 == [5]
assert d3 == [106, 6]

def test42_weak_references():
import weakref
import gc
import time
o = t.StructWithWeakrefs(42)
w = weakref.ref(o)
assert w() is o
del o
gc.collect()
gc.collect()
assert w() is None

p = t.StructWithWeakrefsAndDynamicAttrs(43)
p.a_dynamic_attr = 101
w = weakref.ref(p)
assert w() is p
assert w().a_dynamic_attr == 101
del p
gc.collect()
gc.collect()
assert w() is None

0 comments on commit da50335

Please sign in to comment.