Skip to content

Commit

Permalink
Add ell device kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jan 13, 2024
1 parent f9a3443 commit 1eb406e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 2 deletions.
7 changes: 6 additions & 1 deletion common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const array<ValueType>* col_scale, const array<ValueType>* row_scale,
batch::matrix::Ell<ValueType, IndexType>* input)
{
GKO_NOT_IMPLEMENTED;
const auto num_blocks = input->get_num_batch_items();
const auto col_scale_vals = col_scale->get_const_data();
const auto row_scale_vals = row_scale->get_const_data();
const auto mat_ub = get_batch_struct(input);
scale_kernel<<<num_blocks, default_block_size, 0, exec->get_stream()>>>(
as_device_type(col_scale_vals), as_device_type(row_scale_vals), mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
Expand Down
40 changes: 40 additions & 0 deletions common/cuda_hip/matrix/batch_ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,43 @@ __global__ __launch_bounds__(
x_b.values);
}
}


template <typename ValueType, typename IndexType>
__device__ __forceinline__ void scale(
const ValueType* const __restrict__ col_scale,
const ValueType* const __restrict__ row_scale,
const gko::batch::matrix::ell::batch_item<ValueType, IndexType>& mat)
{
for (int tidx = threadIdx.x; tidx < mat.num_rows; tidx += blockDim.x) {
auto r_scale = row_scale[tidx];
for (size_type idx = 0; idx < mat.num_stored_elems_per_row; idx++) {
const auto ind = tidx + idx * mat.stride;
const auto col_idx = mat.col_idxs[ind];
if (col_idx == invalid_index<IndexType>()) {
break;
} else {
mat.values[ind] *= r_scale * col_scale[col_idx];
}
}
}
}


template <typename ValueType, typename IndexType>
__global__ void scale_kernel(
const ValueType* const __restrict__ col_scale_vals,
const ValueType* const __restrict__ row_scale_vals,
const gko::batch::matrix::ell::uniform_batch<ValueType, IndexType> mat)
{
auto num_rows = mat.num_rows;
auto num_cols = mat.num_cols;
for (size_type batch_id = blockIdx.x; batch_id < mat.num_batch_items;
batch_id += gridDim.x) {
const auto mat_b =
gko::batch::matrix::extract_batch_item(mat, batch_id);
const auto col_scale_b = col_scale_vals + num_cols * batch_id;
const auto row_scale_b = row_scale_vals + num_rows * batch_id;
scale(col_scale_b, row_scale_b, mat_b);
}
}
32 changes: 31 additions & 1 deletion dpcpp/matrix/batch_ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,37 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
const array<ValueType>* col_scale, const array<ValueType>* row_scale,
batch::matrix::Ell<ValueType, IndexType>* input)
{
GKO_NOT_IMPLEMENTED;
const auto col_scale_vals = col_scale->get_const_data();
const auto row_scale_vals = row_scale->get_const_data();
const auto num_rows = static_cast<int>(input->get_common_size()[0]);
const auto num_cols = static_cast<int>(input->get_common_size()[1]);
auto mat_ub = get_batch_struct(input);

const auto num_batch_items = mat_ub.num_batch_items;
auto device = exec->get_queue()->get_device();
// TODO: use runtime selection of group size based on num_rows.
auto group_size =
device.get_info<sycl::info::device::max_work_group_size>();

const dim3 block(group_size);
const dim3 grid(num_batch_items);

exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto col_scale_b =
col_scale_vals + num_cols * group_id;
const auto row_scale_b =
row_scale_vals + num_rows * group_id;
auto mat_item =
batch::matrix::extract_batch_item(mat_ub, group_id);
scale_kernel(col_scale_b, row_scale_b, mat_item, item_ct1);
});
});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
Expand Down
21 changes: 21 additions & 0 deletions dpcpp/matrix/batch_ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,24 @@ __dpct_inline__ void advanced_apply_kernel(
x[tidx] = alpha * temp + beta * x[tidx];
}
}


template <typename ValueType, typename IndexType>
__dpct_inline__ void scale_kernel(
const ValueType* const col_scale, const ValueType* const row_scale,
gko::batch::matrix::ell::batch_item<ValueType, IndexType>& mat,
sycl::nd_item<3>& item_ct1)
{
for (int row = item_ct1.get_local_linear_id(); row < mat.num_rows;
row += item_ct1.get_local_range().size()) {
const ValueType rscale = row_scale[row];
for (auto k = 0; k < mat.num_stored_elems_per_row; k++) {
auto col_idx = mat.col_idxs[row + mat.stride * k];
if (col_idx == invalid_index<IndexType>()) {
break;
} else {
mat.values[row + mat.stride * k] *= rscale * col_scale[col_idx];
}
}
}
}

0 comments on commit 1eb406e

Please sign in to comment.