Skip to content

Commit

Permalink
Add MultiVector scaling and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jan 13, 2024
1 parent 1eb406e commit 53073cc
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
template <typename ValueType>
void element_wise_scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* alpha,
batch::MultiVector<ValueType>* x) GKO_NOT_IMPLEMENTED;
batch::MultiVector<ValueType>* x)
{
const auto num_blocks = x->get_num_batch_items();
const auto alpha_ub = get_batch_struct(alpha);
const auto x_ub = get_batch_struct(x);
elem_wise_scale_kernel<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(alpha_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_DECLARE_BATCH_MULTI_VECTOR_ELEMENT_WISE_SCALE_KERNEL);
Expand Down
41 changes: 41 additions & 0 deletions common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ __device__ __forceinline__ void scale(
}
}


template <typename ValueType, typename Mapping>
__global__
__launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel(
Expand All @@ -32,6 +33,45 @@ __launch_bounds__(default_block_size, sm_oversubscription) void scale_kernel(
}


template <typename ValueType>
__device__ __forceinline__ void elem_wise_scale(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
const gko::batch::multi_vector::batch_item<ValueType>& x)
{
const int max_li = x.num_rows * x.num_rhs;
for (int li = threadIdx.x; li < max_li; li += blockDim.x) {
const int row = li / x.num_rhs;
const int col = li % x.num_rhs;

x.values[row * x.stride + col] *=
alpha.values[row * alpha.stride + col];
}
}


template <typename ValueType>
__global__ __launch_bounds__(
default_block_size,
sm_oversubscription) void elem_wise_scale_kernel(const gko::batch::
multi_vector::
uniform_batch<
const ValueType>
alpha,
const gko::batch::
multi_vector::
uniform_batch<
ValueType>
x)
{
for (size_type batch_id = blockIdx.x; batch_id < x.num_batch_items;
batch_id += gridDim.x) {
const auto alpha_b = gko::batch::extract_batch_item(alpha, batch_id);
const auto x_b = gko::batch::extract_batch_item(x, batch_id);
elem_wise_scale(alpha_b, x_b);
}
}


template <typename ValueType, typename Mapping>
__device__ __forceinline__ void add_scaled(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
Expand All @@ -48,6 +88,7 @@ __device__ __forceinline__ void add_scaled(
}
}


template <typename ValueType, typename Mapping>
__global__ __launch_bounds__(
default_block_size,
Expand Down
32 changes: 31 additions & 1 deletion dpcpp/base/batch_multi_vector_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,37 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
template <typename ValueType>
void element_wise_scale(std::shared_ptr<const DefaultExecutor> exec,
const batch::MultiVector<ValueType>* alpha,
batch::MultiVector<ValueType>* x) GKO_NOT_IMPLEMENTED;
batch::MultiVector<ValueType>* x)
{
const auto alpha_ub = get_batch_struct(alpha);
const auto x_ub = get_batch_struct(x);

const int num_rows = x->get_common_size()[0];
constexpr int max_subgroup_size = config::warp_size;
const auto num_batches = x_ub.num_batch_items;
auto device = exec->get_queue()->get_device();
long max_group_size =
device.get_info<sycl::info::device::max_work_group_size>();
int group_size =
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
max_group_size);

const dim3 block(group_size);
const dim3 grid(num_batches);

// Launch a kernel that has nbatches blocks, each block has max group size
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block), [=](sycl::nd_item<3> item_ct1) {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
elem_wise_scale_kernel(alpha_b, x_b, item_ct1);
});
});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_DECLARE_BATCH_MULTI_VECTOR_ELEMENT_WISE_SCALE_KERNEL);
Expand Down
18 changes: 18 additions & 0 deletions dpcpp/base/batch_multi_vector_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ __dpct_inline__ void scale_kernel(
}


template <typename ValueType>
__dpct_inline__ void elem_wise_scale_kernel(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
const gko::batch::multi_vector::batch_item<ValueType>& x,
sycl::nd_item<3>& item_ct1)
{
const int max_li = x.num_rows * x.num_rhs;
for (int li = item_ct1.get_local_linear_id(); li < max_li;
li += item_ct1.get_local_range().size()) {
const int row = li / x.num_rhs;
const int col = li % x.num_rhs;

x.values[row * x.stride + col] *=
alpha.values[row * alpha.stride + col];
}
}


template <typename ValueType, typename Mapping>
__dpct_inline__ void add_scaled_kernel(
const gko::batch::multi_vector::batch_item<const ValueType>& alpha,
Expand Down
20 changes: 20 additions & 0 deletions test/base/batch_multi_vector_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ class MultiVector : public CommonTestFixture {
std::normal_distribution<>(-1.0, 1.0), rand_engine, ref);
}

void set_up_elem_scale_vector_data(gko::size_type num_vecs,
const int num_rows = 252)
{
x = gen_mtx<Mtx>(batch_size, num_rows, num_vecs);
alpha = gen_mtx<Mtx>(batch_size, num_rows, num_vecs);
dx = gko::clone(exec, x);
dalpha = gko::clone(exec, alpha);
}

void set_up_vector_data(gko::size_type num_vecs, const int num_rows = 252,
bool different_alpha = false)
{
Expand Down Expand Up @@ -156,6 +165,17 @@ TEST_F(MultiVector, MultipleVectorScaleWithDifferentAlphaIsEquivalentToRef)
}


TEST_F(MultiVector, MultipleVectorElemWiseScaleIsEquivalentToRef)
{
set_up_elem_scale_vector_data(20);

x->scale(alpha.get());
dx->scale(dalpha.get());

GKO_ASSERT_BATCH_MTX_NEAR(dx, x, 5 * r<value_type>::value);
}


TEST_F(MultiVector, ComputeNorm2SingleSmallIsEquivalentToRef)
{
set_up_vector_data(1, 10);
Expand Down

0 comments on commit 53073cc

Please sign in to comment.