From ce2f00559498f4070821d540ba446e3a59766154 Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 7 Nov 2024 18:32:09 +0100 Subject: [PATCH] Fixed data race in all_type_info in free-threading mode (#5419) * Fix data race all_type_info_populate in free-threading mode Description: - fixed data race all_type_info_populate in free-threading mode - added test For example, we have 2 threads entering `all_type_info`. Both enter `all_type_info_get_cache`` function and there is a first one which inserts a tuple (type, empty_vector) to the map and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread after waiting gets (iter_to_key, False). Inserting thread than will add a weakref and will then call into `all_type_info_populate`. However, non-inserting thread is not entering `if (ins.second) {` clause and returns `ins.first->second;`` which is just empty_vector. Finally, non-inserting thread is failing the check in `allocate_layout`: ```c++ if (n_types == 0) { pybind11_fail( "instance allocation failed: new instance has no pybind11-registered base types"); } ``` * style: pre-commit fixes * Addressed PR comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- include/pybind11/detail/type_caster_base.h | 9 +------ include/pybind11/pybind11.h | 15 ++++++++--- tests/pybind11_tests.cpp | 5 ++++ tests/pybind11_tests.h | 21 ++++++++++++++++ tests/test_class.py | 29 ++++++++++++++++++++++ 5 files changed, 67 insertions(+), 12 deletions(-) diff --git a/include/pybind11/detail/type_caster_base.h b/include/pybind11/detail/type_caster_base.h index 0898be0140..d5d86dc6c1 100644 --- a/include/pybind11/detail/type_caster_base.h +++ b/include/pybind11/detail/type_caster_base.h @@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector(t->tp_bases)) { check.push_back((PyTypeObject *) parent.ptr()); } - auto const &type_dict = get_internals().registered_types_py; for (size_t i = 0; i < check.size(); i++) { auto *type = check[i]; @@ -176,13 +175,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector &all_type_info(PyTypeObject *type) { - auto ins = all_type_info_get_cache(type); - if (ins.second) { - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); - } - - return ins.first->second; + return all_type_info_get_cache(type).first->second; } /** diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 2527d25faf..b4f93f1a6a 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { inline std::pair all_type_info_get_cache(PyTypeObject *type) { auto res = with_internals([type](internals &internals) { - return internals - .registered_types_py + auto ins = internals + .registered_types_py #ifdef __cpp_lib_unordered_map_try_emplace - .try_emplace(type); + .try_emplace(type); #else - .emplace(type, std::vector()); + .emplace(type, std::vector()); #endif + if (ins.second) { + // For free-threading mode, this call must be under + // the with_internals() mutex lock, to avoid that other threads + // continue running with the empty ins.first->second. + all_type_info_populate(type, ins.first->second); + } + return ins; }); if (res.second) { // New cache entry created; set up a weak reference to automatically remove it if the type diff --git a/tests/pybind11_tests.cpp b/tests/pybind11_tests.cpp index 3d2d84e77a..818d53a548 100644 --- a/tests/pybind11_tests.cpp +++ b/tests/pybind11_tests.cpp @@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) { for (const auto &initializer : initializers()) { initializer(m); } + + py::class_(m, "TestContext") + .def(py::init<>(&TestContext::createNewContextForInit)) + .def("__enter__", &TestContext::contextEnter) + .def("__exit__", &TestContext::contextExit); } diff --git a/tests/pybind11_tests.h b/tests/pybind11_tests.h index 7be58feb6c..0eb0398df0 100644 --- a/tests/pybind11_tests.h +++ b/tests/pybind11_tests.h @@ -96,3 +96,24 @@ void ignoreOldStyleInitWarnings(F &&body) { )", py::dict(py::arg("body") = py::cpp_function(body))); } + +// See PR #5419 for background. +class TestContext { +public: + TestContext() = delete; + TestContext(const TestContext &) = delete; + TestContext(TestContext &&) = delete; + static TestContext *createNewContextForInit() { return new TestContext("new-context"); } + + pybind11::object contextEnter() { + py::object contextObj = py::cast(*this); + return contextObj; + } + void contextExit(const pybind11::object & /*excType*/, + const pybind11::object & /*excVal*/, + const pybind11::object & /*excTb*/) {} + +private: + explicit TestContext(const std::string &context) : context(context) {} + std::string context; +}; diff --git a/tests/test_class.py b/tests/test_class.py index f424db5c35..01963d0122 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from unittest import mock import pytest @@ -508,3 +509,31 @@ def test_pr4220_tripped_over_this(): m.Empty0().get_msg() == "This is really only meant to exercise successful compilation." ) + + +@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads") +def test_all_type_info_multithreaded(): + # See PR #5419 for background. + import threading + + from pybind11_tests import TestContext + + class Context(TestContext): + pass + + num_runs = 10 + num_threads = 4 + barrier = threading.Barrier(num_threads) + + def func(): + barrier.wait() + with Context(): + pass + + for _ in range(num_runs): + threads = [threading.Thread(target=func) for _ in range(num_threads)] + for thread in threads: + thread.start() + + for thread in threads: + thread.join()