Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make threaded 'for' API more explicit #79952

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions core/object/worker_thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ class WorkerThreadPool : public Object {
ud->userdata = p_userdata;
return _add_group_task(Callable(), nullptr, nullptr, ud, p_elements, p_tasks, p_high_priority, p_description);
}

// Syntactic sugar for add_template_group_task().
// Despite the seeming simplicity, the decision to use this over a plain loop
// should be carefully considered, since this won't benefit most cases.
template <typename F>
static void parallel_for(int p_begin, int p_end, bool p_parallel, String p_name, F p_function) {
if (!p_parallel) {
for (int i = p_begin; i < p_end; i++) {
p_function(i);
}
return;
}

auto wrapper = [&](int p_index, void *) {
p_function(p_index + p_begin);
};

WorkerThreadPool::GroupID gid = singleton->add_template_group_task(
&wrapper, &decltype(wrapper)::operator(), nullptr,
p_end - p_begin, -1,
true, p_name);
singleton->wait_for_group_task_completion(gid);
}

GroupID add_native_group_task(void (*p_func)(void *, uint32_t), void *p_userdata, int p_elements, int p_tasks = -1, bool p_high_priority = false, const String &p_description = String());
GroupID add_group_task(const Callable &p_action, int p_elements, int p_tasks = -1, bool p_high_priority = false, const String &p_description = String());
uint32_t get_group_processed_element_count(GroupID p_group) const;
Expand All @@ -202,25 +226,4 @@ class WorkerThreadPool : public Object {
~WorkerThreadPool();
};

template <typename F>
static _FORCE_INLINE_ void for_range(int i_begin, int i_end, bool parallel, String name, F f) {
if (!parallel) {
for (int i = i_begin; i < i_end; i++) {
f(i);
}
return;
}

auto wrapper = [&](int i, void *unused) {
f(i + i_begin);
};

WorkerThreadPool *wtp = WorkerThreadPool::get_singleton();
WorkerThreadPool::GroupID gid = wtp->add_template_group_task(
&wrapper, &decltype(wrapper)::operator(), nullptr,
i_end - i_begin, -1,
true, name);
wtp->wait_for_group_task_completion(gid);
}

#endif // WORKER_THREAD_POOL_H
2 changes: 1 addition & 1 deletion modules/raycast/raycast_occlusion_cull.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void RaycastOcclusionCull::Scenario::_update_dirty_instance(int p_idx, RID *p_in
// Embree requires the last element to be readable by a 16-byte SSE load instruction, so we add padding to be safe.
occ_inst->xformed_vertices.resize(vertices_size + 1);

for_range(0, vertices_size, vertices_size > 1024, SNAME("RaycastOcclusionCull"), [&](const int i) {
WorkerThreadPool::parallel_for(0, vertices_size, vertices_size > 1024, SNAME("RaycastOcclusionCull"), [&](const int i) {
occ_inst->xformed_vertices[i] = occ_inst->xform.xform(occ->vertices[i]);
});

Expand Down
6 changes: 3 additions & 3 deletions tests/core/threads/test_worker_thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,17 @@ TEST_CASE("[WorkerThreadPool] Parallel foreach") {
LocalVector<int> c;
c.resize(count_max);

for_range(0, count_max, true, String(), [&](int i) {
WorkerThreadPool::parallel_for(0, count_max, true, String(), [&](int i) {
c[i] = 1;
});
c.sort();
CHECK(c[0] == 1);
CHECK(c[0] == c[count_max - 1]);

for_range(0, midpoint, false, String(), [&](int i) {
WorkerThreadPool::parallel_for(0, midpoint, false, String(), [&](int i) {
c[i]++;
});
for_range(midpoint, count_max, true, String(), [&](int i) {
WorkerThreadPool::parallel_for(midpoint, count_max, true, String(), [&](int i) {
c[i]++;
});
c.sort();
Expand Down