From 2ac95112ff36532f2d572890c3dcea751a905741 Mon Sep 17 00:00:00 2001 From: Pratik Nayak Date: Tue, 10 Oct 2023 17:15:48 +0200 Subject: [PATCH] Some general fixes --- core/matrix/batch_ell.cpp | 16 +----- core/matrix/batch_struct.hpp | 2 +- cuda/matrix/batch_struct.hpp | 16 +++--- dpcpp/matrix/batch_ell_kernels.dp.cpp | 58 ++++++++++----------- dpcpp/matrix/batch_struct.hpp | 16 +++--- hip/matrix/batch_struct.hip.hpp | 16 +++--- include/ginkgo/core/matrix/batch_ell.hpp | 8 --- reference/matrix/batch_struct.hpp | 16 +++--- reference/test/matrix/batch_ell_kernels.cpp | 50 +++++++++--------- test/test_install/test_install.cpp | 9 +++- 10 files changed, 97 insertions(+), 110 deletions(-) diff --git a/core/matrix/batch_ell.cpp b/core/matrix/batch_ell.cpp index 0d903b10968..d33270edb11 100644 --- a/core/matrix/batch_ell.cpp +++ b/core/matrix/batch_ell.cpp @@ -104,22 +104,10 @@ template std::unique_ptr> Ell::create_with_config_of( ptr_param> other) -{ - // De-referencing `other` before calling the functions (instead of - // using operator `->`) is currently required to be compatible with - // CUDA 10.1. - // Otherwise, it results in a compile error. - return (*other).create_with_same_config(); -} - - -template -std::unique_ptr> -Ell::create_with_same_config() const { return Ell::create( - this->get_executor(), this->get_size(), - this->get_num_stored_elements_per_row()); + other->get_executor(), other->get_size(), + other->get_num_stored_elements_per_row()); } diff --git a/core/matrix/batch_struct.hpp b/core/matrix/batch_struct.hpp index 2eed40882bc..6d170f393b5 100644 --- a/core/matrix/batch_struct.hpp +++ b/core/matrix/batch_struct.hpp @@ -109,7 +109,7 @@ struct batch_item { template struct uniform_batch { using value_type = ValueType; - using index_type = int; + using index_type = int32; using entry_type = batch_item; ValueType* values; diff --git a/cuda/matrix/batch_struct.hpp b/cuda/matrix/batch_struct.hpp index 3feb8ed653a..e11acc02279 100644 --- a/cuda/matrix/batch_struct.hpp +++ b/cuda/matrix/batch_struct.hpp @@ -100,10 +100,10 @@ get_batch_struct(const batch::matrix::Ell* const op) return {as_cuda_type(op->get_const_values()), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -117,10 +117,10 @@ get_batch_struct(batch::matrix::Ell* const op) return {as_cuda_type(op->get_values()), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index 1d1210cc270..70cfa12f9ee 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -98,19 +98,19 @@ void simple_apply(std::shared_ptr exec, } // Launch a kernel that has nbatches blocks, each block has max group size - (exec->get_queue())->submit([&](sycl::handler& cgh) { + exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - simple_apply_kernel(mat_b, b_b, x_b, item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + simple_apply_kernel(mat_b, b_b, x_b, item_ct1); + }); }); } @@ -145,24 +145,24 @@ void advanced_apply(std::shared_ptr exec, const dim3 grid(num_batch_items); // Launch a kernel that has nbatches blocks, each block has max group size - (exec->get_queue())->submit([&](sycl::handler& cgh) { + exec->get_queue()->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl_nd_range(grid, block), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(config::warp_size)]] { - auto group = item_ct1.get_group(); - auto group_id = group.get_group_linear_id(); - const auto mat_b = - batch::matrix::extract_batch_item(mat_ub, group_id); - const auto b_b = batch::extract_batch_item(b_ub, group_id); - const auto x_b = batch::extract_batch_item(x_ub, group_id); - const auto alpha_b = - batch::extract_batch_item(alpha_ub, group_id); - const auto beta_b = - batch::extract_batch_item(beta_ub, group_id); - advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, - item_ct1); - }); + sycl_nd_range(grid, block), [= + ](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + config::warp_size)]] { + auto group = item_ct1.get_group(); + auto group_id = group.get_group_linear_id(); + const auto mat_b = + batch::matrix::extract_batch_item(mat_ub, group_id); + const auto b_b = batch::extract_batch_item(b_ub, group_id); + const auto x_b = batch::extract_batch_item(x_ub, group_id); + const auto alpha_b = + batch::extract_batch_item(alpha_ub, group_id); + const auto beta_b = + batch::extract_batch_item(beta_ub, group_id); + advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b, + item_ct1); + }); }); } diff --git a/dpcpp/matrix/batch_struct.hpp b/dpcpp/matrix/batch_struct.hpp index 667085d354e..62af28bc06b 100644 --- a/dpcpp/matrix/batch_struct.hpp +++ b/dpcpp/matrix/batch_struct.hpp @@ -98,10 +98,10 @@ get_batch_struct(const batch::matrix::Ell* const op) return {op->get_const_values(), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -115,10 +115,10 @@ inline batch::matrix::batch_ell::uniform_batch get_batch_struct( return {op->get_values(), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/hip/matrix/batch_struct.hip.hpp b/hip/matrix/batch_struct.hip.hpp index db3f6e70182..b9f81a67ea5 100644 --- a/hip/matrix/batch_struct.hip.hpp +++ b/hip/matrix/batch_struct.hip.hpp @@ -100,10 +100,10 @@ get_batch_struct(const batch::matrix::Ell* const op) return {as_hip_type(op->get_const_values()), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -117,10 +117,10 @@ get_batch_struct(batch::matrix::Ell* const op) return {as_hip_type(op->get_values()), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/include/ginkgo/core/matrix/batch_ell.hpp b/include/ginkgo/core/matrix/batch_ell.hpp index 5cb5f73dec5..6f3db1bb96b 100644 --- a/include/ginkgo/core/matrix/batch_ell.hpp +++ b/include/ginkgo/core/matrix/batch_ell.hpp @@ -356,14 +356,6 @@ class Ell final col_idxs_.get_num_elems()); } - /** - * Creates a Ell matrix with the same configuration as the callers - * matrix. - * - * @returns a Ell matrix with the same configuration as the caller. - */ - std::unique_ptr create_with_same_config() const; - void apply_impl(const MultiVector* b, MultiVector* x) const; diff --git a/reference/matrix/batch_struct.hpp b/reference/matrix/batch_struct.hpp index b17f2fda103..572bd9e8fef 100644 --- a/reference/matrix/batch_struct.hpp +++ b/reference/matrix/batch_struct.hpp @@ -103,10 +103,10 @@ get_batch_struct(const batch::matrix::Ell* const op) return {op->get_const_values(), op->get_const_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } @@ -120,10 +120,10 @@ inline batch::matrix::batch_ell::uniform_batch get_batch_struct( return {op->get_values(), op->get_col_idxs(), op->get_num_batch_items(), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[0]), - static_cast(op->get_common_size()[1]), - static_cast(op->get_num_stored_elements_per_row())}; + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[0]), + static_cast(op->get_common_size()[1]), + static_cast(op->get_num_stored_elements_per_row())}; } diff --git a/reference/test/matrix/batch_ell_kernels.cpp b/reference/test/matrix/batch_ell_kernels.cpp index 76b681c69f7..c3ff33ad7b1 100644 --- a/reference/test/matrix/batch_ell_kernels.cpp +++ b/reference/test/matrix/batch_ell_kernels.cpp @@ -59,7 +59,7 @@ class Ell : public ::testing::Test { using value_type = T; using size_type = gko::size_type; using Mtx = gko::batch::matrix::Ell; - using MVec = gko::batch::MultiVector; + using BMVec = gko::batch::MultiVector; using EllMtx = gko::matrix::Ell; using DenseMtx = gko::matrix::Dense; using ComplexMtx = gko::to_complex; @@ -74,7 +74,7 @@ class Ell : public ::testing::Test { {I({1.0, -1.0, 1.5}), I({-2.0, 2.0, 3.0})}, exec)), mtx_01(gko::initialize( {I({1.0, -2.0, -0.5}), I({1.0, -2.5, 4.0})}, exec)), - b_0(gko::batch::initialize( + b_0(gko::batch::initialize( {{I({1.0, 0.0, 1.0}), I({2.0, 0.0, 1.0}), I({1.0, 0.0, 2.0})}, {I({-1.0, 1.0, 1.0}), I({1.0, -1.0, 1.0}), @@ -88,7 +88,7 @@ class Ell : public ::testing::Test { {I({-1.0, 1.0, 1.0}), I({1.0, -1.0, 1.0}), I({1.0, 0.0, 2.0})}, exec)), - x_0(gko::batch::initialize( + x_0(gko::batch::initialize( {{I({2.0, 0.0, 1.0}), I({2.0, 0.0, 2.0})}, {I({-2.0, 1.0, 1.0}), I({1.0, -1.0, -1.0})}}, exec)), @@ -102,10 +102,10 @@ class Ell : public ::testing::Test { std::unique_ptr mtx_0; std::unique_ptr mtx_00; std::unique_ptr mtx_01; - std::unique_ptr b_0; + std::unique_ptr b_0; std::unique_ptr b_00; std::unique_ptr b_01; - std::unique_ptr x_0; + std::unique_ptr x_0; std::unique_ptr x_00; std::unique_ptr x_01; @@ -134,11 +134,11 @@ TYPED_TEST(Ell, AppliesToBatchMultiVector) TYPED_TEST(Ell, AppliesLinearCombinationWithSameAlphaToBatchMultiVector) { using Mtx = typename TestFixture::Mtx; - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; using DenseMtx = typename TestFixture::DenseMtx; using T = typename TestFixture::value_type; - auto alpha = gko::batch::initialize(2, {1.5}, this->exec); - auto beta = gko::batch::initialize(2, {-4.0}, this->exec); + auto alpha = gko::batch::initialize(2, {1.5}, this->exec); + auto beta = gko::batch::initialize(2, {-4.0}, this->exec); auto alpha0 = gko::initialize({1.5}, this->exec); auto alpha1 = gko::initialize({1.5}, this->exec); auto beta0 = gko::initialize({-4.0}, this->exec); @@ -161,11 +161,11 @@ TYPED_TEST(Ell, AppliesLinearCombinationWithSameAlphaToBatchMultiVector) TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) { using Mtx = typename TestFixture::Mtx; - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; using DenseMtx = typename TestFixture::DenseMtx; using T = typename TestFixture::value_type; - auto alpha = gko::batch::initialize({{1.5}, {-1.0}}, this->exec); - auto beta = gko::batch::initialize({{2.5}, {-4.0}}, this->exec); + auto alpha = gko::batch::initialize({{1.5}, {-1.0}}, this->exec); + auto beta = gko::batch::initialize({{2.5}, {-4.0}}, this->exec); auto alpha0 = gko::initialize({1.5}, this->exec); auto alpha1 = gko::initialize({-1.0}, this->exec); auto beta0 = gko::initialize({2.5}, this->exec); @@ -187,8 +187,8 @@ TYPED_TEST(Ell, AppliesLinearCombinationToBatchMultiVector) TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols) { - using MVec = typename TestFixture::MVec; - auto res = MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2}}); + using BMVec = typename TestFixture::BMVec; + auto res = BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2}}); ASSERT_THROW(this->mtx_0->apply(this->b_0.get(), res.get()), gko::DimensionMismatch); @@ -197,8 +197,8 @@ TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols) TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultRows) { - using MVec = typename TestFixture::MVec; - auto res = MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3}}); + using BMVec = typename TestFixture::BMVec; + auto res = BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3}}); ASSERT_THROW(this->mtx_0->apply(this->b_0.get(), res.get()), gko::DimensionMismatch); @@ -207,9 +207,9 @@ TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultRows) TYPED_TEST(Ell, ApplyFailsOnWrongInnerDimension) { - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; auto res = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); ASSERT_THROW(this->mtx_0->apply(res.get(), this->x_0.get()), gko::DimensionMismatch); @@ -218,13 +218,13 @@ TYPED_TEST(Ell, ApplyFailsOnWrongInnerDimension) TYPED_TEST(Ell, AdvancedApplyFailsOnWrongInnerDimension) { - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; auto res = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 3}}); auto alpha = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); auto beta = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); ASSERT_THROW( this->mtx_0->apply(alpha.get(), res.get(), beta.get(), this->x_0.get()), @@ -234,13 +234,13 @@ TYPED_TEST(Ell, AdvancedApplyFailsOnWrongInnerDimension) TYPED_TEST(Ell, AdvancedApplyFailsOnWrongAlphaDimension) { - using MVec = typename TestFixture::MVec; + using BMVec = typename TestFixture::BMVec; auto res = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3, 3}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{3, 3}}); auto alpha = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{2, 1}}); auto beta = - MVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); + BMVec::create(this->exec, gko::batch_dim<2>{2, gko::dim<2>{1, 1}}); ASSERT_THROW( this->mtx_0->apply(alpha.get(), res.get(), beta.get(), this->x_0.get()), diff --git a/test/test_install/test_install.cpp b/test/test_install/test_install.cpp index 325773f0b75..565dd36ddc5 100644 --- a/test/test_install/test_install.cpp +++ b/test/test_install/test_install.cpp @@ -219,13 +219,20 @@ int main() auto test = batch_multi_vector_type::create(exec); } - // core/base/batch_dense.hpp + // core/matrix/batch_dense.hpp { using type1 = float; using batch_dense_type = gko::batch::Dense; auto test = batch_dense_type::create(exec); } + // core/matrix/batch_ell.hpp + { + using type1 = float; + using batch_ell_type = gko::batch::Ell; + auto test = batch_ell_type::create(exec); + } + // core/base/combination.hpp { using type1 = int;