diff --git a/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc b/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc index 7b4112ec347..367b9b4ff4c 100644 --- a/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc +++ b/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc @@ -55,7 +55,12 @@ void scale(std::shared_ptr exec, const array* col_scale, const array* row_scale, batch::matrix::Ell* input) { - GKO_NOT_IMPLEMENTED; + const auto num_blocks = input->get_num_batch_items(); + const auto col_scale_vals = col_scale->get_const_data(); + const auto row_scale_vals = row_scale->get_const_data(); + const auto mat_ub = get_batch_struct(input); + scale_kernel<<get_stream()>>>( + as_device_type(col_scale_vals), as_device_type(row_scale_vals), mat_ub); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( diff --git a/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc b/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc index f50c43b2018..6bfa2291d70 100644 --- a/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc +++ b/common/cuda_hip/matrix/batch_ell_kernels.hpp.inc @@ -125,3 +125,43 @@ __global__ __launch_bounds__( x_b.values); } } + + +template +__device__ __forceinline__ void scale( + const ValueType* const __restrict__ col_scale, + const ValueType* const __restrict__ row_scale, + const gko::batch::matrix::ell::batch_item& mat) +{ + for (int tidx = threadIdx.x; tidx < mat.num_rows; tidx += blockDim.x) { + auto r_scale = row_scale[tidx]; + for (size_type idx = 0; idx < mat.num_stored_elems_per_row; idx++) { + const auto ind = tidx + idx * mat.stride; + const auto col_idx = mat.col_idxs[ind]; + if (col_idx == invalid_index()) { + break; + } else { + mat.values[ind] *= r_scale * col_scale[col_idx]; + } + } + } +} + + +template +__global__ void scale_kernel( + const ValueType* const __restrict__ col_scale_vals, + const ValueType* const __restrict__ row_scale_vals, + const gko::batch::matrix::ell::uniform_batch mat) +{ + auto num_rows = mat.num_rows; + auto num_cols = mat.num_cols; + for (size_type batch_id = blockIdx.x; batch_id < mat.num_batch_items; + batch_id += gridDim.x) { + const auto mat_b = + gko::batch::matrix::extract_batch_item(mat, batch_id); + const auto col_scale_b = col_scale_vals + num_cols * batch_id; + const auto row_scale_b = row_scale_vals + num_rows * batch_id; + scale(col_scale_b, row_scale_b, mat_b); + } +} diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index 34fdac3b4c7..404e38569d8 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -147,7 +147,37 @@ void scale(std::shared_ptr exec, const array* col_scale, const array* row_scale, batch::matrix::Ell* input) { - GKO_NOT_IMPLEMENTED; + const auto col_scale_vals = col_scale->get_const_data(); + const auto row_scale_vals = row_scale->get_const_data(); + const auto num_rows = static_cast(input->get_common_size()[0]); + const auto num_cols = static_cast(input->get_common_size()[1]); + auto mat_ub = get_batch_struct(input); + + const auto num_batch_items = mat_ub.num_batch_items; + auto device = exec->get_queue()->get_device(); + // TODO: use runtime selection of group size based on num_rows. + auto group_size = + device.get_info(); + + const dim3 block(group_size); + const dim3 grid(num_batch_items); + + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto col_scale_b = + col_scale_vals + num_cols * group_id; + const auto row_scale_b = + row_scale_vals + num_rows * group_id; + auto mat_item = + batch::matrix::extract_batch_item(mat_ub, group_id); + scale_kernel(col_scale_b, row_scale_b, mat_item, item_ct1); + }); + }); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( diff --git a/dpcpp/matrix/batch_ell_kernels.hpp.inc b/dpcpp/matrix/batch_ell_kernels.hpp.inc index f755c74ec8a..fa27d5bf37c 100644 --- a/dpcpp/matrix/batch_ell_kernels.hpp.inc +++ b/dpcpp/matrix/batch_ell_kernels.hpp.inc @@ -44,3 +44,24 @@ __dpct_inline__ void advanced_apply_kernel( x[tidx] = alpha * temp + beta * x[tidx]; } } + + +template +__dpct_inline__ void scale_kernel( + const ValueType* const col_scale, const ValueType* const row_scale, + gko::batch::matrix::ell::batch_item& mat, + sycl::nd_item<3>& item_ct1) +{ + for (int row = item_ct1.get_local_linear_id(); row < mat.num_rows; + row += item_ct1.get_local_range().size()) { + const ValueType rscale = row_scale[row]; + for (auto k = 0; k < mat.num_stored_elems_per_row; k++) { + auto col_idx = mat.col_idxs[row + mat.stride * k]; + if (col_idx == invalid_index()) { + break; + } else { + mat.values[row + mat.stride * k] *= rscale * col_scale[col_idx]; + } + } + } +}