Skip to content

Commit

Permalink
Major improvements to the std::function<> caster (roundtrips, cyclic …
Browse files Browse the repository at this point in the history
…GC) (#95)

- Roundtrip support! When a Python function has been wrapped in a
  ``std::function<>``, subsequent conversion to a Python object will
  return the original Python function object.

  Note that this is the opposite of pybind11's ``std::function<>``
  caster, where roundtrip support is implemented for C function pointers
  (which isn't possible in nanobind currently, and seems less useful in
  retrospect).

- Building on the previous point, the following C++ snippet
  ```cpp
  nb::object o = nb::cast(std_function_instance, nb::rv_policy::none);
  ```
  can be used to attempt a conversion of a ``std::function<>`` instance
  into a Python object. This will either return the function object or
  an invalid (``!o.is_valid()``) object if the conversion fails.

Why was this added? A useful feature of nanobind is that one can set
callback methods on bound C++ instances that redirect control flow
back to Python.

```python
a = MyCppClass()
a.f = lambda x: ...
```

A major potential issue here are reference leaks. What if the lambda
function assigned to ``a.f`` captures some variables from the
surrounding environment, which in turn reference the instance `a`? Then
we have a reference cycle that spans the Python <-> C++ boundary, and
that whole set of objects will never be deleted.

Fortunately, Python provides a garbage collector that can collect such
cycles, but we must provide it with further information so that it can
properly do its job. It must be able to traverse the C++ instance to
discover contained Python objects.

Below is a fully worked out example.

```cpp
// Type definition
struct MyCppClass {
    // A callback function that could be implemented in either language
    std::function<void(void>) f;
};

// Traversal method that may be invoked by Python's cyclic GC
int mycppclass_tp_traverse(PyObject *self, visitproc visit, void *arg) {
    MyCppClass *m = nb::cast<MyCppClass *>(nb::handle(self));
    if (m) {
        nb::object f = nb::cast(m->f, nb::rv_policy::none);

        // If 'f' is a Python function object, then traverse it recursively
        if (f.is_valid())
            Py_VISIT(f.ptr());
    }

    return 0;
};

// Callback to register additional type slots in the bindings
void mycppclass_type_callback = [](PyType_Slot **s) noexcept {
    *(*s)++ = { Py_tp_traverse, (void *) mycppclass_tp_traverse };
};

// .. binding code ...
nb::class_<MyCppClass>(m, "MyCppClass", nb::type_callback(mycppclass_type_callback))
    .def(nb::init<>())
    .def_readwrite("f", &FuncWrapper::f);
```

This commit also adds an example of such a cycle to the test suite,
which will fail if it cannot be garbage-collected.
  • Loading branch information
wjakob authored Oct 25, 2022
1 parent 81c31e9 commit c12dcd5
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 63 deletions.
56 changes: 41 additions & 15 deletions include/nanobind/stl/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,33 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

struct function_handle {
object f;
explicit function_handle(handle h): f(borrow(h)) { }
function_handle(function_handle &&h) noexcept : f(std::move(h.f)) { }
function_handle(const function_handle &h) {
gil_scoped_acquire acq;
f = h.f;
struct pyfunc_wrapper {
PyObject *f;

explicit pyfunc_wrapper(PyObject *f) : f(f) {
Py_INCREF(f);
}

pyfunc_wrapper(pyfunc_wrapper &&w) noexcept : f(w.f) {
w.f = nullptr;
}

pyfunc_wrapper(const pyfunc_wrapper &w) : f(w.f) {
if (f) {
gil_scoped_acquire acq;
Py_INCREF(f);
}
}

~function_handle() {
if (f.is_valid()) {
~pyfunc_wrapper() {
if (f) {
gil_scoped_acquire acq;
f.release().dec_ref();
Py_DECREF(f);
}
}

pyfunc_wrapper &operator=(const pyfunc_wrapper) = delete;
pyfunc_wrapper &operator=(pyfunc_wrapper &&) = delete;
};

template <typename Return, typename... Args>
Expand All @@ -42,25 +54,39 @@ struct type_caster<std::function<Return(Args...)>> {
concat(make_caster<Args>::Name...) + const_name("], ") +
ReturnCaster::Name + const_name("]"));

struct pyfunc_wrapper_t : pyfunc_wrapper {
using pyfunc_wrapper::pyfunc_wrapper;

Return operator()(Args... args) const {
gil_scoped_acquire acq;
return cast<Return>(handle(f)((forward_t<Args>) args...));
}
};

bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept {
if (src.is_none())
return flags & cast_flags::convert;

if (!PyCallable_Check(src.ptr()))
return false;

value = [f = function_handle(src)](Args... args) -> Return {
gil_scoped_acquire acq;
return cast<Return>(f.f((forward_t<Args>) args...));
};
value = pyfunc_wrapper_t(src.ptr());

return true;
}

static handle from_cpp(const Value &value, rv_policy,
static handle from_cpp(const Value &value, rv_policy rvp,
cleanup_list *) noexcept {
const pyfunc_wrapper_t *wrapper = value.template target<pyfunc_wrapper_t>();
if (wrapper)
return handle(wrapper->f).inc_ref();

if (!value)
return none().release();

if (rvp == rv_policy::none)
return handle();

return cpp_function(value).release();
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/nb_internals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

/// Tracks the ABI of nanobind
#ifndef NB_INTERNALS_VERSION
# define NB_INTERNALS_VERSION 4
# define NB_INTERNALS_VERSION 5
#endif

/// On MSVC, debug and release builds are not ABI-compatible!
Expand Down
2 changes: 0 additions & 2 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,5 @@ def test28_ellipsis():
def test29_traceback():
result = t.test_30(fail_fn)
regexp = r'Traceback \(most recent call last\):\n.*\n File "[^"]*", line 8, in fail_fn\n.*RuntimeError: Foo'
print("'%s'\n"%result)
print("'%s'\n"%regexp)
matches = re.findall(regexp, result, re.MULTILINE | re.DOTALL)
assert len(matches) == 1
56 changes: 43 additions & 13 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,30 @@ struct Copyable {
};

struct StructWithReadonlyMap {
std::map<std::string, uint64_t> map;
std::map<std::string, uint64_t> map;
};

struct FuncWrapper {
std::function<void(void)> f;
static int alive;
FuncWrapper() { alive++; }
~FuncWrapper() { alive--; }
};

int funcwrapper_tp_traverse(PyObject *self, visitproc visit, void *arg) {
FuncWrapper *w = nb::cast<FuncWrapper *>(nb::handle(self));
if (w) {
nb::object f = nb::cast(w->f, nb::rv_policy::none);
if (f.is_valid())
Py_VISIT(f.ptr());
}

return 0;
};


int FuncWrapper::alive = 0;

void fail() { throw std::exception(); }

NB_MODULE(test_stl_ext, m) {
Expand Down Expand Up @@ -85,7 +106,7 @@ NB_MODULE(test_stl_ext, m) {
.def(nb::init<>())
.def_readonly("map", &StructWithReadonlyMap::map);

// ----- test01-test12 ------ */
// ----- test01-test12 ------

m.def("return_movable", []() { return Movable(); });
m.def("return_movable_ptr", []() { return new Movable(); });
Expand All @@ -100,7 +121,7 @@ NB_MODULE(test_stl_ext, m) {
m.def("copyable_in_rvalue_ref", [](Copyable &&m) { Copyable x(m); if (x.value != 5) fail(); });
m.def("copyable_in_ptr", [](Copyable *m) { if (m->value != 5) fail(); });

// ----- test13-test20 ------ */
// ----- test13-test20 ------

m.def("tuple_return_movable", []() { return std::make_tuple(Movable()); });
m.def("tuple_return_movable_ptr", []() { return std::make_tuple(new Movable()); });
Expand All @@ -111,7 +132,7 @@ NB_MODULE(test_stl_ext, m) {
m.def("tuple_movable_in_rvalue_ref_2", [](std::tuple<Movable> &&m) { Movable x(std::move(std::get<0>(m))); if (x.value != 5) fail(); });
m.def("tuple_movable_in_ptr", [](std::tuple<Movable*> m) { if (std::get<0>(m)->value != 5) fail(); });

// ----- test21 ------ */
// ----- test21 ------

m.def("empty_tuple", [](std::tuple<>) { return std::tuple<>(); });
m.def("swap_tuple", [](const std::tuple<int, float> &v) {
Expand All @@ -121,7 +142,7 @@ NB_MODULE(test_stl_ext, m) {
return std::pair<float, int>(std::get<1>(v), std::get<0>(v));
});

// ----- test22 ------ */
// ----- test22 ------
m.def("vec_return_movable", [](){
std::vector<Movable> x;
x.reserve(10);
Expand Down Expand Up @@ -183,14 +204,14 @@ NB_MODULE(test_stl_ext, m) {
fail();
});

// ----- test29 ------ */
// ----- test29 ------
using fvec = std::vector<float, std::allocator<float>>;
nb::class_<fvec>(m, "float_vec")
.def(nb::init<>())
.def("push_back", [](fvec *fv, float f) { fv->push_back(f); })
.def("size", [](const fvec &fv) { return fv.size(); });

// ----- test30 ------ */
// ----- test30 ------

m.def("return_empty_function", []() -> std::function<int(int)> {
return {};
Expand All @@ -210,11 +231,20 @@ NB_MODULE(test_stl_ext, m) {

m.def("identity_list", [](std::list<int> &x) { return x; });

// ----- test33 ------ */
auto callback = [](PyType_Slot **s) noexcept {
*(*s)++ = { Py_tp_traverse, (void *) funcwrapper_tp_traverse };
};

nb::class_<FuncWrapper>(m, "FuncWrapper", nb::type_callback(callback))
.def(nb::init<>())
.def_readwrite("f", &FuncWrapper::f)
.def_readonly_static("alive", &FuncWrapper::alive);

// ----- test35 ------
m.def("identity_string", [](std::string& x) { return x; });
m.def("identity_string_view", [](std::string_view& x) { return x; });

// ----- test34-test40 ------ */
// ----- test36-test42 ------
m.def("optional_copyable", [](std::optional<Copyable> &) {}, nb::arg("x").none());
m.def("optional_copyable_ptr", [](std::optional<Copyable *> &) {}, nb::arg("x").none());
m.def("optional_none", [](std::optional<Copyable> &x) { if(x) fail(); }, nb::arg("x").none());
Expand All @@ -223,7 +253,7 @@ NB_MODULE(test_stl_ext, m) {
m.def("optional_ret_opt_none", []() { return std::optional<Movable>(); });
m.def("optional_unbound_type", [](std::optional<int> &x) { return x; }, nb::arg("x") = nb::none());

// ----- test41-test47 ------ */
// ----- test43-test50 ------
m.def("variant_copyable", [](std::variant<Copyable, int> &) {});
m.def("variant_copyable_none", [](std::variant<std::monostate, Copyable, int> &) {}, nb::arg("x").none());
m.def("variant_copyable_ptr", [](std::variant<Copyable *, int> &) {});
Expand All @@ -233,7 +263,7 @@ NB_MODULE(test_stl_ext, m) {
m.def("variant_unbound_type", [](std::variant<std::monostate, nb::list, nb::tuple, int> &x) { return x; },
nb::arg("x") = nb::none());

// ----- test48-test55 ------ */
// ----- test50-test57 ------
m.def("map_return_movable_value", [](){
std::map<std::string, Movable> x;
for (int i = 0; i < 10; ++i)
Expand Down Expand Up @@ -303,11 +333,11 @@ NB_MODULE(test_stl_ext, m) {
return x;
});

// test56
// test58
m.def("array_out", [](){ return std::array<int, 3>{1, 2, 3}; });
m.def("array_in", [](std::array<int, 3> x) { return x[0] + x[1] + x[2]; });

// ----- test58-test62 ------ */
// ----- test60-test64 ------
m.def("set_return_value", []() {
std::set<std::string> x;
for (int i = 0; i < 10; ++i)
Expand Down
Loading

0 comments on commit c12dcd5

Please sign in to comment.