Skip to content

Commit

Permalink
use macro for subgroup size selection
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Dec 21, 2021
1 parent 9d30279 commit d3a82e5
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 225 deletions.
7 changes: 7 additions & 0 deletions dpcpp/base/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ struct config {
};


#if SYCL_LANGUAGE_VERSION < 202000
#define KERNEL_SUBGROUP_SIZE(val) [[intel::reqd_sub_group_size(val)]]
#else
#define KERNEL_SUBGROUP_SIZE(val) [[sycl::reqd_sub_group_size(val)]]
#endif


} // namespace dpcpp
} // namespace kernels
} // namespace gko
Expand Down
51 changes: 25 additions & 26 deletions dpcpp/base/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,18 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* @param name_ the name of the host function with config
* @param kernel_ the kernel name
*/
#define GKO_ENABLE_DEFAULT_HOST(name_, kernel_) \
template <typename... InferredArgs> \
void name_(dim3 grid, dim3 block, gko::size_type, sycl::queue* queue, \
InferredArgs... args) \
{ \
queue->submit([&](sycl::handler& cgh) { \
cgh.parallel_for( \
sycl_nd_range(grid, block), [= \
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( \
config::warp_size)]] { \
kernel_(args..., item_ct1); \
}); \
}); \
#define GKO_ENABLE_DEFAULT_HOST(name_, kernel_) \
template <typename... InferredArgs> \
void name_(dim3 grid, dim3 block, gko::size_type, sycl::queue* queue, \
InferredArgs... args) \
{ \
queue->submit([&](sycl::handler& cgh) { \
cgh.parallel_for(sycl_nd_range(grid, block), \
[=](sycl::nd_item<3> item_ct1) \
KERNEL_SUBGROUP_SIZE(config::warp_size) { \
kernel_(args..., item_ct1); \
}); \
}); \
}


Expand All @@ -80,19 +79,19 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* @param name_ the name of the host function with config
* @param kernel_ the kernel name
*/
#define GKO_ENABLE_DEFAULT_HOST_CONFIG(name_, kernel_) \
template <std::uint32_t encoded, typename... InferredArgs> \
inline void name_(dim3 grid, dim3 block, gko::size_type, \
sycl::queue* queue, InferredArgs... args) \
{ \
queue->submit([&](sycl::handler& cgh) { \
cgh.parallel_for( \
sycl_nd_range(grid, block), [= \
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( \
KCfg::decode<1>(encoded))]] { \
kernel_<encoded>(args..., item_ct1); \
}); \
}); \
#define GKO_ENABLE_DEFAULT_HOST_CONFIG(name_, kernel_) \
template <std::uint32_t encoded, typename... InferredArgs> \
inline void name_(dim3 grid, dim3 block, gko::size_type, \
sycl::queue* queue, InferredArgs... args) \
{ \
queue->submit([&](sycl::handler& cgh) { \
cgh.parallel_for( \
sycl_nd_range(grid, block), \
[=](sycl::nd_item<3> item_ct1) \
KERNEL_SUBGROUP_SIZE(KCFG_1D::decode<1>(encoded)) { \
kernel_<encoded>(args..., item_ct1); \
}); \
}); \
}

