Skip to content

Commit

Permalink
Some general fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Oct 10, 2023
1 parent d8bd424 commit 2ac9511
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 110 deletions.
16 changes: 2 additions & 14 deletions core/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,10 @@ template <typename ValueType, typename IndexType>
std::unique_ptr<Ell<ValueType, IndexType>>
Ell<ValueType, IndexType>::create_with_config_of(
ptr_param<const Ell<ValueType, IndexType>> other)
{
// De-referencing `other` before calling the functions (instead of
// using operator `->`) is currently required to be compatible with
// CUDA 10.1.
// Otherwise, it results in a compile error.
return (*other).create_with_same_config();
}


template <typename ValueType, typename IndexType>
std::unique_ptr<Ell<ValueType, IndexType>>
Ell<ValueType, IndexType>::create_with_same_config() const
{
return Ell<ValueType, IndexType>::create(
this->get_executor(), this->get_size(),
this->get_num_stored_elements_per_row());
other->get_executor(), other->get_size(),
other->get_num_stored_elements_per_row());
}


Expand Down
2 changes: 1 addition & 1 deletion core/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct batch_item {
template <typename ValueType>
struct uniform_batch {
using value_type = ValueType;
using index_type = int;
using index_type = int32;
using entry_type = batch_item<value_type>;

ValueType* values;
Expand Down
16 changes: 8 additions & 8 deletions cuda/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ get_batch_struct(const batch::matrix::Ell<ValueType, int32>* const op)
return {as_cuda_type(op->get_const_values()),
op->get_const_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand All @@ -117,10 +117,10 @@ get_batch_struct(batch::matrix::Ell<ValueType, int32>* const op)
return {as_cuda_type(op->get_values()),
op->get_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand Down
58 changes: 29 additions & 29 deletions dpcpp/matrix/batch_ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}

// Launch a kernel that has nbatches blocks, each block has max group size
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
simple_apply_kernel(mat_b, b_b, x_b, item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
simple_apply_kernel(mat_b, b_b, x_b, item_ct1);
});
});
}

Expand Down Expand Up @@ -145,24 +145,24 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
const dim3 grid(num_batch_items);

// Launch a kernel that has nbatches blocks, each block has max group size
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto beta_b =
batch::extract_batch_item(beta_ub, group_id);
advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b,
item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto beta_b =
batch::extract_batch_item(beta_ub, group_id);
advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b,
item_ct1);
});
});
}

