Skip to content

Commit

Permalink
Fix thread safety for pybind11 loader_life_support (#3237)
Browse files Browse the repository at this point in the history
* Fix thread safety for pybind11 loader_life_support

Fixes issue: #2765

This converts the vector of PyObjects to either a single void* or
a per-thread void* depending on the WITH_THREAD define.

The new field is used by each thread to construct a stack
of loader_life_support frames that can extend the life of python
objects.

The pointer is updated when the loader_life_support object is allocated
(which happens before a call) as well as on release.

Each loader_life_support maintains a set of PyObject references
that need to be lifetime extended; this is done by storing them
in a c++ std::unordered_set and clearing the references when the
method completes.

* Also update the internals version as the internal struct is no longer compatible

* Add test demonstrating threading works correctly.

It may be appropriate to run this under msan/tsan/etc.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test to use lifetime-extended references rather than
std::string_view, as that's a C++ 17 feature.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Make loader_life_support members private

* Update version to dev2

* Update test to use python threading rather than concurrent.futures

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unnecessary env in test

* Remove unnecessary pytest in test

* Use native C++ thread_local in place of python per-thread data structures to retain compatability

* clang-format test_thread.cpp

* Add a note about debugging the py::cast() error

* thread_test.py now propagates exceptions on join() calls.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused sys / merge

* Update include order in test_thread.cpp

* Remove spurious whitespace

* Update comment / whitespace.

* Address review comments

* lint cleanup

* Fix test IntStruct constructor.

* Add explicit to constructor

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com>
  • Loading branch information
3 people authored Sep 10, 2021
1 parent 121b91f commit 0e59958
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 30 deletions.
4 changes: 2 additions & 2 deletions include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

#define PYBIND11_VERSION_MAJOR 2
#define PYBIND11_VERSION_MINOR 8
#define PYBIND11_VERSION_PATCH 0.dev1
#define PYBIND11_VERSION_PATCH 0.dev2

// Similar to Python's convention: https://docs.python.org/3/c-api/apiabiversion.html
// Additional convention: 0xD = dev
#define PYBIND11_VERSION_HEX 0x020800D1
#define PYBIND11_VERSION_HEX 0x020800D2

#define PYBIND11_NAMESPACE_BEGIN(name) namespace name {
#define PYBIND11_NAMESPACE_END(name) }
Expand Down
6 changes: 3 additions & 3 deletions include/pybind11/detail/internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct internals {
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients;
std::forward_list<ExceptionTranslator> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support`
std::vector<PyObject *> unused_loader_patient_stack_remove_at_v5;
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str()
PyTypeObject *static_property_type;
PyTypeObject *default_metaclass;
Expand Down Expand Up @@ -298,12 +298,12 @@ PYBIND11_NOINLINE internals &get_internals() {
#if PY_VERSION_HEX >= 0x03070000
internals_ptr->tstate = PyThread_tss_alloc();
if (!internals_ptr->tstate || (PyThread_tss_create(internals_ptr->tstate) != 0))
pybind11_fail("get_internals: could not successfully initialize the TSS key!");
pybind11_fail("get_internals: could not successfully initialize the tstate TSS key!");
PyThread_tss_set(internals_ptr->tstate, tstate);
#else
internals_ptr->tstate = PyThread_create_key();
if (internals_ptr->tstate == -1)
pybind11_fail("get_internals: could not successfully initialize the TLS key!");
pybind11_fail("get_internals: could not successfully initialize the tstate TLS key!");
PyThread_set_key_value(internals_ptr->tstate, tstate);
#endif
internals_ptr->istate = tstate->interp;
Expand Down
55 changes: 31 additions & 24 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,47 +31,54 @@ PYBIND11_NAMESPACE_BEGIN(detail)
/// A life support system for temporary objects created by `type_caster::load()`.
/// Adding a patient will keep it alive up until the enclosing function returns.
class loader_life_support {
private:
loader_life_support* parent = nullptr;
std::unordered_set<PyObject *> keep_alive;

static loader_life_support** get_stack_pp() {
#if defined(WITH_THREAD)
thread_local static loader_life_support* per_thread_stack = nullptr;
return &per_thread_stack;
#else
static loader_life_support* global_stack = nullptr;
return &global_stack;
#endif
}

public:
/// A new patient frame is created when a function is entered
loader_life_support() {
get_internals().loader_patient_stack.push_back(nullptr);
loader_life_support** stack = get_stack_pp();
parent = *stack;
*stack = this;
}

/// ... and destroyed after it returns
~loader_life_support() {
auto &stack = get_internals().loader_patient_stack;
if (stack.empty())
loader_life_support** stack = get_stack_pp();
if (*stack != this)
pybind11_fail("loader_life_support: internal error");

auto ptr = stack.back();
stack.pop_back();
Py_CLEAR(ptr);

// A heuristic to reduce the stack's capacity (e.g. after long recursive calls)
if (stack.capacity() > 16 && !stack.empty() && stack.capacity() / stack.size() > 2)
stack.shrink_to_fit();
*stack = parent;
for (auto* item : keep_alive)
Py_DECREF(item);
}

/// This can only be used inside a pybind11-bound function, either by `argument_loader`
/// at argument preparation time or by `py::cast()` at execution time.
PYBIND11_NOINLINE static void add_patient(handle h) {
auto &stack = get_internals().loader_patient_stack;
if (stack.empty())
loader_life_support* frame = *get_stack_pp();
if (!frame) {
// NOTE: It would be nice to include the stack frames here, as this indicates
// use of pybind11::cast<> outside the normal call framework, finding such
// a location is challenging. Developers could consider printing out
// stack frame addresses here using something like __builtin_frame_address(0)
throw cast_error("When called outside a bound function, py::cast() cannot "
"do Python -> C++ conversions which require the creation "
"of temporary values");

auto &list_ptr = stack.back();
if (list_ptr == nullptr) {
list_ptr = PyList_New(1);
if (!list_ptr)
pybind11_fail("loader_life_support: error allocating list");
PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr());
} else {
auto result = PyList_Append(list_ptr, h.ptr());
if (result == -1)
pybind11_fail("loader_life_support: error adding patient");
}

if (frame->keep_alive.insert(h.ptr()).second)
Py_INCREF(h.ptr());
}
};

Expand Down
2 changes: 1 addition & 1 deletion pybind11/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ def _to_int(s):
return s


__version__ = "2.8.0.dev1"
__version__ = "2.8.0.dev2"
version_info = tuple(_to_int(s) for s in __version__.split("."))
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ set(PYBIND11_TEST_FILES
test_stl.cpp
test_stl_binders.cpp
test_tagbased_polymorphic.cpp
test_thread.cpp
test_union.cpp
test_virtual_functions.cpp)

Expand Down
66 changes: 66 additions & 0 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
tests/test_thread.cpp -- call pybind11 bound methods in threads
Copyright (c) 2021 Laramie Leavitt (Google LLC) <lar@google.com>
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#include <pybind11/cast.h>
#include <pybind11/pybind11.h>

#include <chrono>
#include <thread>

#include "pybind11_tests.h"

namespace py = pybind11;

namespace {

struct IntStruct {
explicit IntStruct(int v) : value(v) {};
~IntStruct() { value = -value; }
IntStruct(const IntStruct&) = default;
IntStruct& operator=(const IntStruct&) = default;

int value;
};

} // namespace

TEST_SUBMODULE(thread, m) {

py::class_<IntStruct>(m, "IntStruct").def(py::init([](const int i) { return IntStruct(i); }));

// implicitly_convertible uses loader_life_support when an implicit
// conversion is required in order to lifetime extend the reference.
//
// This test should be run with ASAN for better effectiveness.
py::implicitly_convertible<int, IntStruct>();

m.def("test", [](int expected, const IntStruct &in) {
{
py::gil_scoped_release release;
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}

if (in.value != expected) {
throw std::runtime_error("Value changed!!");
}
});

m.def(
"test_no_gil",
[](int expected, const IntStruct &in) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
if (in.value != expected) {
throw std::runtime_error("Value changed!!");
}
},
py::call_guard<py::gil_scoped_release>());

// NOTE: std::string_view also uses loader_life_support to ensure that
// the string contents remain alive, but that's a C++ 17 feature.
}
44 changes: 44 additions & 0 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-

import threading

from pybind11_tests import thread as m


class Thread(threading.Thread):
def __init__(self, fn):
super(Thread, self).__init__()
self.fn = fn
self.e = None

def run(self):
try:
for i in range(10):
self.fn(i, i)
except Exception as e:
self.e = e

def join(self):
super(Thread, self).join()
if self.e:
raise self.e


def test_implicit_conversion():
a = Thread(m.test)
b = Thread(m.test)
c = Thread(m.test)
for x in [a, b, c]:
x.start()
for x in [c, b, a]:
x.join()


def test_implicit_conversion_no_gil():
a = Thread(m.test_no_gil)
b = Thread(m.test_no_gil)
c = Thread(m.test_no_gil)
for x in [a, b, c]:
x.start()
for x in [c, b, a]:
x.join()

0 comments on commit 0e59958

Please sign in to comment.