Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track patients with unordered_set rather than vector #1253

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
8 changes: 4 additions & 4 deletions include/pybind11/detail/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ inline void add_patient(PyObject *nurse, PyObject *patient) {
auto &internals = get_internals();
auto instance = reinterpret_cast<detail::instance *>(nurse);
instance->has_patients = true;
Py_INCREF(patient);
internals.patients[nurse].push_back(patient);
auto it = internals.patients[nurse].insert(patient);
if (it.second) Py_INCREF(patient);
}

inline void clear_patients(PyObject *self) {
Expand All @@ -372,12 +372,12 @@ inline void clear_patients(PyObject *self) {
auto pos = internals.patients.find(self);
assert(pos != internals.patients.end());
// Clearing the patients can cause more Python code to run, which
// can invalidate the iterator. Extract the vector of patients
// can invalidate the iterator. Extract the set of patients
// from the unordered_map first.
auto patients = std::move(pos->second);
internals.patients.erase(pos);
instance->has_patients = false;
for (PyObject *&patient : patients)
for (PyObject *patient : patients)
Py_CLEAR(patient);
}

Expand Down
4 changes: 2 additions & 2 deletions include/pybind11/detail/internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct internals {
std::unordered_multimap<const void *, instance*> registered_instances; // void * -> instance*
std::unordered_set<std::pair<const PyObject *, const char *>, override_hash> inactive_override_cache;
type_map<std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::unordered_map<const PyObject *, std::unordered_set<PyObject *>> patients;
std::forward_list<ExceptionTranslator> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
Expand Down Expand Up @@ -154,7 +154,7 @@ struct type_info {
};

/// Tracks the `internals` and `type_info` ABI version independent of the main library version
#define PYBIND11_INTERNALS_VERSION 4
#define PYBIND11_INTERNALS_VERSION 5

/// On MSVC, debug and release builds are not ABI-compatible!
#if defined(_MSC_VER) && defined(_DEBUG)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_call_policies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

#include "pybind11_tests.h"
#include "constructor_stats.h"

struct CustomGuard {
static bool enabled;
Expand Down Expand Up @@ -68,6 +69,21 @@ TEST_SUBMODULE(call_policies, m) {
m.def("free_function", [](Parent*, Child*) {}, py::keep_alive<1, 2>());
m.def("invalid_arg_index", []{}, py::keep_alive<0, 1>());

// test_keep_alive_single
m.def("add_patient", [](py::object /*nurse*/, py::object /*patient*/) { }, py::keep_alive<1, 2>());
m.def("get_patients", [](py::object nurse) {
py::set patients;
for (PyObject *p : pybind11::detail::get_internals().patients[nurse.ptr()])
patients.add(py::reinterpret_borrow<py::object>(p));
return patients;
});
m.def("has_patients", [](uint64_t nurse_id) {
// This assumes that id() and PyObject* are equivalent.
// We use this to allow the original `nurse` object to be garbage collected.
PyObject *nurse_ptr = (PyObject*)nurse_id;
return pybind11::detail::get_internals().patients.count(nurse_ptr);
});

#if !defined(PYPY_VERSION)
// test_alive_gc
class ParentGC : public Parent {
Expand Down
50 changes: 49 additions & 1 deletion tests/test_call_policies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
import sys

import pytest

import env # noqa: F401
from pybind11_tests import ConstructorStats
from pybind11_tests import ConstructorStats, UserType
from pybind11_tests import call_policies as m


Expand Down Expand Up @@ -115,6 +117,52 @@ def test_keep_alive_return_value(capture):
)


def refcount(h):
pytest.gc_collect()
return sys.getrefcount(h)


@pytest.mark.xfail("env.PYPY", reason="getrefcount is unimplemented")
def test_keep_alive_single():
"""Issue #1251 - patients are stored multiple times when given to the same nurse"""

nurse, p1, p2 = UserType(), UserType(), UserType()
b = refcount(nurse)
nurse_id = id(nurse)
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b, b]
m.add_patient(nurse, p1)
assert m.get_patients(nurse) == {
p1,
}
assert m.has_patients(nurse_id)
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b + 1, b]
m.add_patient(nurse, p1)
assert m.get_patients(nurse) == {
p1,
}
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b + 1, b]
m.add_patient(nurse, p1)
assert m.get_patients(nurse) == {
p1,
}
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b + 1, b]
m.add_patient(nurse, p2)
assert m.get_patients(nurse) == {p1, p2}
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b + 1, b + 1]
m.add_patient(nurse, p2)
assert m.get_patients(nurse) == {p1, p2}
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b + 1, b + 1]
m.add_patient(nurse, p2)
m.add_patient(nurse, p1)
assert m.get_patients(nurse) == {p1, p2}
assert [refcount(nurse), refcount(p1), refcount(p2)] == [b, b + 1, b + 1]
del nurse
pytest.gc_collect()
assert not m.has_patients(nurse_id)
# Ensure that nurse entry is removed once it goes out of scope.
assert [refcount(p1), refcount(p2)] == [b, b]


# https://foss.heptapod.net/pypy/pypy/-/issues/2447
@pytest.mark.xfail("env.PYPY", reason="_PyObject_GetDictPtr is unimplemented")
def test_alive_gc(capture):
Expand Down