diff --git a/src/nb_internals.h b/src/nb_internals.h index b168fdb7..1ffdf0dd 100644 --- a/src/nb_internals.h +++ b/src/nb_internals.h @@ -13,6 +13,9 @@ #include #include +#if defined(NB_FREE_THREADED) +#include +#endif #include #include #include @@ -236,6 +239,32 @@ struct NB_SHARD_ALIGNMENT nb_shard { #endif }; + +/** + * Wraps a std::atomic if free-threading is enabled, otherwise a raw value. + */ +#if defined(NB_FREE_THREADED) +template +struct nb_maybe_atomic { + nb_maybe_atomic(T v) : value(v) {} + + std::atomic value; + T load_acquire() { return value.load(std::memory_order_acquire); } + T load_relaxed() { return value.load(std::memory_order_relaxed); } + void store_release(T w) { value.store(w, std::memory_order_release); } +}; +#else +template +struct nb_maybe_atomic { + nb_maybe_atomic(T v) : value(v) {} + + T value; + T load_acquire() { return value; } + T load_relaxed() { return value; } + void store_release(T w) { value = w; } +}; +#endif + /** * `nb_internals` is the central data structure storing information related to * function/type bindings and instances. Separate nanobind extensions within the @@ -318,7 +347,7 @@ struct nb_internals { PyTypeObject *nb_func, *nb_method, *nb_bound_method; /// Property variant for static attributes (created on demand) - PyTypeObject *nb_static_property = nullptr; + nb_maybe_atomic nb_static_property = nullptr; descrsetfunc nb_static_property_descr_set = nullptr; #if defined(NB_FREE_THREADED) @@ -328,7 +357,7 @@ struct nb_internals { #endif /// N-dimensional array wrapper (created on demand) - PyTypeObject *nb_ndarray = nullptr; + nb_maybe_atomic nb_ndarray = nullptr; #if defined(NB_FREE_THREADED) nb_shard *shards = nullptr; diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index cd118d76..74bcd8fd 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -175,11 +175,11 @@ static PyMethodDef nb_ndarray_members[] = { static PyTypeObject *nd_ndarray_tp() noexcept { nb_internals *internals_ = internals; - PyTypeObject *tp = internals_->nb_ndarray; + PyTypeObject *tp = internals_->nb_ndarray.load_acquire(); if (NB_UNLIKELY(!tp)) { lock_internals guard(internals_); - tp = internals_->nb_ndarray; + tp = internals_->nb_ndarray.load_relaxed(); if (tp) return tp; @@ -209,7 +209,7 @@ static PyTypeObject *nd_ndarray_tp() noexcept { tp->tp_as_buffer->bf_releasebuffer = nb_ndarray_releasebuffer; #endif - internals_->nb_ndarray = tp; + internals_->nb_ndarray.store_release(tp); } return tp; diff --git a/src/nb_static_property.cpp b/src/nb_static_property.cpp index 0c8bc639..51be6610 100644 --- a/src/nb_static_property.cpp +++ b/src/nb_static_property.cpp @@ -30,12 +30,12 @@ static int nb_static_property_descr_set(PyObject *self, PyObject *obj, PyObject PyTypeObject *nb_static_property_tp() noexcept { nb_internals *internals_ = internals; - PyTypeObject *tp = internals_->nb_static_property; + PyTypeObject *tp = internals_->nb_static_property.load_acquire(); if (NB_UNLIKELY(!tp)) { lock_internals guard(internals_); - tp = internals_->nb_static_property; + tp = internals_->nb_static_property.load_relaxed(); if (tp) return tp; @@ -65,8 +65,8 @@ PyTypeObject *nb_static_property_tp() noexcept { tp = (PyTypeObject *) PyType_FromSpec(&spec); check(tp, "nb_static_property type creation failed!"); - internals_->nb_static_property = tp; internals_->nb_static_property_descr_set = nb_static_property_descr_set; + internals_->nb_static_property.store_release(tp); } return tp; diff --git a/src/nb_type.cpp b/src/nb_type.cpp index a536416a..94e07d56 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -512,7 +512,7 @@ int nb_type_setattro(PyObject* obj, PyObject* name, PyObject* value) { #endif if (cur) { - PyTypeObject *tp = int_p->nb_static_property; + PyTypeObject *tp = int_p->nb_static_property.load_acquire(); // For type.static_prop = value, call the setter. // For type.static_prop = another_static_prop, replace the descriptor. if (Py_TYPE(cur) == tp && Py_TYPE(value) != tp) {