Skip to content

Commit

Permalink
[smart_holder] Keep parent alive when returning raw pointers (google#…
Browse files Browse the repository at this point in the history
…4609)

* Avoid dangling pointers.

* Add test for const ptr

* Fix test failures.

* Fix ClangTidy

* fix emplace_back
  • Loading branch information
wangxf123456 authored Apr 6, 2023
1 parent b37a1cd commit 99cf27a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
6 changes: 5 additions & 1 deletion include/pybind11/detail/smart_holder_type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,11 @@ struct smart_holder_type_caster : smart_holder_type_caster_load<T>,

static handle cast(T *src, return_value_policy policy, handle parent) {
if (policy == return_value_policy::_clif_automatic) {
policy = return_value_policy::reference;
if (parent) {
policy = return_value_policy::reference_internal;
} else {
policy = return_value_policy::reference;
}
}
return cast(const_cast<T const *>(src), policy, parent); // Mutbl2Const
}
Expand Down
46 changes: 46 additions & 0 deletions tests/test_return_value_policy_override.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,37 @@

#include "pybind11_tests.h"

#include <vector>

namespace test_return_value_policy_override {

struct some_type {};

struct data_field {
int value = -99;

explicit data_field(int v) : value(v) {}
};

struct data_fields_holder {
std::vector<data_field> vec;

explicit data_fields_holder(std::size_t vec_size) {
for (std::size_t i = 0; i < vec_size; i++) {
vec.emplace_back(13 + static_cast<int>(i) * 11);
}
}

data_field *vec_at(std::size_t index) {
if (index >= vec.size()) {
return nullptr;
}
return &vec[index];
}

const data_field *vec_at_const_ptr(std::size_t index) { return vec_at(index); }
};

// cp = copyable, mv = movable, 1 = yes, 0 = no
struct type_cp1_mv1 {
std::string mtxt;
Expand Down Expand Up @@ -156,6 +183,8 @@ std::unique_ptr<type_cp0_mv0> return_unique_pointer_nocopy_nomove() {

} // namespace test_return_value_policy_override

using test_return_value_policy_override::data_field;
using test_return_value_policy_override::data_fields_holder;
using test_return_value_policy_override::some_type;
using test_return_value_policy_override::type_cp0_mv0;
using test_return_value_policy_override::type_cp0_mv1;
Expand Down Expand Up @@ -205,6 +234,8 @@ struct type_caster<some_type> : type_caster_base<some_type> {
} // namespace detail
} // namespace pybind11

PYBIND11_SMART_HOLDER_TYPE_CASTERS(data_field)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(data_fields_holder)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(type_cp1_mv1)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(type_cp0_mv1)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(type_cp1_mv0)
Expand Down Expand Up @@ -239,6 +270,21 @@ TEST_SUBMODULE(return_value_policy_override, m) {
},
py::return_value_policy::_clif_automatic);

py::classh<data_field>(m, "data_field").def_readwrite("value", &data_field::value);
py::classh<data_fields_holder>(m, "data_fields_holder")
.def(py::init<std::size_t>())
.def("vec_at",
[](const py::object &self_py, std::size_t index) {
auto *self = py::cast<data_fields_holder *>(self_py);
return py::cast(
self->vec_at(index), py::return_value_policy::_clif_automatic, self_py);
})
.def("vec_at_const_ptr", [](const py::object &self_py, std::size_t index) {
auto *self = py::cast<data_fields_holder *>(self_py);
return py::cast(
self->vec_at_const_ptr(index), py::return_value_policy::_clif_automatic, self_py);
});

py::classh<type_cp1_mv1>(m, "type_cp1_mv1")
.def(py::init<std::string>())
.def_readonly("mtxt", &type_cp1_mv1::mtxt);
Expand Down
15 changes: 15 additions & 0 deletions tests/test_return_value_policy_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ def test_return_pointer():
assert m.return_pointer_with_policy_clif_automatic() == "_clif_automatic"


def test_persistent_holder():
h = m.data_fields_holder(2)
assert h.vec_at(0).value == 13
assert h.vec_at(1).value == 24
assert h.vec_at_const_ptr(0).value == 13
assert h.vec_at_const_ptr(1).value == 24


def test_temporary_holder():
data_field = m.data_fields_holder(2).vec_at(1)
assert data_field.value == 24
data_field = m.data_fields_holder(2).vec_at_const_ptr(1)
assert data_field.value == 24


@pytest.mark.parametrize(
("func", "expected"),
[
Expand Down

0 comments on commit 99cf27a

Please sign in to comment.