Expand Down
16 changes: 8 additions & 8 deletions dpcpp/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ get_batch_struct(const batch::matrix::Ell<ValueType, int32>* const op)
return {op->get_const_values(),
op->get_const_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand All @@ -115,10 +115,10 @@ inline batch::matrix::batch_ell::uniform_batch<ValueType> get_batch_struct(
return {op->get_values(),
op->get_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand Down
16 changes: 8 additions & 8 deletions hip/matrix/batch_struct.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ get_batch_struct(const batch::matrix::Ell<ValueType, int32>* const op)
return {as_hip_type(op->get_const_values()),
op->get_const_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand All @@ -117,10 +117,10 @@ get_batch_struct(batch::matrix::Ell<ValueType, int32>* const op)
return {as_hip_type(op->get_values()),
op->get_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand Down
8 changes: 0 additions & 8 deletions include/ginkgo/core/matrix/batch_ell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,6 @@ class Ell final
col_idxs_.get_num_elems());
}

/**
* Creates a Ell matrix with the same configuration as the callers
* matrix.
*
* @returns a Ell matrix with the same configuration as the caller.
*/
std::unique_ptr<Ell> create_with_same_config() const;

void apply_impl(const MultiVector<value_type>* b,
MultiVector<value_type>* x) const;

Expand Down
16 changes: 8 additions & 8 deletions reference/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ get_batch_struct(const batch::matrix::Ell<ValueType, int32>* const op)
return {op->get_const_values(),
op->get_const_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand All @@ -120,10 +120,10 @@ inline batch::matrix::batch_ell::uniform_batch<ValueType> get_batch_struct(
return {op->get_values(),
op->get_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand Down
50 changes: 25 additions & 25 deletions reference/test/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Ell : public ::testing::Test {
using value_type = T;
using size_type = gko::size_type;
using Mtx = gko::batch::matrix::Ell<value_type>;
using MVec = gko::batch::MultiVector<value_type>;
using BMVec = gko::batch::MultiVector<value_type>;
using EllMtx = gko::matrix::Ell<value_type>;
using DenseMtx = gko::matrix::Dense<value_type>;
using ComplexMtx = gko::to_complex<Mtx>;
Expand All @@ -74,7 +74,7 @@ class Ell : public ::testing::Test {
{I<T>({1.0, -1.0, 1.5}), I<T>({-2.0, 2.0, 3.0})}, exec)),
mtx_01(gko::initialize<EllMtx>(
{I<T>({1.0, -2.0, -0.5}), I<T>({1.0, -2.5, 4.0})}, exec)),
b_0(gko::batch::initialize<MVec>(
b_0(gko::batch::initialize<BMVec>(
{{I<T>({1.0, 0.0, 1.0}), I<T>({2.0, 0.0, 1.0}),
I<T>({1.0, 0.0, 2.0})},
{I<T>({-1.0, 1.0, 1.0}), I<T>({1.0, -1.0, 1.0}),
Expand All @@ -88,7 +88,7 @@ class Ell : public ::testing::Test {
{I<T>({-1.0, 1.0, 1.0}), I<T>({1.0, -1.0, 1.0}),
I<T>({1.0, 0.0, 2.0})},
exec)),
x_0(gko::batch::initialize<MVec>(
x_0(gko::batch::initialize<BMVec>(
{{I<T>({2.0, 0.0, 1.0}), I<T>({2.0, 0.0, 2.0})},
{I<T>({-2.0, 1.0, 1.0}), I<T>({1.0, -1.0, -1.0})}},
exec)),
Expand All @@ -102,10 +102,10 @@ class Ell : public ::testing::Test {
std::unique_ptr<Mtx> mtx_0;
std::unique_ptr<EllMtx> mtx_00;
std::unique_ptr<EllMtx> mtx_01;
std::unique_ptr<MVec> b_0;
std::unique_ptr<BMVec> b_0;
std::unique_ptr<DenseMtx> b_00;
std::unique_ptr<DenseMtx> b_01;
std::unique_ptr<MVec> x_0;
std::unique_ptr<BMVec> x_0;
std::unique_ptr<DenseMtx> x_00;
std::unique_ptr<DenseMtx> x_01;

Expand Down Expand Up @@ -134,11 +134,11 @@ TYPED_TEST(Ell, AppliesToBatchMultiVector)
TYPED_TEST(Ell, AppliesLinearCombinationWithSameAlphaToBatchMultiVector)
{
using Mtx = typename TestFixture::Mtx;
using MVec = typename TestFixture::MVec;
using BMVec = typename TestFixture::BMVec;
using DenseMtx = typename TestFixture::DenseMtx;
using T = typename TestFixture::value_type;
auto alpha = gko::batch::initialize<MVec>(2, {1.5}, this->exec);
auto beta = gko::batch::initialize<MVec>(2, {-4.0}, this->exec);
auto alpha = gko::batch::initialize<BMVec>(2, {1.5}, this->exec);
auto beta = gko::batch::initialize<BMVec>(2, {-4.0}, this->exec);
auto alpha0 = gko::initialize<DenseMtx>({1.5}, this->exec);
auto alpha1 = gko::initialize<DenseMtx>({1.5}, this->exec);
auto beta0 = gko::initialize<DenseMtx>({-4.0}, this->exec);
Expand All @@ -161,11 +161,11 @@ TYPED_TEST(Ell, AppliesLinearCombinationWithSameAlphaToBatchMultiVector)
TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector)
{
using Mtx = typename TestFixture::Mtx;
using MVec = typename TestFixture::MVec;
using BMVec = typename TestFixture::BMVec;
using DenseMtx = typename TestFixture::DenseMtx;
using T = typename TestFixture::value_type;
auto alpha = gko::batch::initialize<MVec>({{1.5}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<MVec>({{2.5}, {-4.0}}, this->exec);
auto alpha = gko::batch::initialize<BMVec>({{1.5}, {-1.0}}, this->exec);
auto beta = gko::batch::initialize<BMVec>({{2.5}, {-4.0}}, this->exec);
auto alpha0 = gko::initialize<DenseMtx>({1.5}, this->exec);
auto alpha1 = gko::initialize<DenseMtx>({-1.0}, this->exec);
auto beta0 = gko::initialize<DenseMtx>({2.5}, this->exec);
Expand All @@ -187,8 +187,8 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector)

TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols)
{
using MVec = typename TestFixture::MVec;
auto res = MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2}});
using BMVec = typename TestFixture::BMVec;
auto res = BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2}});

ASSERT_THROW(this->mtx_0->apply(this->b_0.get(), res.get()),
gko::DimensionMismatch);
Expand All @@ -197,8 +197,8 @@ TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols)

TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultRows)
{
using MVec = typename TestFixture::MVec;
auto res = MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3}});
using BMVec = typename TestFixture::BMVec;
auto res = BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3}});

ASSERT_THROW(this->mtx_0->apply(this->b_0.get(), res.get()),
gko::DimensionMismatch);
Expand All @@ -207,9 +207,9 @@ TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultRows)

TYPED_TEST(Ell, ApplyFailsOnWrongInnerDimension)
{
using MVec = typename TestFixture::MVec;
using BMVec = typename TestFixture::BMVec;
auto res =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}});

ASSERT_THROW(this->mtx_0->apply(res.get(), this->x_0.get()),
gko::DimensionMismatch);
Expand All @@ -218,13 +218,13 @@ TYPED_TEST(Ell, ApplyFailsOnWrongInnerDimension)

TYPED_TEST(Ell, AdvancedApplyFailsOnWrongInnerDimension)
{
using MVec = typename TestFixture::MVec;
using BMVec = typename TestFixture::BMVec;
auto res =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}});
auto alpha =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}});
auto beta =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}});

ASSERT_THROW(
this->mtx_0->apply(alpha.get(), res.get(), beta.get(), this->x_0.get()),
Expand All @@ -234,13 +234,13 @@ TYPED_TEST(Ell, AdvancedApplyFailsOnWrongInnerDimension)

TYPED_TEST(Ell, AdvancedApplyFailsOnWrongAlphaDimension)
{
using MVec = typename TestFixture::MVec;
using BMVec = typename TestFixture::BMVec;
auto res =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3, 3}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3, 3}});
auto alpha =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 1}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 1}});
auto beta =
MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}});
BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}});

ASSERT_THROW(
this->mtx_0->apply(alpha.get(), res.get(), beta.get(), this->x_0.get()),
Expand Down
Loading

0 comments on commit 2ac9511

Please sign in to comment.