Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash from concurrent nb::make_iterator<> under free-threading. #832

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions include/nanobind/make_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,30 @@ typed<iterator, ValueType> make_iterator_impl(handle scope, const char *name,
"make_iterator_impl(): the generated __next__ would copy elements, so the "
"element type must be copy-constructible");

if (!type<State>().is_valid()) {
class_<State>(scope, name)
.def("__iter__", [](handle h) { return h; })
.def("__next__",
[](State &s) -> ValueType {
if (!s.first_or_done)
++s.it;
else
s.first_or_done = false;

if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}

return Access()(s.it);
},
std::forward<Extra>(extra)...,
Policy);
{
static ft_mutex mu;
ft_lock_guard lock(mu);
if (!type<State>().is_valid()) {
class_<State>(scope, name)
.def("__iter__", [](handle h) { return h; })
.def("__next__",
[](State &s) -> ValueType {
if (!s.first_or_done)
++s.it;
else
s.first_or_done = false;

if (s.it == s.end) {
s.first_or_done = true;
throw stop_iteration();
}

return Access()(s.it);
},
std::forward<Extra>(extra)...,
Policy);
}
}

return borrow<typed<iterator, ValueType>>(cast(State{
std::forward<Iterator>(first), std::forward<Sentinel>(last), true }));
}
Expand Down
23 changes: 23 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import platform
import gc
import pytest
import threading

is_pypy = platform.python_implementation() == 'PyPy'
is_darwin = platform.system() == 'Darwin'
Expand All @@ -17,3 +18,25 @@ def collect() -> None:

xfail_on_pypy_darwin = pytest.mark.xfail(
is_pypy and is_darwin, reason="This test for some reason fails on PyPy/Darwin")


# Helper function to parallelize execution of a function. We intentionally
# don't use the Python threads pools here to have threads shut down / start
# between test cases.
def parallelize(func, n_threads):
barrier = threading.Barrier(n_threads)
result = [None]*n_threads

def wrapper(i):
barrier.wait()
result[i] = func()

workers = []
for i in range(n_threads):
t = threading.Thread(target=wrapper, args=(i,))
t.start()
workers.append(t)

for worker in workers:
worker.join()
return result
6 changes: 5 additions & 1 deletion tests/test_make_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import test_make_iterator_ext as t
import pytest
from common import parallelize

data = [
{},
Expand Down Expand Up @@ -30,6 +30,10 @@ def test03_items_iterator():
assert sorted(list(m.items_l())) == sorted(list(d.items()))


def test03_items_iterator_parallel(n_threads=8):
parallelize(test03_items_iterator, n_threads=n_threads)


def test04_passthrough_iterator():
for d in data:
m = t.StringMap(d)
Expand Down
25 changes: 1 addition & 24 deletions tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,6 @@
import test_thread_ext as t
from test_thread_ext import Counter

import threading

# Helper function to parallelize execution of a function. We intentionally
# don't use the Python threads pools here to have threads shut down / start
# between test cases.
def parallelize(func, n_threads):
barrier = threading.Barrier(n_threads)
result = [None]*n_threads

def wrapper(i):
barrier.wait()
result[i] = func()

workers = []
for i in range(n_threads):
t = threading.Thread(target=wrapper, args=(i,))
t.start()
workers.append(t)

for worker in workers:
worker.join()
return result

from common import parallelize

def test01_object_creation(n_threads=8):
# This test hammers 'inst_c2p' from multiple threads, and
Expand Down
Loading