Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[group] Refine scan_over_group for sub-group #839

Merged
merged 5 commits into from
Dec 1, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions tests/group_functions/group_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,9 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
sycl::group<D> group = item.get_group();
sycl::sub_group sub_group = item.get_sub_group();

// Use the local id of the item in the group to place results of
// the scan operation in the order of the items.
auto g_index = group.get_group_linear_id() *
group.get_local_linear_range() +
group.get_local_linear_id();
// Use the global linear id of the item in the group to place
// results of the scan operation in the order of the items.
auto g_index = item.get_global_linear_id();
steffenlarsen marked this conversation as resolved.
Show resolved Hide resolved
local_id_acc[g_index] = group.get_local_linear_id();

auto res_g_e = exclusive_scan_over_group_helper<T>(
Expand All @@ -433,22 +431,17 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
res_acc[range_size + g_index] = res_g_i;
ret_type_acc[1] = std::is_same_v<T, decltype(res_g_i)>;

// Use the local id of the item in the sub-group to place
// results of the scan operation in the order of the items.
auto sg_index = sub_group.get_group_linear_id() *
sub_group.get_local_linear_range() +
sub_group.get_local_linear_id();
local_id_acc[range_size + sg_index] =
local_id_acc[range_size + g_index] =
sub_group.get_local_linear_id();

auto res_sg_e = exclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 2 + sg_index] = res_sg_e;
sub_group, ref_input_acc[g_index], op, with_init);
res_acc[range_size * 2 + g_index] = res_sg_e;
ret_type_acc[2] = std::is_same_v<T, decltype(res_sg_e)>;

auto res_sg_i = inclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 3 + sg_index] = res_sg_i;
sub_group, ref_input_acc[g_index], op, with_init);
res_acc[range_size * 3 + g_index] = res_sg_i;
ret_type_acc[3] = std::is_same_v<T, decltype(res_sg_i)>;
});
})
Expand Down