Skip to content

Commit

Permalink
Merge pull request #804 from steffenlarsen/steffen/fix_scan_order
Browse files Browse the repository at this point in the history
Store scan results by group local ID instead of global ID
  • Loading branch information
bader authored Nov 2, 2023
2 parents 72c07a7 + 075c46c commit 19a205b
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions tests/group_functions/group_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,29 +416,39 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
sycl::group<D> group = item.get_group();
sycl::sub_group sub_group = item.get_sub_group();

auto index = item.get_global_linear_id();
local_id_acc[index] = group.get_local_linear_id();
local_id_acc[range_size + index] =
sub_group.get_local_linear_id();
// Use the local id of the item in the group to place results of
// the scan operation in the order of the items.
auto g_index = group.get_group_linear_id() *
group.get_group_linear_range() +
group.get_local_linear_id();
local_id_acc[g_index] = group.get_local_linear_id();

auto res_g_e = exclusive_scan_over_group_helper<T>(
group, ref_input_acc[index], op, with_init);
res_acc[index] = res_g_e;
group, ref_input_acc[g_index], op, with_init);
res_acc[g_index] = res_g_e;
ret_type_acc[0] = std::is_same_v<T, decltype(res_g_e)>;

auto res_g_i = inclusive_scan_over_group_helper<T>(
group, ref_input_acc[index], op, with_init);
res_acc[range_size + index] = res_g_i;
group, ref_input_acc[g_index], op, with_init);
res_acc[range_size + g_index] = res_g_i;
ret_type_acc[1] = std::is_same_v<T, decltype(res_g_i)>;

// Use the local id of the item in the sub-group to place
// results of the scan operation in the order of the items.
auto sg_index = sub_group.get_group_linear_id() *
sub_group.get_group_linear_range() +
sub_group.get_local_linear_id();
local_id_acc[range_size + sg_index] =
sub_group.get_local_linear_id();

auto res_sg_e = exclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[index], op, with_init);
res_acc[range_size * 2 + index] = res_sg_e;
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 2 + sg_index] = res_sg_e;
ret_type_acc[2] = std::is_same_v<T, decltype(res_sg_e)>;

auto res_sg_i = inclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[index], op, with_init);
res_acc[range_size * 3 + index] = res_sg_i;
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 3 + sg_index] = res_sg_i;
ret_type_acc[3] = std::is_same_v<T, decltype(res_sg_i)>;
});
})
Expand Down

0 comments on commit 19a205b

Please sign in to comment.