From 831ed952807612e060c7d7f6fca4fc028d468870 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Mon, 4 Nov 2024 13:48:31 -0700 Subject: [PATCH] Add support for custom call policies --- docs/api_core.rst | 107 +++++++++++++++++++++++++++++++ docs/changelog.rst | 10 +++ docs/functions.rst | 7 ++ include/nanobind/nb_attr.h | 64 ++++++++++++++---- include/nanobind/nb_func.h | 20 +++--- src/nb_func.cpp | 32 ++++----- tests/test_functions.cpp | 79 +++++++++++++++++++++++ tests/test_functions.py | 75 ++++++++++++++++++++++ tests/test_functions_ext.pyi.ref | 4 ++ 9 files changed, 361 insertions(+), 37 deletions(-) diff --git a/docs/api_core.rst b/docs/api_core.rst index 3ef27967..a21db805 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -1895,6 +1895,113 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, Analogous to :cpp:struct:`for_getter`, but for setters. +.. cpp:struct:: template call_policy + + Request that custom logic be inserted around each call to the + bound function, by calling ``Policy::precall(args, nargs, cleanup)`` before + Python-to-C++ argument conversion, and ``Policy::postcall(args, nargs, ret)`` + after C++-to-Python return value conversion. + + If multiple call policy annotations are provided for the same function, then + their precall and postcall hooks will both execute left-to-right according + to the order in which the annotations were specified when binding the + function. + + The :cpp:struct:`nb::call_guard\() ` annotation + should be preferred over ``call_policy`` unless the wrapper logic + depends on the function arguments or return value. + If both annotations are combined, then + :cpp:struct:`nb::call_guard\() ` always executes on + the "inside" (closest to the bound function, after argument + conversions and before return value conversion) regardless of its + position in the function annotations list. + + Your ``Policy`` class must define two static member functions: + + .. cpp:function:: static void precall(PyObject **args, size_t nargs, detail::cleanup_list *cleanup); + + A hook that will be invoked before calling the bound function. More + precisely, it is called after any :ref:`argument locks ` + have been obtained, but before the Python arguments are converted to C++ + objects for the function call. + + This hook may access or modify the function arguments using the + *args* array, which holds borrowed references in one-to-one + correspondence with the C++ arguments of the bound function. If + the bound function is a method, then ``args[0]`` is its *self* + argument. *nargs* is the number of function arguments. It is actually + passed as ``std::integral_constant()``, so you can + match on that type if you want to do compile-time checks with it. + + The *cleanup* list may be used as it is used in type casters, + to cause some Python object references to be released at some point + after the bound function completes. (If the bound function is part + of an overload set, the cleanup list isn't released until all overloads + have been tried.) + + ``precall()`` may choose to throw a C++ exception. If it does, + it will preempt execution of the bound function, and the + exception will be treated as if the bound function had thrown it. + + .. cpp:function:: static void postcall(PyObject **args, size_t nargs, handle ret); + + A hook that will be invoked after calling the bound function and + converting its return value to a Python object, but only if the + bound function returned normally. + + *args* stores the Python object arguments, with the same semantics + as in ``precall()``, except that arguments that participated in + implicit conversions will have had their ``args[i]`` pointer updated + to reflect the new Python object that the implicit conversion produced. + *nargs* is the number of arguments, passed as a ``std::integral_constant`` + in the same way as for ``precall()``. + + *ret* is the bound function's return value. If the bound function returned + normally but its C++ return value could not be converted to a Python + object, then ``postcall()`` will execute with *ret* set to null, + and the Python error indicator might or might not be set to explain why. + + If the bound function did not return normally -- either because its + Python object arguments couldn't be converted to the appropriate C++ + types, or because the C++ function threw an exception -- then + ``postcall()`` **will not execute**. If you need some cleanup logic to + run even in such cases, your ``precall()`` can add a capsule object to the + cleanup list; its destructor will run eventually, but with no promises + as to when. A :cpp:struct:`nb::call_guard ` might be a + better choice. + + ``postcall()`` may choose to throw a C++ exception. If it does, + the result of the wrapped function will be destroyed, + and the exception will be raised in its place, as if the bound function + had thrown it just before returning. + + Here is an example policy to demonstrate. + ``nb::call_policy>()`` behaves like + :cpp:class:`nb::keep_alive\<0, I\>() `, except that the + return value is a treated as a list of objects rather than a single one. + + .. code-block:: cpp + + template + struct returns_references_to { + static void precall(PyObject **, size_t, nb::detail::cleanup_list *) {} + + template + static void postcall(PyObject **args, + std::integral_constant, + nb::handle ret) { + static_assert(I > 0 && I < N, + "I in returns_references_to must be in the " + "range [1, number of C++ function arguments]"); + if (!nb::isinstance(ret)) { + throw std::runtime_error("return value should be a sequence"); + } + for (nb::handle nurse : ret) { + nb::detail::keep_alive(nurse.ptr(), args[I]); + } + } + }; + .. _class_binding_annotations: Class binding annotations diff --git a/docs/changelog.rst b/docs/changelog.rst index bde4da7e..81bed959 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -23,6 +23,16 @@ Version TBD (unreleased) ``std::variant`` might cast a Python object wrapping a ``T`` to a ``U`` if there is an implicit conversion available from ``T`` to ``U``. +- Added a function annotation :cpp:class:`nb::call_policy\() + ` which supports custom function wrapping logic, + calling ``Policy::precall()`` before the bound function and + ``Policy::postcall()`` after. This is a low-level interface intended + for advanced users. The precall and postcall hooks are able to + observe the Python objects forming the function arguments and return + value, and the precall hook can change the arguments. See the linked + documentation for more details, important caveats, and an example policy. + (PR `#767 `. Construction occurs left to right, while destruction occurs in reverse. +If your wrapping needs are more complex than +:cpp:class:`nb::call_guard\() ` can handle, it is also +possible to define a custom "call policy", which can observe or modify the +Python object arguments and observe the return value. See the documentation of +:cpp:class:`nb::call_policy\ ` for details. + + .. _higher_order_adv: Higher-order functions diff --git a/include/nanobind/nb_attr.h b/include/nanobind/nb_attr.h index 0476824e..99f628b0 100644 --- a/include/nanobind/nb_attr.h +++ b/include/nanobind/nb_attr.h @@ -151,6 +151,8 @@ struct sig { struct is_getter { }; +template struct call_policy final {}; + NAMESPACE_BEGIN(literals) constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } NAMESPACE_END(literals) @@ -186,8 +188,9 @@ enum class func_flags : uint32_t { return_ref = (1 << 15), /// Does this overload specify a custom function signature (for docstrings, typing) has_signature = (1 << 16), - /// Does this function have one or more nb::keep_alive() annotations? - has_keep_alive = (1 << 17) + /// Does this function potentially modify the elements of the PyObject*[] array + /// representing its arguments? (nb::keep_alive() or call_policy annotations) + can_mutate_args = (1 << 17) }; enum cast_flags : uint8_t { @@ -384,12 +387,17 @@ NB_INLINE void func_extra_apply(F &, call_guard, size_t &) {} template NB_INLINE void func_extra_apply(F &f, nanobind::keep_alive, size_t &) { - f.flags |= (uint32_t) func_flags::has_keep_alive; + f.flags |= (uint32_t) func_flags::can_mutate_args; +} + +template +NB_INLINE void func_extra_apply(F &f, call_policy, size_t &) { + f.flags |= (uint32_t) func_flags::can_mutate_args; } template struct func_extra_info { using call_guard = void; - static constexpr bool keep_alive = false; + static constexpr bool pre_post_hooks = false; static constexpr size_t nargs_locked = 0; }; @@ -397,7 +405,7 @@ template struct func_extra_info : func_extra_info { }; template -struct func_extra_info, Ts...> : func_extra_info { +struct func_extra_info, Ts...> : func_extra_info { static_assert(std::is_same_v::call_guard, void>, "call_guard<> can only be specified once!"); using call_guard = nanobind::call_guard; @@ -405,29 +413,59 @@ struct func_extra_info, Ts...> : func_extra_info struct func_extra_info, Ts...> : func_extra_info { - static constexpr bool keep_alive = true; + static constexpr bool pre_post_hooks = true; +}; + +template +struct func_extra_info, Ts...> : func_extra_info { + static constexpr bool pre_post_hooks = true; }; template -struct func_extra_info : func_extra_info { +struct func_extra_info : func_extra_info { static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; }; template -struct func_extra_info : func_extra_info { +struct func_extra_info : func_extra_info { static constexpr size_t nargs_locked = 1 + func_extra_info::nargs_locked; }; -template -NB_INLINE void process_keep_alive(PyObject **, PyObject *, T *) { } +NB_INLINE void process_precall(PyObject **, size_t, detail::cleanup_list *, void *) { } + +template +NB_INLINE void +process_precall(PyObject **args, std::integral_constant nargs, + detail::cleanup_list *cleanup, call_policy *) { + Policy::precall(args, nargs, cleanup); +} + +NB_INLINE void process_postcall(PyObject **, size_t, PyObject *, void *) { } -template +template NB_INLINE void -process_keep_alive(PyObject **args, PyObject *result, - nanobind::keep_alive *) { +process_postcall(PyObject **args, std::integral_constant, + PyObject *result, nanobind::keep_alive *) { + static_assert(Nurse != Patient, + "keep_alive with the same argument as both nurse and patient " + "doesn't make sense"); + static_assert(Nurse <= NArgs && Patient <= NArgs, + "keep_alive template parameters must be in the range " + "[0, number of C++ function arguments]"); keep_alive(Nurse == 0 ? result : args[Nurse - 1], Patient == 0 ? result : args[Patient - 1]); } +template +NB_INLINE void +process_postcall(PyObject **args, std::integral_constant nargs, + PyObject *result, call_policy *) { + // result_guard avoids leaking a reference to the return object + // if postcall throws an exception + object result_guard = steal(result); + Policy::postcall(args, nargs, handle(result)); + result_guard.release(); +} + NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/nb_func.h b/include/nanobind/nb_func.h index da420349..10eb3994 100644 --- a/include/nanobind/nb_func.h +++ b/include/nanobind/nb_func.h @@ -11,14 +11,14 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) template -bool from_python_keep_alive(Caster &c, PyObject **args, uint8_t *args_flags, - cleanup_list *cleanup, size_t index) { +bool from_python_remember_conv(Caster &c, PyObject **args, uint8_t *args_flags, + cleanup_list *cleanup, size_t index) { size_t size_before = cleanup->size(); if (!c.from_python(args[index], args_flags[index], cleanup)) return false; // If an implicit conversion took place, update the 'args' array so that - // the keep_alive annotation can later process this change + // any keep_alive annotation or postcall hook can be aware of this change size_t size_after = cleanup->size(); if (size_after != size_before) args[index] = (*cleanup)[size_after - 1]; @@ -244,9 +244,11 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), } #endif - if constexpr (Info::keep_alive) { - if ((!from_python_keep_alive(in.template get(), args, - args_flags, cleanup, Is) || ...)) + if constexpr (Info::pre_post_hooks) { + std::integral_constant nargs_c; + (process_precall(args, nargs_c, cleanup, (Extra *) nullptr), ...); + if ((!from_python_remember_conv(in.template get(), args, + args_flags, cleanup, Is) || ...)) return NB_NEXT_OVERLOAD; } else { if ((!in.template get().from_python(args[Is], args_flags[Is], @@ -276,8 +278,10 @@ NB_INLINE PyObject *func_create(Func &&func, Return (*)(Args...), #endif } - if constexpr (Info::keep_alive) - (process_keep_alive(args, result, (Extra *) nullptr), ...); + if constexpr (Info::pre_post_hooks) { + std::integral_constant nargs_c; + (process_postcall(args, nargs_c, result, (Extra *) nullptr), ...); + } return result; }; diff --git a/src/nb_func.cpp b/src/nb_func.cpp index 08c4eefa..21883a62 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -196,21 +196,21 @@ PyObject *nb_func_new(const void *in_) noexcept { func_data_prelim<0> *f = (func_data_prelim<0> *) in_; arg_data *args_in = std::launder((arg_data *) f->args); - bool has_scope = f->flags & (uint32_t) func_flags::has_scope, - has_name = f->flags & (uint32_t) func_flags::has_name, - has_args = f->flags & (uint32_t) func_flags::has_args, - has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs, - has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args, - has_keep_alive = f->flags & (uint32_t) func_flags::has_keep_alive, - has_doc = f->flags & (uint32_t) func_flags::has_doc, - has_signature = f->flags & (uint32_t) func_flags::has_signature, - is_implicit = f->flags & (uint32_t) func_flags::is_implicit, - is_method = f->flags & (uint32_t) func_flags::is_method, - return_ref = f->flags & (uint32_t) func_flags::return_ref, - is_constructor = false, - is_init = false, - is_new = false, - is_setstate = false; + bool has_scope = f->flags & (uint32_t) func_flags::has_scope, + has_name = f->flags & (uint32_t) func_flags::has_name, + has_args = f->flags & (uint32_t) func_flags::has_args, + has_var_args = f->flags & (uint32_t) func_flags::has_var_kwargs, + has_var_kwargs = f->flags & (uint32_t) func_flags::has_var_args, + can_mutate_args = f->flags & (uint32_t) func_flags::can_mutate_args, + has_doc = f->flags & (uint32_t) func_flags::has_doc, + has_signature = f->flags & (uint32_t) func_flags::has_signature, + is_implicit = f->flags & (uint32_t) func_flags::is_implicit, + is_method = f->flags & (uint32_t) func_flags::is_method, + return_ref = f->flags & (uint32_t) func_flags::return_ref, + is_constructor = false, + is_init = false, + is_new = false, + is_setstate = false; PyObject *name = nullptr; PyObject *func_prev = nullptr; @@ -292,7 +292,7 @@ PyObject *nb_func_new(const void *in_) noexcept { maybe_make_immortal((PyObject *) func); // Check if the complex dispatch loop is needed - bool complex_call = has_keep_alive || has_var_kwargs || has_var_args || + bool complex_call = can_mutate_args || has_var_kwargs || has_var_args || f->nargs >= NB_MAXARGS_SIMPLE; if (has_args) { diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 22450d87..34e584e3 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -1,6 +1,9 @@ +#include + #include #include #include +#include namespace nb = nanobind; using namespace nb::literals; @@ -12,6 +15,63 @@ struct my_call_guard { ~my_call_guard() { call_guard_value = 2; } }; +struct example_policy { + static inline std::vector> calls; + static void precall(PyObject **args, size_t nargs, + nb::detail::cleanup_list *cleanup) { + PyObject* tup = PyTuple_New(nargs); + for (size_t i = 0; i < nargs; ++i) { + if (!PyUnicode_CheckExact(args[i])) { + Py_DECREF(tup); + throw std::runtime_error("expected only strings"); + } + if (0 == PyUnicode_CompareWithASCIIString(args[i], "swapfrom")) { + nb::object replacement = nb::cast("swapto"); + args[i] = replacement.ptr(); + cleanup->append(replacement.release().ptr()); + } + Py_INCREF(args[i]); + PyTuple_SetItem(tup, i, args[i]); + } + calls.emplace_back(nb::steal(tup), nb::cast("")); + } + static void postcall(PyObject **args, size_t nargs, nb::handle ret) { + if (!ret.is_valid()) { + calls.back().second = nb::cast(""); + } else { + calls.back().second = nb::borrow(ret); + } + for (size_t i = 0; i < nargs; ++i) { + if (0 == PyUnicode_CompareWithASCIIString(args[i], "postthrow")) { + throw std::runtime_error("postcall exception"); + } + } + } +}; + +struct numeric_string { + unsigned long number; +}; + +template <> struct nb::detail::type_caster { + NB_TYPE_CASTER(numeric_string, const_name("str")) + + bool from_python(handle h, uint8_t flags, cleanup_list* cleanup) noexcept { + make_caster str_caster; + if (!str_caster.from_python(h, flags, cleanup)) + return false; + const char* str = str_caster.operator cast_t(); + if (!str) + return false; + char* endp; + value.number = strtoul(str, &endp, 10); + return *str && !*endp; + } + static handle from_cpp(numeric_string, rv_policy, handle) noexcept { + return nullptr; + } +}; + int test_31(int i) noexcept { return i; } NB_MODULE(test_functions_ext, m) { @@ -377,4 +437,23 @@ NB_MODULE(test_functions_ext, m) { m.def("test_bytearray_c_str", [](nb::bytearray o) -> const char * { return o.c_str(); }); m.def("test_bytearray_size", [](nb::bytearray o) { return o.size(); }); m.def("test_bytearray_resize", [](nb::bytearray c, int size) { return c.resize(size); }); + + // Test call_policy feature + m.def("test_call_policy", + [](const char* s, numeric_string n) -> const char* { + if (0 == strcmp(s, "returnfail")) { + return "not utf8 \xff"; + } + if (n.number > strlen(s)) { + throw std::runtime_error("offset too large"); + } + return s + n.number; + }, + nb::call_policy()); + + m.def("call_policy_record", + []() { + auto ret = std::move(example_policy::calls); + return ret; + }); } diff --git a/tests/test_functions.py b/tests/test_functions.py index 48824c6e..377ab86e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -650,3 +650,78 @@ def test49_resize(): assert len(o) == 4 t.test_bytearray_resize(o, 8) assert len(o) == 8 + + +def test50_call_policy(): + def case(arg1, arg2, expect_ret): # type: (str, str, str | None) -> str + if hasattr(sys, "getrefcount"): + refs_before = (sys.getrefcount(arg1), sys.getrefcount(arg2)) + + ret = None + try: + ret = t.test_call_policy(arg1, arg2) + assert ret == expect_ret + return ret + finally: + if expect_ret is None: + assert t.call_policy_record() == [] + else: + (((arg1r, arg2r), recorded_ret),) = t.call_policy_record() + assert recorded_ret == expect_ret + assert ret is None or ret is recorded_ret + assert recorded_ret is not expect_ret + + if hasattr(sys, "getrefcount"): + # Make sure no reference leak occurred: should be + # one in getrefcount args, one or two in locals, + # zero or one in the pending-return-value slot. + # We have to decompose this to avoid getting confused + # by transient additional references added by pytest's + # assertion rewriting. + ret_refs = sys.getrefcount(recorded_ret) + assert ret_refs == 2 + 2 * (ret is not None) + + for (passed, recorded) in ((arg1, arg1r), (arg2, arg2r)): + if passed == "swapfrom": + assert recorded == "swapto" + if hasattr(sys, "getrefcount"): + recorded_refs = sys.getrefcount(recorded) + # recorded, arg1r, unnamed tuple, getrefcount arg + assert recorded_refs == 4 + else: + assert passed is recorded + + del passed, recorded, arg1r, arg2r + if hasattr(sys, "getrefcount"): + refs_after = (sys.getrefcount(arg1), sys.getrefcount(arg2)) + assert refs_before == refs_after + + # precall throws exception + with pytest.raises(RuntimeError, match="expected only strings"): + case(12345, "0", None) + + # conversion of args fails + with pytest.raises(TypeError): + case("string", "xxx", "") + + # function throws exception + with pytest.raises(RuntimeError, match="offset too large"): + case("abc", "4", "") + + # conversion of return value fails + with pytest.raises(UnicodeDecodeError): + case("returnfail", "4", "") + + # postcall throws exception + with pytest.raises(RuntimeError, match="postcall exception"): + case("postthrow", "4", "throw") + + # normal call + case("example", "1", "xample") + + # precall modifies args + case("swapfrom", "0", "swapto") + with pytest.raises(TypeError): + case("swapfrom", "xxx", "") + with pytest.raises(RuntimeError, match="offset too large"): + case("swapfrom", "10", "") diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 5bfab347..05d5fe50 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -5,6 +5,8 @@ from typing import Annotated, Any, overload def call_guard_value() -> int: ... +def call_policy_record() -> list[tuple[tuple, object]]: ... + def hash_it(arg: object, /) -> int: ... def identity_i16(arg: int, /) -> int: ... @@ -178,6 +180,8 @@ def test_call_guard() -> int: ... def test_call_guard_wrapper_rvalue_ref(arg: int, /) -> int: ... +def test_call_policy(arg0: str, arg1: str, /) -> str: ... + def test_cast_char(arg: object, /) -> str: ... def test_cast_str(arg: object, /) -> str: ...