/**
Expand Down
152 changes: 72 additions & 80 deletions dpcpp/base/kernel_launch_reduction.dp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ void generic_kernel_reduction_1d(sycl::handler& cgh, int64 size,
const auto global_size = num_workgroups * wg_size;

cgh.parallel_for(
range, [=
](sycl::nd_item<3> idx) [[sycl::reqd_sub_group_size(sg_size)]] {
range, [=](sycl::nd_item<3> idx) KERNEL_SUBGROUP_SIZE(sg_size) {
auto subgroup_partial = &(*subgroup_partial_acc.get_pointer())[0];
const auto tidx = thread::get_thread_id_flat<int64>(idx);
const auto local_tidx = static_cast<int64>(tidx % wg_size);
Expand Down Expand Up @@ -129,8 +128,7 @@ void generic_kernel_reduction_2d(sycl::handler& cgh, int64 rows, int64 cols,
const auto global_size = num_workgroups * wg_size;

cgh.parallel_for(
range, [=
](sycl::nd_item<3> idx) [[sycl::reqd_sub_group_size(sg_size)]] {
range, [=](sycl::nd_item<3> idx) KERNEL_SUBGROUP_SIZE(sg_size) {
auto subgroup_partial = &(*subgroup_partial_acc.get_pointer())[0];
const auto tidx = thread::get_thread_id_flat<int64>(idx);
const auto local_tidx = static_cast<int64>(tidx % wg_size);
Expand Down Expand Up @@ -310,38 +308,35 @@ void generic_kernel_row_reduction_2d(syn::value_list<int, ssg_size>,
const auto num_workgroups = ceildiv(rows * col_blocks * ssg_size, wg_size);
const auto range = sycl_nd_range(dim3(num_workgroups), dim3(wg_size));
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
range, [=
](sycl::nd_item<3> id) [[sycl::reqd_sub_group_size(sg_size)]] {
const auto idx =
thread::get_subwarp_id_flat<ssg_size, int64>(id);
const auto row = idx % rows;
const auto col_block = idx / rows;
auto partial = identity;
auto subgroup = group::tiled_partition<sg_size>(
group::this_thread_block(id));
auto ssg_rank =
static_cast<int64>(subgroup.thread_rank() % ssg_size);
if (col_block < col_blocks) {
const auto cols_per_part =
ceildiv(ceildiv(cols, ssg_size), col_blocks) * ssg_size;
const auto begin = cols_per_part * col_block;
const auto end = min(begin + cols_per_part, cols);
for (auto col = begin + ssg_rank; col < end;
col += ssg_size) {
partial = op(partial, fn(row, col, args...));
}
cgh.parallel_for(range, [=](sycl::nd_item<3> id) KERNEL_SUBGROUP_SIZE(
sg_size) {
const auto idx = thread::get_subwarp_id_flat<ssg_size, int64>(id);
const auto row = idx % rows;
const auto col_block = idx / rows;
auto partial = identity;
auto subgroup =
group::tiled_partition<sg_size>(group::this_thread_block(id));
auto ssg_rank =
static_cast<int64>(subgroup.thread_rank() % ssg_size);
if (col_block < col_blocks) {
const auto cols_per_part =
ceildiv(ceildiv(cols, ssg_size), col_blocks) * ssg_size;
const auto begin = cols_per_part * col_block;
const auto end = min(begin + cols_per_part, cols);
for (auto col = begin + ssg_rank; col < end; col += ssg_size) {
partial = op(partial, fn(row, col, args...));
}
}
// since we do a sub-subgroup reduction, we can't use reduce
#pragma unroll
for (int i = 1; i < ssg_size; i *= 2) {
partial = op(partial, subgroup.shfl_xor(partial, i));
}
if (col_block < col_blocks && ssg_rank == 0) {
result[(row + col_block * rows) * result_stride] =
finalize(partial);
}
});
for (int i = 1; i < ssg_size; i *= 2) {
partial = op(partial, subgroup.shfl_xor(partial, i));
}
if (col_block < col_blocks && ssg_rank == 0) {
result[(row + col_block * rows) * result_stride] =
finalize(partial);
}
});
});
}

Expand All @@ -367,60 +362,57 @@ void generic_kernel_col_reduction_2d_small(
sycl::access_mode::read_write, sycl::access::target::local>
block_partial_acc(cgh);
const auto range = sycl_nd_range(dim3(row_blocks), dim3(wg_size));
cgh.parallel_for(
range, [=](sycl::nd_item<3> id) [[sycl::reqd_sub_group_size(sg_size)]] {
auto block_partial = &(*block_partial_acc.get_pointer())[0];
const auto ssg_id =
thread::get_subwarp_id_flat<ssg_size, int64>(id);
const auto local_sg_id = id.get_local_id(2) / sg_size;
const auto local_ssg_id = id.get_local_id(2) % sg_size / ssg_size;
const auto ssg_num =
thread::get_subwarp_num_flat<ssg_size, int64>(id);
const auto workgroup = group::this_thread_block(id);
const auto subgroup = group::tiled_partition<sg_size>(workgroup);
const auto sg_rank = subgroup.thread_rank();
const auto ssg_rank = sg_rank % ssg_size;
const auto col = static_cast<int64>(ssg_rank);
auto partial = identity;
// accumulate within a thread
if (col < cols) {
for (auto row = ssg_id; row < rows; row += ssg_num) {
partial = op(partial, fn(row, col, args...));
}
cgh.parallel_for(range, [=](sycl::nd_item<3> id) KERNEL_SUBGROUP_SIZE(
sg_size) {
auto block_partial = &(*block_partial_acc.get_pointer())[0];
const auto ssg_id = thread::get_subwarp_id_flat<ssg_size, int64>(id);
const auto local_sg_id = id.get_local_id(2) / sg_size;
const auto local_ssg_id = id.get_local_id(2) % sg_size / ssg_size;
const auto ssg_num = thread::get_subwarp_num_flat<ssg_size, int64>(id);
const auto workgroup = group::this_thread_block(id);
const auto subgroup = group::tiled_partition<sg_size>(workgroup);
const auto sg_rank = subgroup.thread_rank();
const auto ssg_rank = sg_rank % ssg_size;
const auto col = static_cast<int64>(ssg_rank);
auto partial = identity;
// accumulate within a thread
if (col < cols) {
for (auto row = ssg_id; row < rows; row += ssg_num) {
partial = op(partial, fn(row, col, args...));
}
}
// accumulate between all subsubgroups in the subgroup
#pragma unroll
for (unsigned i = ssg_size; i < sg_size; i *= 2) {
partial = op(partial, subgroup.shfl_xor(partial, i));
}
// store the result to shared memory
if (local_ssg_id == 0) {
block_partial[local_sg_id * ssg_size + ssg_rank] = partial;
}
workgroup.sync();
// in a single thread: accumulate the results
if (local_sg_id == 0) {
partial = identity;
// accumulate the partial results within a thread
if (shared_storage >= sg_size) {
for (unsigned i = ssg_size; i < sg_size; i *= 2) {
partial = op(partial, subgroup.shfl_xor(partial, i));
}
// store the result to shared memory
if (local_ssg_id == 0) {
block_partial[local_sg_id * ssg_size + ssg_rank] = partial;
}
workgroup.sync();
// in a single thread: accumulate the results
if (local_sg_id == 0) {
partial = identity;
// accumulate the partial results within a thread
if (shared_storage >= sg_size) {
#pragma unroll
for (int i = 0; i < shared_storage; i += sg_size) {
partial = op(partial, block_partial[i + sg_rank]);
}
} else if (sg_rank < shared_storage) {
partial = op(partial, block_partial[sg_rank]);
for (int i = 0; i < shared_storage; i += sg_size) {
partial = op(partial, block_partial[i + sg_rank]);
}
} else if (sg_rank < shared_storage) {
partial = op(partial, block_partial[sg_rank]);
}
// accumulate between all subsubgroups in the subgroup
#pragma unroll
for (unsigned i = ssg_size; i < sg_size; i *= 2) {
partial = op(partial, subgroup.shfl_xor(partial, i));
}
if (sg_rank < cols) {
result[sg_rank + id.get_group(2) * cols] =
finalize(partial);
}
for (unsigned i = ssg_size; i < sg_size; i *= 2) {
partial = op(partial, subgroup.shfl_xor(partial, i));
}
});
if (sg_rank < cols) {
result[sg_rank + id.get_group(2) * cols] = finalize(partial);
}
}
});
}


Expand All @@ -440,7 +432,7 @@ void generic_kernel_col_reduction_2d_blocked(
sycl::access_mode::read_write, sycl::access::target::local>
block_partial_acc(cgh);
cgh.parallel_for(
range, [=](sycl::nd_item<3> id) [[sycl::reqd_sub_group_size(sg_size)]] {
range, [=](sycl::nd_item<3> id) KERNEL_SUBGROUP_SIZE(sg_size) {
const auto sg_id = thread::get_subwarp_id_flat<sg_size, int64>(id);
const auto sg_num =
thread::get_subwarp_num_flat<sg_size, int64>(id);
Expand Down
33 changes: 16 additions & 17 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,8 @@ void abstract_classical_spmv(dim3 grid, dim3 block,
{
queue->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
subgroup_size)]] {
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) KERNEL_SUBGROUP_SIZE(subgroup_size) {
abstract_classical_spmv<subgroup_size>(num_rows, val, col_idxs,
row_ptrs, b, b_stride, c,
c_stride, item_ct1);
Expand Down Expand Up @@ -971,13 +970,14 @@ void reduce_total_cols(dim3 grid, dim3 block, size_type dynamic_shared_memory,
sycl::access::target::local>
block_result_acc_ct1(sycl::range<1>(default_block_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
config::warp_size)]] {
reduce_total_cols(num_slices, max_nnz_per_slice, result,
item_ct1, block_result_acc_ct1.get_pointer());
});
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
KERNEL_SUBGROUP_SIZE(config::warp_size) {
reduce_total_cols(
num_slices, max_nnz_per_slice, result,
item_ct1,
block_result_acc_ct1.get_pointer());
});
});
}

Expand All @@ -1004,13 +1004,12 @@ void reduce_max_nnz(dim3 grid, dim3 block, size_type dynamic_shared_memory,
sycl::access::target::local>
block_max_acc_ct1(sycl::range<1>(default_block_size), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
config::warp_size)]] {
reduce_max_nnz(size, nnz_per_row, result, item_ct1,
block_max_acc_ct1.get_pointer());
});
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1) KERNEL_SUBGROUP_SIZE(
config::warp_size) {
reduce_max_nnz(size, nnz_per_row, result, item_ct1,
block_max_acc_ct1.get_pointer());
});
});
}

Expand Down
30 changes: 15 additions & 15 deletions dpcpp/matrix/dense_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,13 @@ void reduce_max_nnz(dim3 grid, dim3 block, size_type dynamic_shared_memory,
dpct_local_acc_ct1(sycl::range<1>(dynamic_shared_memory), cgh);


cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
KCFG_1D::decode<1>(cfg))]] {
reduce_max_nnz<cfg>(size, nnz_per_row, result, item_ct1,
dpct_local_acc_ct1.get_pointer().get());
});
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
KERNEL_SUBGROUP_SIZE(KCFG_1D::decode<1>(cfg)) {
reduce_max_nnz<cfg>(
size, nnz_per_row, result, item_ct1,
dpct_local_acc_ct1.get_pointer().get());
});
});
}

Expand Down Expand Up @@ -382,14 +382,14 @@ void reduce_total_cols(dim3 grid, dim3 block, size_type dynamic_shared_memory,
sycl::access::target::local>
dpct_local_acc_ct1(sycl::range<1>(dynamic_shared_memory), cgh);

cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
KCFG_1D::decode<1>(cfg))]] {
reduce_total_cols<cfg>(num_slices, max_nnz_per_slice, result,
item_ct1,
dpct_local_acc_ct1.get_pointer().get());
});
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
KERNEL_SUBGROUP_SIZE(KCFG_1D::decode<1>(cfg)) {
reduce_total_cols<cfg>(
num_slices, max_nnz_per_slice, result,
item_ct1,
dpct_local_acc_ct1.get_pointer().get());
});
});
}

Expand Down
13 changes: 7 additions & 6 deletions dpcpp/matrix/ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,13 @@ void count_nnz_per_row(dim3 grid, dim3 block, size_type dynamic_shared_memory,
const ValueType* values, IndexType* result)
{
queue->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
count_nnz_per_row(num_rows, max_nnz_per_row, stride, values,
result, item_ct1);
});
cgh.parallel_for(sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
KERNEL_SUBGROUP_SIZE(config::warp_size) {
count_nnz_per_row(num_rows, max_nnz_per_row,
stride, values, result,
item_ct1);
});
});
}

Expand Down
Loading

0 comments on commit d3a82e5

Please sign in to comment.