From 2e191cbb6c8ab72a22061466ca9e652dffbe0a2f Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Mon, 4 Nov 2024 13:48:31 -0700 Subject: [PATCH 1/2] Add support for custom call policies --- docs/api_core.rst | 111 +++++++++++++++++++++++++ docs/changelog.rst | 10 +++ docs/functions.rst | 7 ++ include/nanobind/nb_attr.h | 64 ++++++++++++--- include/nanobind/nb_func.h | 20 +++-- src/nb_func.cpp | 32 ++++---- tests/CMakeLists.txt | 2 + tests/test_callbacks.cpp | 137 +++++++++++++++++++++++++++++++ tests/test_callbacks.py | 58 +++++++++++++ tests/test_functions.cpp | 90 ++++++++++++++++++++ tests/test_functions.py | 75 +++++++++++++++++ tests/test_functions_ext.pyi.ref | 4 + 12 files changed, 573 insertions(+), 37 deletions(-) create mode 100644 tests/test_callbacks.cpp create mode 100644 tests/test_callbacks.py diff --git a/docs/api_core.rst b/docs/api_core.rst index 7cf15947..2b9d2d70 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1899,6 +1899,117 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, Analogous to :cpp:struct:`for_getter`, but for setters. +.. cpp:struct:: template call_policy + + Request that custom logic be inserted around each call to the + bound function, by calling ``Policy::precall(args, nargs, cleanup)`` before + Python-to-C++ argument conversion, and ``Policy::postcall(args, nargs, ret)`` + after C++-to-Python return value conversion. + + If multiple call policy annotations are provided for the same function, then + their precall and postcall hooks will both execute left-to-right according + to the order in which the annotations were specified when binding the + function. + + The :cpp:struct:`nb::call_guard\() ` annotation + should be preferred over ``call_policy`` unless the wrapper logic + depends on the function arguments or return value. + If both annotations are combined, then + :cpp:struct:`nb::call_guard\() ` always executes on + the "inside" (closest to the bound function, after argument + conversions and before return value conversion) regardless of its + position in the function annotations list. + + Your ``Policy`` class must define two static member functions: + + .. cpp:function:: static void precall(PyObject **args, size_t nargs, detail::cleanup_list *cleanup); + + A hook that will be invoked before calling the bound function. More + precisely, it is called after any :ref:`argument locks ` + have been obtained, but before the Python arguments are converted to C++ + objects for the function call. + + This hook may access or modify the function arguments using the + *args* array, which holds borrowed references in one-to-one + correspondence with the C++ arguments of the bound function. If + the bound function is a method, then ``args[0]`` is its *self* + argument. *nargs* is the number of function arguments. It is actually + passed as ``std::integral_constant()``, so you can + match on that type if you want to do compile-time checks with it. + + The *cleanup* list may be used as it is used in type casters, + to cause some Python object references to be released at some point + after the bound function completes. (If the bound function is part + of an overload set, the cleanup list isn't released until all overloads + have been tried.) + + ``precall()`` may choose to throw a C++ exception. If it does, + it will preempt execution of the bound function, and the + exception will be treated as if the bound function had thrown it. + + .. cpp:function:: static void postcall(PyObject **args, size_t nargs, handle ret); + + A hook that will be invoked after calling the bound function and + converting its return value to a Python object, but only if the + bound function returned normally. + + *args* stores the Python object arguments, with the same semantics + as in ``precall()``, except that arguments that participated in + implicit conversions will have had their ``args[i]`` pointer updated + to reflect the new Python object that the implicit conversion produced. + *nargs* is the number of arguments, passed as a ``std::integral_constant`` + in the same way as for ``precall()``. + + *ret* is the bound function's return value. If the bound function returned + normally but its C++ return value could not be converted to a Python + object, then ``postcall()`` will execute with *ret* set to null, + and the Python error indicator might or might not be set to explain why. + + If the bound function did not return normally -- either because its + Python object arguments couldn't be converted to the appropriate C++ + types, or because the C++ function threw an exception -- then + ``postcall()`` **will not execute**. If you need some cleanup logic to + run even in such cases, your ``precall()`` can add a capsule object to the + cleanup list; its destructor will run eventually, but with no promises + as to when. A :cpp:struct:`nb::call_guard ` might be a + better choice. + + ``postcall()`` may choose to throw a C++ exception. If it does, + the result of the wrapped function will be destroyed, + and the exception will be raised in its place, as if the bound function + had thrown it just before returning. + + Here is an example policy to demonstrate. + ``nb::call_policy>()`` behaves like + :cpp:class:`nb::keep_alive\<0, I\>() `, except that the + return value is a treated as a list of objects rather than a single one. + + .. code-block:: cpp + + template + struct returns_references_to { + static void precall(PyObject **, size_t, nb::detail::cleanup_list *) {} + + template + static void postcall(PyObject **args, + std::integral_constant, + nb::handle ret) { + static_assert(I > 0 && I < N, + "I in returns_references_to must be in the " + "range [1, number of C++ function arguments]"); + if (!nb::isinstance(ret)) { + throw std::runtime_error("return value should be a sequence"); + } + for (nb::handle nurse : ret) { + nb::detail::keep_alive(nurse.ptr(), args[I]); + } + } + }; + + For a more complex example (binding an object that uses trivially-copyable + callbacks), see ``tests/test_callbacks.cpp`` in the nanobind source + distribution. + .. _class_binding_annotations: Class binding annotations diff --git a/docs/changelog.rst b/docs/changelog.rst index bde4da7e..b98820e2 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -23,6 +23,16 @@ Version TBD (unreleased) ``std::variant`` might cast a Python object wrapping a ``T`` to a ``U`` if there is an implicit conversion available from ``T`` to ``U``. +- Added a function annotation :cpp:class:`nb::call_policy\() + ` which supports custom function wrapping logic, + calling ``Policy::precall()`` before the bound function and + ``Policy::postcall()`` after. This is a low-level interface intended + for advanced users. The precall and postcall hooks are able to + observe the Python objects forming the function arguments and return + value, and the precall hook can change the arguments. See the linked + documentation for more details, important caveats, and an example policy. + (PR `#767 `. Construction occurs left to right, while destruction occurs in reverse. +If your wrapping needs are more complex than +:cpp:class:`nb::call_guard\() ` can handle, it is also +possible to define a custom "call policy", which can observe or modify the +Python object arguments and observe the return value. See the documentation of +:cpp:class:`nb::call_policy\ ` for details. + + .. _higher_order_adv: Higher-order functions diff --git a/include/nanobind/nb_attr.h b/include/nanobind/nb_attr.h index 0476824e..99f628b0 100644 --- a/include/nanobind/nb_attr.h +++ b/include/nanobind/nb_attr.h @@ -151,6 +151,8 @@ struct sig { struct is_getter { }; +template struct call_policy final {}; + NAMESPACE_BEGIN(literals) constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } NAMESPACE_END(literals) @@ -186,8 +188,9 @@ enum class func_flags : uint32_t { return_ref = (1 << 15), /// Does this overload specify a custom function signature (for docstrings, typing) has_signature = (1 << 16), - /// Does this function have one or more nb::keep_alive() annotations? - has_keep_alive = (1 << 17) + /// Does this function potentially modify the elements of the PyObject*[] array + /// representing its arguments? (nb::keep_alive() or call_policy annotations) + can_mutate_args = (1 << 17) }; enum cast_flags : uint8_t { @@ -384,12 +387,17 @@ NB_INLINE void func_extra_apply(F &, call_guard, size_t &) {} template NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive, size_t &) { - f.flags |= (uint32_t) func_flags::has_keep_alive; + f.flags |= (uint32_t) func_flags::can_mutate_args; +} + +template +NB_INLINE void func_extra_apply(F &f, call_policy, size_t &) { + f.flags |= (uint32_t) func_flags::can_mutate_args; } template struct func_extra_info { using call_guard = void; - static constexpr bool keep_alive = false; + static constexpr bool pre_post_hooks = false; static constexpr size_t nargs_locked = 0; }; @@ -397,7 +405,7 @@ template struct func_extra_info : func_extra_info { }; template -struct func_extra_info, Ts...> : func_extra_info { +struct func_extra_info, Ts...> : func_extra_info { static_assert(std::is_same_v::call_guard, void>, "call_guard<> can only be specified once!"); using call_guard = nanobind::call_guard; @@ -405,29 +413,59 @@ struct func_extra_info, Ts...> : func_extra_info struct func_extra_info, Ts...> : func_extra_info { - static constexpr bool keep_alive = true; + static constexpr bool pre_post_hooks = true; +}; + +template +struct func_extra_info, Ts...> : func_extra_info { + static constexpr bool pre_post_hooks = true; }; template -struct func_extra_info : func_extra_info { +struct func_extra_info : func_extra_info { static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; }; template -struct func_extra_info : func_extra_info { +struct func_extra_info : func_extra_info { static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; }; -template -NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { } +NB_INLINE void process_precall(PyObject **, size_t, detail::cleanup_list *, void *) { } + +template +NB_INLINE void +process_precall(PyObject **args, std::integral_constant nargs, + detail::cleanup_list *cleanup, call_policy *) { + Policy::precall(args, nargs, cleanup); +} + +NB_INLINE void process_postcall(PyObject **, size_t, PyObject *, void *) { } -template +template NB_INLINE void -process_keep_alive(PyObject **args, PyObject *result, - nanobind::keep_alive *) { +process_postcall(PyObject **args, std::integral_constant, + PyObject *result, nanobind::keep_alive *) { + static_assert(Nurse != Patient, + "keep_alive with the same argument as both nurse and patient " + "doesn't make sense"); + static_assert(Nurse <= NArgs && Patient <= NArgs, + "keep_alive template parameters must be in the range " + "[0, number of C++ function arguments]"); keep_alive(Nurse == 0 ? result : args[Nurse - 1], Patient == 0 ? result : args[Patient - 1]); } +template +NB_INLINE void +process_postcall(PyObject **args, std::integral_constant nargs, + PyObject *result, call_policy *) { + // result_guard avoids leaking a reference to the return object + // if postcall throws an exception + object result_guard = steal(result); + Policy::postcall(args, nargs, handle(result)); + result_guard.release(); +} + NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/nb_func.h b/include/nanobind/nb_func.h index da420349..10eb3994 100644 --- a/include/nanobind/nb_func.h +++ b/include/nanobind/nb_func.h @@ -11,14 +11,14 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) template -bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags, - cleanup_list *cleanup, size_t index) { +bool from_python_remember_conv(Caster &c, PyObject **args, uint8_t *args_flags, + cleanup_list *cleanup, size_t index) { size_t size_before = cleanup->size(); if (!c.from_python(args[index], args_flags[index], cleanup)) return false; // If an implicit conversion took place, update the 'args' array so that - // the keep_alive annotation can later process this change + // any keep_alive annotation or postcall hook can be aware of this change size_t size_after = cleanup->size(); if (size_after != size_before) args[index] = (*cleanup)[size_after - 1]; @@ -244,9 +244,11 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), } #endif - if constexpr (Info::keep_alive) { - if ((!from_python_keep_alive(in.template get(), args, - args_flags, cleanup, Is) || ...)) + if constexpr (Info::pre_post_hooks) { + std::integral_constant nargs_c; + (process_precall(args, nargs_c, cleanup, (Extra *) nullptr), ...); + if ((!from_python_remember_conv(in.template get(), args, + args_flags, cleanup, Is) || ...)) return NB_NEXT_OVERLOAD; } else { if ((!in.template get().from_python(args[Is], args_flags[Is], @@ -276,8 +278,10 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), #endif } - if constexpr (Info::keep_alive) - (process_keep_alive(args, result, (Extra *) nullptr), ...); + if constexpr (Info::pre_post_hooks) { + std::integral_constant nargs_c; + (process_postcall(args, nargs_c, result, (Extra *) nullptr), ...); + } return result; }; diff --git a/src/nb_func.cpp b/src/nb_func.cpp index 08c4eefa..21883a62 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -196,21 +196,21 @@ PyObject *nb_func_new(const void *in_) noexcept { func_data_prelim<0> *f = (func_data_prelim<0> *) in_; arg_data *args_in = std::launder((arg_data *) f->args); - bool has_scope = f->flags & (uint32_t) func_flags::has_scope, - has_name = f->flags & (uint32_t) func_flags::has_name, - has_args = f->flags & (uint32_t) func_flags::has_args, - has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs, - has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args, - has_keep_alive = f->flags & (uint32_t) func_flags::has_keep_alive, - has_doc = f->flags & (uint32_t) func_flags::has_doc, - has_signature = f->flags & (uint32_t) func_flags::has_signature, - is_implicit = f->flags & (uint32_t) func_flags::is_implicit, - is_method = f->flags & (uint32_t) func_flags::is_method, - return_ref = f->flags & (uint32_t) func_flags::return_ref, - is_constructor = false, - is_init = false, - is_new = false, - is_setstate = false; + bool has_scope = f->flags & (uint32_t) func_flags::has_scope, + has_name = f->flags & (uint32_t) func_flags::has_name, + has_args = f->flags & (uint32_t) func_flags::has_args, + has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs, + has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args, + can_mutate_args = f->flags & (uint32_t) func_flags::can_mutate_args, + has_doc = f->flags & (uint32_t) func_flags::has_doc, + has_signature = f->flags & (uint32_t) func_flags::has_signature, + is_implicit = f->flags & (uint32_t) func_flags::is_implicit, + is_method = f->flags & (uint32_t) func_flags::is_method, + return_ref = f->flags & (uint32_t) func_flags::return_ref, + is_constructor = false, + is_init = false, + is_new = false, + is_setstate = false; PyObject *name = nullptr; PyObject *func_prev = nullptr; @@ -292,7 +292,7 @@ PyObject *nb_func_new(const void *in_) noexcept { maybe_make_immortal((PyObject *) func); // Check if the complex dispatch loop is needed - bool complex_call = has_keep_alive || has_var_kwargs || has_var_args || + bool complex_call = can_mutate_args || has_var_kwargs || has_var_args || f->nargs >= NB_MAXARGS_SIMPLE; if (has_args) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9e8fcb4e..c5cfef42 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -62,6 +62,7 @@ endif() set(TEST_NAMES functions + callbacks classes holders stl @@ -137,6 +138,7 @@ target_link_libraries(test_inter_module_2_ext PRIVATE inter_module) set(TEST_FILES common.py + test_callbacks.py test_classes.py test_eigen.py test_enum.py diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp new file mode 100644 index 00000000..f178baed --- /dev/null +++ b/tests/test_callbacks.cpp @@ -0,0 +1,137 @@ +// This is an example of using nb::call_policy to support binding an +// object that takes non-owning callbacks. Since the callbacks can't +// directly keep a Python object alive (they're trivially copyable), we +// maintain a sideband structure to manage the lifetimes. + +#include +#include +#include + +#include +#include + +namespace nb = nanobind; + +// The callback type accepted by the object, which we assume we can't change. +// It's trivially copyable, so it can't directly keep a Python object alive. +struct callback { + void *context; + void (*func)(void *context, int arg); + + void operator()(int arg) const { (*func)(context, arg); } + bool operator==(const callback& other) const { + return context == other.context && func == other.func; + } +}; + +// An object that uses these callbacks, which we want to write bindings for +class publisher { + public: + void subscribe(callback cb) { cbs.push_back(cb); } + void unsubscribe(callback cb) { + cbs.erase(std::remove(cbs.begin(), cbs.end(), cb), cbs.end()); + } + void emit(int arg) const { for (auto cb : cbs) cb(arg); } + private: + std::vector cbs; +}; + +template <> struct nanobind::detail::type_caster { + static void wrap_call(void *context, int arg) { + borrow((PyObject *) context)(arg); + } + bool from_python(handle src, uint8_t, cleanup_list*) noexcept { + if (!isinstance(src)) return false; + value = {(void *) src.ptr(), &wrap_call}; + return true; + } + static handle from_cpp(callback cb, rv_policy policy, cleanup_list*) noexcept { + if (cb.func == &wrap_call) + return handle((PyObject *) cb.context).inc_ref(); + if (policy == rv_policy::none) + return handle(); + return cpp_function(cb, policy).release(); + } + NB_TYPE_CASTER(callback, const_name("Callable[[int], None]")) +}; + +nb::dict cb_registry() { + return nb::cast( + nb::module_::import_("test_callbacks_ext").attr("registry")); +} + +struct callback_data { + struct py_hash { + size_t operator()(const nb::object& obj) const { return nb::hash(obj); } + }; + struct py_eq { + bool operator()(const nb::object& a, const nb::object& b) const { + return a.equal(b); + } + }; + std::unordered_set subscribers; +}; + +callback_data& callbacks_for(nb::handle publisher) { + auto registry = cb_registry(); + nb::weakref key(publisher, registry.attr("__delitem__")); + if (nb::handle value = PyDict_GetItem(registry.ptr(), key.ptr())) { + return nb::cast(value); + } + nb::object new_data = nb::cast(callback_data{}); + registry[key] = new_data; + return nb::cast(new_data); +} + +struct cb_policy_common { + using TwoArgs = std::integral_constant; + static void precall(PyObject **args, TwoArgs, + nb::detail::cleanup_list *cleanup) { + nb::handle self = args[0], cb = args[1]; + auto& cbs = callbacks_for(self); + auto it = cbs.subscribers.find(nb::borrow(cb)); + if (it != cbs.subscribers.end() && !it->is(cb)) { + // A callback is already subscribed that is + // equal-but-not-identical to the one passed in. + // Adjust args to refer to that one, to work around + // the fact that the C++ object does not understand py-equality. + args[1] = it->ptr(); + + // This ensures that the normalized callback won't be + // immediately destroyed if it's removed from the registry + // in the unsubscribe postcall hook. Such destruction could + // result in a use-after-free if you have other postcall hooks + // or keep_alives that try to inspect the function args. + // It's not strictly necessary if each arg is inspected by + // only one call policy or keep_alive. + cleanup->append(it->inc_ref().ptr()); + } + } +}; + +struct subscribe_policy : cb_policy_common { + static void postcall(PyObject **args, TwoArgs, nb::handle) { + nb::handle self = args[0], cb = args[1]; + callbacks_for(self).subscribers.insert(nb::borrow(cb)); + } +}; + +struct unsubscribe_policy : cb_policy_common { + static void postcall(PyObject **args, TwoArgs, nb::handle) { + nb::handle self = args[0], cb = args[1]; + callbacks_for(self).subscribers.erase(nb::borrow(cb)); + } +}; + +NB_MODULE(test_callbacks_ext, m) { + m.attr("registry") = nb::dict(); + nb::class_(m, "callback_data") + .def_ro("subscribers", &callback_data::subscribers); + nb::class_(m, "publisher", nb::is_weak_referenceable()) + .def(nb::init<>()) + .def("subscribe", &publisher::subscribe, + nb::call_policy()) + .def("unsubscribe", &publisher::unsubscribe, + nb::call_policy()) + .def("emit", &publisher::emit); +} diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 00000000..5f6796c3 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,58 @@ +import test_callbacks_ext as t +import gc + + +def test_callbacks(): + pub1 = t.publisher() + pub2 = t.publisher() + record = [] + + def sub1(x): + record.append(x + 10) + + def sub2(x): + record.append(x + 20) + + pub1.subscribe(sub1) + pub2.subscribe(sub2) + for pub in (pub1, pub2): + pub.subscribe(record.append) + + pub1.emit(1) + assert record == [11, 1] + del record[:] + + pub2.emit(2) + assert record == [22, 2] + del record[:] + + pub1_w, pub2_w = t.registry.keys() # weakrefs to pub1, pub2 + assert pub1_w() is pub1 + assert pub2_w() is pub2 + assert t.registry[pub1_w].subscribers == {sub1, record.append} + assert t.registry[pub2_w].subscribers == {sub2, record.append} + + # NB: this `record.append` is a different object than the one we subscribed + # above, so we're testing the normalization logic in unsubscribe_policy + pub1.unsubscribe(record.append) + assert t.registry[pub1_w].subscribers == {sub1} + pub1.emit(3) + assert record == [13] + del record[:] + + del pub, pub1 + gc.collect() + gc.collect() + assert pub1_w() is None + assert pub2_w() is pub2 + assert t.registry.keys() == {pub2_w} + + pub2.emit(4) + assert record == [24, 4] + del record[:] + + del pub2 + gc.collect() + gc.collect() + assert pub2_w() is None + assert not t.registry diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 22450d87..0bc645d1 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -1,6 +1,9 @@ +#include + #include #include #include +#include namespace nb = nanobind; using namespace nb::literals; @@ -12,6 +15,74 @@ struct my_call_guard { ~my_call_guard() { call_guard_value = 2; } }; +// Example call policy for use with nb::call_policy<>. Each call will add +// an entry to `calls` containing the arguments tuple and return value. +// The return value will be recorded as "" if the function +// did not return (still executing or threw an exception) and as +// "" if the function returned something that we +// couldn't convert to a Python object. +// Additional features to test particular interactions: +// - the precall hook will throw if any arguments are not strings +// - any argument equal to "swapfrom" will be replaced by a temporary +// string object equal to "swapto", which will be destroyed at end of call +// - the postcall hook will throw if any argument equals "postthrow" +struct example_policy { + static inline std::vector> calls; + static void precall(PyObject **args, size_t nargs, + nb::detail::cleanup_list *cleanup) { + PyObject* tup = PyTuple_New(nargs); + for (size_t i = 0; i < nargs; ++i) { + if (!PyUnicode_CheckExact(args[i])) { + Py_DECREF(tup); + throw std::runtime_error("expected only strings"); + } + if (0 == PyUnicode_CompareWithASCIIString(args[i], "swapfrom")) { + nb::object replacement = nb::cast("swapto"); + args[i] = replacement.ptr(); + cleanup->append(replacement.release().ptr()); + } + Py_INCREF(args[i]); + PyTuple_SetItem(tup, i, args[i]); + } + calls.emplace_back(nb::steal(tup), nb::cast("")); + } + static void postcall(PyObject **args, size_t nargs, nb::handle ret) { + if (!ret.is_valid()) { + calls.back().second = nb::cast(""); + } else { + calls.back().second = nb::borrow(ret); + } + for (size_t i = 0; i < nargs; ++i) { + if (0 == PyUnicode_CompareWithASCIIString(args[i], "postthrow")) { + throw std::runtime_error("postcall exception"); + } + } + } +}; + +struct numeric_string { + unsigned long number; +}; + +template <> struct nb::detail::type_caster { + NB_TYPE_CASTER(numeric_string, const_name("str")) + + bool from_python(handle h, uint8_t flags, cleanup_list* cleanup) noexcept { + make_caster str_caster; + if (!str_caster.from_python(h, flags, cleanup)) + return false; + const char* str = str_caster.operator cast_t(); + if (!str) + return false; + char* endp; + value.number = strtoul(str, &endp, 10); + return *str && !*endp; + } + static handle from_cpp(numeric_string, rv_policy, handle) noexcept { + return nullptr; + } +}; + int test_31(int i) noexcept { return i; } NB_MODULE(test_functions_ext, m) { @@ -377,4 +448,23 @@ NB_MODULE(test_functions_ext, m) { m.def("test_bytearray_c_str", [](nb::bytearray o) -> const char * { return o.c_str(); }); m.def("test_bytearray_size", [](nb::bytearray o) { return o.size(); }); m.def("test_bytearray_resize", [](nb::bytearray c, int size) { return c.resize(size); }); + + // Test call_policy feature + m.def("test_call_policy", + [](const char* s, numeric_string n) -> const char* { + if (0 == strcmp(s, "returnfail")) { + return "not utf8 \xff"; + } + if (n.number > strlen(s)) { + throw std::runtime_error("offset too large"); + } + return s + n.number; + }, + nb::call_policy()); + + m.def("call_policy_record", + []() { + auto ret = std::move(example_policy::calls); + return ret; + }); } diff --git a/tests/test_functions.py b/tests/test_functions.py index 48824c6e..377ab86e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -650,3 +650,78 @@ def test49_resize(): assert len(o) == 4 t.test_bytearray_resize(o, 8) assert len(o) == 8 + + +def test50_call_policy(): + def case(arg1, arg2, expect_ret): # type: (str, str, str | None) -> str + if hasattr(sys, "getrefcount"): + refs_before = (sys.getrefcount(arg1), sys.getrefcount(arg2)) + + ret = None + try: + ret = t.test_call_policy(arg1, arg2) + assert ret == expect_ret + return ret + finally: + if expect_ret is None: + assert t.call_policy_record() == [] + else: + (((arg1r, arg2r), recorded_ret),) = t.call_policy_record() + assert recorded_ret == expect_ret + assert ret is None or ret is recorded_ret + assert recorded_ret is not expect_ret + + if hasattr(sys, "getrefcount"): + # Make sure no reference leak occurred: should be + # one in getrefcount args, one or two in locals, + # zero or one in the pending-return-value slot. + # We have to decompose this to avoid getting confused + # by transient additional references added by pytest's + # assertion rewriting. + ret_refs = sys.getrefcount(recorded_ret) + assert ret_refs == 2 + 2 * (ret is not None) + + for (passed, recorded) in ((arg1, arg1r), (arg2, arg2r)): + if passed == "swapfrom": + assert recorded == "swapto" + if hasattr(sys, "getrefcount"): + recorded_refs = sys.getrefcount(recorded) + # recorded, arg1r, unnamed tuple, getrefcount arg + assert recorded_refs == 4 + else: + assert passed is recorded + + del passed, recorded, arg1r, arg2r + if hasattr(sys, "getrefcount"): + refs_after = (sys.getrefcount(arg1), sys.getrefcount(arg2)) + assert refs_before == refs_after + + # precall throws exception + with pytest.raises(RuntimeError, match="expected only strings"): + case(12345, "0", None) + + # conversion of args fails + with pytest.raises(TypeError): + case("string", "xxx", "") + + # function throws exception + with pytest.raises(RuntimeError, match="offset too large"): + case("abc", "4", "") + + # conversion of return value fails + with pytest.raises(UnicodeDecodeError): + case("returnfail", "4", "") + + # postcall throws exception + with pytest.raises(RuntimeError, match="postcall exception"): + case("postthrow", "4", "throw") + + # normal call + case("example", "1", "xample") + + # precall modifies args + case("swapfrom", "0", "swapto") + with pytest.raises(TypeError): + case("swapfrom", "xxx", "") + with pytest.raises(RuntimeError, match="offset too large"): + case("swapfrom", "10", "") diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 5bfab347..05d5fe50 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -5,6 +5,8 @@ from typing import Annotated, Any, overload def call_guard_value() -> int: ... +def call_policy_record() -> list[tuple[tuple, object]]: ... + def hash_it(arg: object, /) -> int: ... def identity_i16(arg: int, /) -> int: ... @@ -178,6 +180,8 @@ def test_call_guard() -> int: ... def test_call_guard_wrapper_rvalue_ref(arg: int, /) -> int: ... +def test_call_policy(arg0: str, arg1: str, /) -> str: ... + def test_cast_char(arg: object, /) -> str: ... def test_cast_str(arg: object, /) -> str: ... From 0a9138211a2b6feab72dd1fea32a73e1e2f6436d Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 Nov 2024 23:55:39 -0700 Subject: [PATCH 2/2] Fix CI failure on 3.12+ Win/Mac --- src/nb_type.cpp | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/nb_type.cpp b/src/nb_type.cpp index 7cec9efd..5110ceaf 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -1318,19 +1318,18 @@ PyObject *nb_type_new(const type_init_data *t) noexcept { to->alias_chain = nullptr; to->init = nullptr; - if (has_dynamic_attr) { + if (has_dynamic_attr) to->flags |= (uint32_t) type_flags::has_dynamic_attr; - #if defined(Py_LIMITED_API) - to->dictoffset = (uint32_t) dictoffset; - #endif - } - - if (is_weak_referenceable) { + if (is_weak_referenceable) to->flags |= (uint32_t) type_flags::is_weak_referenceable; - #if defined(Py_LIMITED_API) - to->weaklistoffset = (uint32_t) weaklistoffset; - #endif - } + + #if defined(Py_LIMITED_API) + /* These must be set unconditionally so that nb_dict_ptr() / + nb_weaklist_ptr() return null (rather than garbage) on + objects whose types don't use the corresponding feature. */ + to->dictoffset = (uint32_t) dictoffset; + to->weaklistoffset = (uint32_t) weaklistoffset; + #endif if (t->scope != nullptr) setattr(t->scope, t_name, result);