Skip to content

Commit

Permalink
Add ell ref/omp scale kernels and test
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jan 13, 2024
1 parent 792ca20 commit f9a3443
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 12 deletions.
12 changes: 12 additions & 0 deletions common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,15 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);


template <typename ValueType, typename IndexType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const array<ValueType>* col_scale, const array<ValueType>* row_scale,
batch::matrix::Ell<ValueType, IndexType>* input)
{
GKO_NOT_IMPLEMENTED;
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);
1 change: 1 addition & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ namespace batch_ell {

GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);


} // namespace batch_ell
Expand Down
8 changes: 7 additions & 1 deletion core/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace {

GKO_REGISTER_OPERATION(simple_apply, batch_ell::simple_apply);
GKO_REGISTER_OPERATION(advanced_apply, batch_ell::advanced_apply);
GKO_REGISTER_OPERATION(scale, batch_ell::scale);


} // namespace
Expand Down Expand Up @@ -210,7 +211,12 @@ void two_sided_scale(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Ell<ValueType, IndexType>* in_out)
{
GKO_NOT_IMPLEMENTED;
GKO_ASSERT_EQ(col_scale.get_size(), (in_out->get_common_size()[1] *
in_out->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(), (in_out->get_common_size()[0] *
in_out->get_num_batch_items()));
in_out->get_executor()->run(
ell::make_scale(&col_scale, &row_scale, in_out));
}


Expand Down
19 changes: 14 additions & 5 deletions core/matrix/batch_ell_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,20 @@ namespace kernels {
const batch::MultiVector<_vtype>* beta, \
batch::MultiVector<_vtype>* c)

#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL(ValueType, IndexType)
#define GKO_DECLARE_BATCH_ELL_SCALE_KERNEL(_vtype, _itype) \
void scale(std::shared_ptr<const DefaultExecutor> exec, \
const array<_vtype>* left_scale, \
const array<_vtype>* right_scale, \
batch::matrix::Ell<_vtype, _itype>* input)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL(ValueType, IndexType)


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(batch_ell,
Expand Down
12 changes: 12 additions & 0 deletions dpcpp/matrix/batch_ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);


template <typename ValueType, typename IndexType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const array<ValueType>* col_scale, const array<ValueType>* row_scale,
batch::matrix::Ell<ValueType, IndexType>* input)
{
GKO_NOT_IMPLEMENTED;
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);


} // namespace batch_ell
} // namespace dpcpp
} // namespace kernels
Expand Down
27 changes: 27 additions & 0 deletions omp/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,33 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);


template <typename ValueType, typename IndexType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const array<ValueType>* col_scale, const array<ValueType>* row_scale,
batch::matrix::Ell<ValueType, IndexType>* input)
{
const auto col_scale_vals = col_scale->get_const_data();
const auto row_scale_vals = row_scale->get_const_data();
auto input_vals = input->get_values();
const auto num_rows = static_cast<int>(input->get_common_size()[0]);
const auto num_cols = static_cast<int>(input->get_common_size()[1]);
const auto stride = input->get_common_size()[1];
const auto mat_ub = host::get_batch_struct(input);
#pragma omp parallel for
for (size_type batch_id = 0; batch_id < input->get_num_batch_items();
++batch_id) {
const auto col_scale_b = col_scale_vals + num_cols * batch_id;
const auto row_scale_b = row_scale_vals + num_rows * batch_id;
const auto mat_item =
batch::matrix::extract_batch_item(mat_ub, batch_id);
scale(col_scale_b, row_scale_b, mat_item);
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);


} // namespace batch_ell
} // namespace omp
} // namespace kernels
Expand Down
26 changes: 26 additions & 0 deletions reference/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,32 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);


