Skip to content

Commit

Permalink
Add ell add_scaled_identity kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jan 14, 2024
1 parent 3d1e807 commit 8209ec7
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 2 deletions.
10 changes: 9 additions & 1 deletion common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* alpha,
const batch::MultiVector<ValueType>* beta,
batch::matrix::Ell<ValueType, IndexType>* mat)
GKO_NOT_IMPLEMENTED;
{
const auto num_blocks = mat->get_num_batch_items();
const auto alpha_ub = get_batch_struct(alpha);
const auto beta_ub = get_batch_struct(beta);
const auto mat_ub = get_batch_struct(mat);
add_scaled_identity_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(alpha_ub, beta_ub,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);
40 changes: 40 additions & 0 deletions common/cuda_hip/matrix/batch_ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,43 @@ __global__ void scale_kernel(
scale(col_scale_b, row_scale_b, mat_b);
}
}


template <typename ValueType, typename IndexType>
__device__ __forceinline__ void add_scaled_identity(
const ValueType alpha, const ValueType beta,
const gko::batch::matrix::ell::batch_item<ValueType, IndexType>& mat)
{
for (int tidx = threadIdx.x; tidx < mat.num_rows; tidx += blockDim.x) {
for (size_type idx = 0; idx < mat.num_stored_elems_per_row; idx++) {
const auto ind = tidx + idx * mat.stride;
mat.values[ind] *= beta;
const auto col_idx = mat.col_idxs[ind];
if (col_idx == invalid_index<IndexType>()) {
break;
} else {
if (tidx == col_idx) {
mat.values[ind] += alpha;
}
}
}
}
}


template <typename ValueType, typename IndexType>
__global__ void add_scaled_identity_kernel(
const gko::batch::multi_vector::uniform_batch<const ValueType> alpha,
const gko::batch::multi_vector::uniform_batch<const ValueType> beta,
const gko::batch::matrix::ell::uniform_batch<ValueType, IndexType> mat)
{
const size_type num_batch_items = mat.num_batch_items;
for (size_type batch_id = blockIdx.x; batch_id < num_batch_items;
batch_id += gridDim.x) {
const auto alpha_b = gko::batch::extract_batch_item(alpha, batch_id);
const auto beta_b = gko::batch::extract_batch_item(beta, batch_id);
const auto mat_b =
gko::batch::matrix::extract_batch_item(mat, batch_id);
add_scaled_identity(alpha_b.values[0], beta_b.values[0], mat_b);
}
}
33 changes: 32 additions & 1 deletion dpcpp/matrix/batch_ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,38 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* alpha,
const batch::MultiVector<ValueType>* beta,
batch::matrix::Ell<ValueType, IndexType>* mat)
GKO_NOT_IMPLEMENTED;
{
const auto alpha_ub = get_batch_struct(alpha);
const auto beta_ub = get_batch_struct(beta);
const auto mat_ub = get_batch_struct(mat);

const auto num_batch_items = mat_ub.num_batch_items;
auto device = exec->get_queue()->get_device();
auto group_size =
device.get_info<sycl::info::device::max_work_group_size>();

const dim3 block(group_size);
const dim3 grid(num_batch_items);

// Launch a kernel that has nbatches blocks, each block has max group size
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b =
gko::batch::extract_batch_item(alpha_ub, group_id);
const auto beta_b =
gko::batch::extract_batch_item(beta_ub, group_id);
const auto mat_b = gko::batch::matrix::extract_batch_item(
mat_ub, group_id);
add_scaled_identity_kernel(
alpha_b.values[0], beta_b.values[0], mat_b, item_ct1);
});
});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);
Expand Down
23 changes: 23 additions & 0 deletions dpcpp/matrix/batch_ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,26 @@ __dpct_inline__ void scale_kernel(
}
}
}


template <typename ValueType, typename IndexType>
__dpct_inline__ void add_scaled_identity_kernel(
const ValueType alpha, const ValueType beta,
const gko::batch::matrix::ell::batch_item<ValueType, IndexType>& mat,
sycl::nd_item<3>& item_ct1)
{
for (int row = item_ct1.get_local_linear_id(); row < mat.num_rows;
row += item_ct1.get_local_range().size()) {
for (auto k = 0; k < mat.num_stored_elems_per_row; k++) {
auto col_idx = mat.col_idxs[row + mat.stride * k];
mat.values[row + k * mat.stride] *= beta;
if (col_idx == invalid_index<IndexType>()) {
break;
} else {
if (row == col_idx) {
mat.values[row + k * mat.stride] += alpha;
}
}
}
}
}

0 comments on commit 8209ec7

Please sign in to comment.