From 53073cc70d8322012e5747d0d216ac3367ab9f37 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Sat, 13 Jan 2024 20:34:25 +0100 Subject: [PATCH] Add MultiVector scaling and tests --- ...batch_multi_vector_kernel_launcher.hpp.inc | 9 +++- .../base/batch_multi_vector_kernels.hpp.inc | 41 +++++++++++++++++++ dpcpp/base/batch_multi_vector_kernels.dp.cpp | 32 ++++++++++++++- dpcpp/base/batch_multi_vector_kernels.hpp.inc | 18 ++++++++ test/base/batch_multi_vector_kernels.cpp | 20 +++++++++ 5 files changed, 118 insertions(+), 2 deletions(-) diff --git a/common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc index 8ef1e6de2af..806bccbc692 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernel_launcher.hpp.inc @@ -26,7 +26,14 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( template void element_wise_scale(std::shared_ptr exec, const batch::MultiVector* alpha, - batch::MultiVector* x) GKO_NOT_IMPLEMENTED; + batch::MultiVector* x) +{ + const auto num_blocks = x->get_num_batch_items(); + const auto alpha_ub = get_batch_struct(alpha); + const auto x_ub = get_batch_struct(x); + elem_wise_scale_kernel<<get_stream()>>>(alpha_ub, x_ub); +} GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( GKO_DECLARE_BATCH_MULTI_VECTOR_ELEMENT_WISE_SCALE_KERNEL); diff --git a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc index 773bdd24637..f5a46056201 100644 --- a/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc +++ b/common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc @@ -17,6 +17,7 @@ __device__ __forceinline__ void scale( } } + template __global__ __launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel( @@ -32,6 +33,45 @@ __launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel( } +template +__device__ __forceinline__ void elem_wise_scale( + const gko::batch::multi_vector::batch_item& alpha, + const gko::batch::multi_vector::batch_item& x) +{ + const int max_li = x.num_rows * x.num_rhs; + for (int li = threadIdx.x; li < max_li; li += blockDim.x) { + const int row = li / x.num_rhs; + const int col = li % x.num_rhs; + + x.values[row * x.stride + col] *= + alpha.values[row * alpha.stride + col]; + } +} + + +template +__global__ __launch_bounds__( + default_block_size, + sm_oversubscription) void elem_wise_scale_kernel(const gko::batch:: + multi_vector:: + uniform_batch< + const ValueType> + alpha, + const gko::batch:: + multi_vector:: + uniform_batch< + ValueType> + x) +{ + for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items; + batch_id += gridDim.x) { + const auto alpha_b = gko::batch::extract_batch_item(alpha, batch_id); + const auto x_b = gko::batch::extract_batch_item(x, batch_id); + elem_wise_scale(alpha_b, x_b); + } +} + + template __device__ __forceinline__ void add_scaled( const gko::batch::multi_vector::batch_item& alpha, @@ -48,6 +88,7 @@ __device__ __forceinline__ void add_scaled( } } + template __global__ __launch_bounds__( default_block_size, diff --git a/dpcpp/base/batch_multi_vector_kernels.dp.cpp b/dpcpp/base/batch_multi_vector_kernels.dp.cpp index 317f8773bc4..d407aef8840 100644 --- a/dpcpp/base/batch_multi_vector_kernels.dp.cpp +++ b/dpcpp/base/batch_multi_vector_kernels.dp.cpp @@ -102,7 +102,37 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( template void element_wise_scale(std::shared_ptr exec, const batch::MultiVector* alpha, - batch::MultiVector* x) GKO_NOT_IMPLEMENTED; + batch::MultiVector* x) +{ + const auto alpha_ub = get_batch_struct(alpha); + const auto x_ub = get_batch_struct(x); + + const int num_rows = x->get_common_size()[0]; + constexpr int max_subgroup_size = config::warp_size; + const auto num_batches = x_ub.num_batch_items; + auto device = exec->get_queue()->get_device(); + long max_group_size = + device.get_info(); + int group_size = + std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size, + max_group_size); + + const dim3 block(group_size); + const dim3 grid(num_batches); + + // Launch a kernel that has nbatches blocks, each block has max group size + exec->get_queue()->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + elem_wise_scale_kernel(alpha_b, x_b, item_ct1); + }); + }); +} GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( GKO_DECLARE_BATCH_MULTI_VECTOR_ELEMENT_WISE_SCALE_KERNEL); diff --git a/dpcpp/base/batch_multi_vector_kernels.hpp.inc b/dpcpp/base/batch_multi_vector_kernels.hpp.inc index 1312632021a..f43ef2ff0da 100644 --- a/dpcpp/base/batch_multi_vector_kernels.hpp.inc +++ b/dpcpp/base/batch_multi_vector_kernels.hpp.inc @@ -20,6 +20,24 @@ __dpct_inline__ void scale_kernel( } +template +__dpct_inline__ void elem_wise_scale_kernel( + const gko::batch::multi_vector::batch_item& alpha, + const gko::batch::multi_vector::batch_item& x, + sycl::nd_item<3>& item_ct1) +{ + const int max_li = x.num_rows * x.num_rhs; + for (int li = item_ct1.get_local_linear_id(); li < max_li; + li += item_ct1.get_local_range().size()) { + const int row = li / x.num_rhs; + const int col = li % x.num_rhs; + + x.values[row * x.stride + col] *= + alpha.values[row * alpha.stride + col]; + } +} + + template __dpct_inline__ void add_scaled_kernel( const gko::batch::multi_vector::batch_item& alpha, diff --git a/test/base/batch_multi_vector_kernels.cpp b/test/base/batch_multi_vector_kernels.cpp index b896509d3e4..ab15e1a99a3 100644 --- a/test/base/batch_multi_vector_kernels.cpp +++ b/test/base/batch_multi_vector_kernels.cpp @@ -42,6 +42,15 @@ class MultiVector : public CommonTestFixture { std::normal_distribution<>(-1.0, 1.0), rand_engine, ref); } + void set_up_elem_scale_vector_data(gko::size_type num_vecs, + const int num_rows = 252) + { + x = gen_mtx(batch_size, num_rows, num_vecs); + alpha = gen_mtx(batch_size, num_rows, num_vecs); + dx = gko::clone(exec, x); + dalpha = gko::clone(exec, alpha); + } + void set_up_vector_data(gko::size_type num_vecs, const int num_rows = 252, bool different_alpha = false) { @@ -156,6 +165,17 @@ TEST_F(MultiVector, MultipleVectorScaleWithDifferentAlphaIsEquivalentToRef) } +TEST_F(MultiVector, MultipleVectorElemWiseScaleIsEquivalentToRef) +{ + set_up_elem_scale_vector_data(20); + + x->scale(alpha.get()); + dx->scale(dalpha.get()); + + GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r::value); +} + + TEST_F(MultiVector, ComputeNorm2SingleSmallIsEquivalentToRef) { set_up_vector_data(1, 10);