template <typename ValueType, typename IndexType>
void scale(std::shared_ptr<const DefaultExecutor> exec,
const array<ValueType>* col_scale, const array<ValueType>* row_scale,
batch::matrix::Ell<ValueType, IndexType>* input)
{
const auto col_scale_vals = col_scale->get_const_data();
const auto row_scale_vals = row_scale->get_const_data();
auto input_vals = input->get_values();
const auto num_rows = static_cast<int>(input->get_common_size()[0]);
const auto num_cols = static_cast<int>(input->get_common_size()[1]);
const auto stride = input->get_common_size()[1];
const auto mat_ub = host::get_batch_struct(input);
for (size_type batch_id = 0; batch_id < input->get_num_batch_items();
++batch_id) {
const auto col_scale_b = col_scale_vals + num_cols * batch_id;
const auto row_scale_b = row_scale_vals + num_rows * batch_id;
const auto mat_item =
batch::matrix::extract_batch_item(mat_ub, batch_id);
scale(col_scale_b, row_scale_b, mat_item);
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);


} // namespace batch_ell
} // namespace reference
} // namespace kernels
Expand Down
16 changes: 10 additions & 6 deletions reference/matrix/batch_ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,19 @@ inline void advanced_apply_kernel(


template <typename ValueType, typename IndexType>
inline void add_scaled_identity_kernel(
const ValueType alpha, const ValueType beta,
inline void scale(
const ValueType* const col_scale, const ValueType* const row_scale,
const gko::batch::matrix::ell::batch_item<ValueType, IndexType>& mat)
{
for (int row = 0; row < mat.num_rows; row++) {
for (int k = 0; k < mat.num_stored_elems_per_row; k++) {
mat.values[row + k * mat.stride] *= beta;
if (row == mat.col_idxs[row + k * mat.stride]) {
mat.values[row + k * mat.stride] += alpha;
const ValueType r_scalar = row_scale[row];
for (auto k = 0; k < mat.num_stored_elems_per_row; ++k) {
auto col_idx = mat.col_idxs[row + mat.stride * k];
if (col_idx == invalid_index<IndexType>()) {
break;
} else {
mat.values[row + mat.stride * k] *=
r_scalar * col_scale[col_idx];
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions reference/test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,27 @@ TYPED_TEST(Ell, ConstAppliesLinearCombinationToBatchMultiVector)
}


TYPED_TEST(Ell, CanTwoSidedScale)
{
using value_type = typename TestFixture::value_type;
using index_type = gko::int32;
using BMtx = typename TestFixture::BMtx;
auto col_scale = gko::array<value_type>(this->exec, 3 * 2);
auto row_scale = gko::array<value_type>(this->exec, 2 * 2);
col_scale.fill(2);
row_scale.fill(3);

gko::batch::matrix::two_sided_scale<value_type, index_type>(
col_scale, row_scale, this->mtx_0.get());

auto scaled_mtx_0 =
gko::batch::initialize<BMtx>({{{6.0, -6.0, 9.0}, {-12.0, 12.0, 18.0}},
{{6.0, -12.0, -3.0}, {6.0, -15.0, 24.0}}},
this->exec);
GKO_ASSERT_BATCH_MTX_NEAR(this->mtx_0.get(), scaled_mtx_0.get(), 0.);
}


TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols)
{
using BMVec = typename TestFixture::BMVec;
Expand Down
23 changes: 23 additions & 0 deletions test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ class Ell : public CommonTestFixture {
dy = gko::clone(exec, y);
dalpha = gko::clone(exec, alpha);
dbeta = gko::clone(exec, beta);
row_scale = gko::test::generate_random_array<value_type>(
num_rows * batch_size, std::normal_distribution<>(2.0, 0.5),
rand_engine, ref);
col_scale = gko::test::generate_random_array<value_type>(
num_cols * batch_size, std::normal_distribution<>(4.0, 0.5),
rand_engine, ref);
drow_scale = gko::array<value_type>(exec, row_scale);
dcol_scale = gko::array<value_type>(exec, col_scale);
expected = BMVec::create(
ref,
gko::batch_dim<2>(batch_size, gko::dim<2>{num_rows, num_vecs}));
Expand All @@ -90,6 +98,10 @@ class Ell : public CommonTestFixture {
std::unique_ptr<BMVec> dy;
std::unique_ptr<BMVec> dalpha;
std::unique_ptr<BMVec> dbeta;
gko::array<value_type> row_scale;
gko::array<value_type> col_scale;
gko::array<value_type> drow_scale;
gko::array<value_type> dcol_scale;
};


Expand All @@ -113,3 +125,14 @@ TEST_F(Ell, SingleVectorAdvancedApplyIsEquivalentToRef)

GKO_ASSERT_BATCH_MTX_NEAR(dresult, expected, r<value_type>::value);
}


TEST_F(Ell, TwoSidedScaleIsEquivalentToRef)
{
set_up_apply_data(257);

gko::batch::matrix::two_sided_scale(col_scale, row_scale, mat.get());
gko::batch::matrix::two_sided_scale(dcol_scale, drow_scale, dmat.get());

GKO_ASSERT_BATCH_MTX_NEAR(dmat, mat, r<value_type>::value);
}

0 comments on commit f9a3443

Please sign in to comment.