diff --git a/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc b/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc index 50be43172b8..966f49e8638 100644 --- a/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc +++ b/common/cuda_hip/matrix/batch_ell_kernel_launcher.hpp.inc @@ -48,3 +48,14 @@ void advanced_apply(std::shared_ptr exec, GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); + + +template +void add_scaled_identity(std::shared_ptr exec, + const batch::MultiVector* alpha, + const batch::MultiVector* beta, + batch::matrix::Ell* mat) + GKO_NOT_IMPLEMENTED; + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( + GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 6d941b9c947..6aaedad2e14 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -314,6 +314,7 @@ namespace batch_ell { GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL); GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); +GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); } // namespace batch_ell diff --git a/core/matrix/batch_ell.cpp b/core/matrix/batch_ell.cpp index d52d3b5297a..6e152c7dc7e 100644 --- a/core/matrix/batch_ell.cpp +++ b/core/matrix/batch_ell.cpp @@ -30,6 +30,7 @@ namespace { GKO_REGISTER_OPERATION(simple_apply, batch_ell::simple_apply); GKO_REGISTER_OPERATION(advanced_apply, batch_ell::advanced_apply); +GKO_REGISTER_OPERATION(add_scaled_identity, batch_ell::add_scaled_identity); } // namespace @@ -182,6 +183,20 @@ void Ell::apply_impl(const MultiVector* alpha, } +template +void Ell::add_scaled_identity( + ptr_param> alpha, + ptr_param> beta) +{ + GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(alpha, beta); + GKO_ASSERT_BATCH_EQUAL_NUM_ITEMS(this, beta); + auto exec = this->get_executor(); + exec->run(ell::make_add_scaled_identity( + make_temporary_clone(exec, alpha).get(), + make_temporary_clone(exec, beta).get(), this)); +} + + template void Ell::convert_to( Ell, IndexType>* result) const diff --git a/core/matrix/batch_ell_kernels.hpp b/core/matrix/batch_ell_kernels.hpp index 3c41a80e951..6f32a3f4a55 100644 --- a/core/matrix/batch_ell_kernels.hpp +++ b/core/matrix/batch_ell_kernels.hpp @@ -35,11 +35,19 @@ namespace kernels { const batch::MultiVector<_vtype>* beta, \ batch::MultiVector<_vtype>* c) -#define GKO_DECLARE_ALL_AS_TEMPLATES \ - template \ - GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL(ValueType, IndexType); \ - template \ - GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL(ValueType, IndexType) +#define GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL(_vtype, _itype) \ + void add_scaled_identity(std::shared_ptr exec, \ + const batch::MultiVector<_vtype>* alpha, \ + const batch::MultiVector<_vtype>* beta, \ + batch::matrix::Ell<_vtype, _itype>* mat) + +#define GKO_DECLARE_ALL_AS_TEMPLATES \ + template \ + GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL(ValueType, IndexType); \ + template \ + GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL(ValueType, IndexType); \ + template \ + GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL(ValueType, IndexType) GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(batch_ell, diff --git a/dpcpp/matrix/batch_ell_kernels.dp.cpp b/dpcpp/matrix/batch_ell_kernels.dp.cpp index d52a6d4d627..8d2e6bc826d 100644 --- a/dpcpp/matrix/batch_ell_kernels.dp.cpp +++ b/dpcpp/matrix/batch_ell_kernels.dp.cpp @@ -142,6 +142,17 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); +template +void add_scaled_identity(std::shared_ptr exec, + const batch::MultiVector* alpha, + const batch::MultiVector* beta, + batch::matrix::Ell* mat) + GKO_NOT_IMPLEMENTED; + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( + GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); + + } // namespace batch_ell } // namespace dpcpp } // namespace kernels diff --git a/include/ginkgo/core/matrix/batch_ell.hpp b/include/ginkgo/core/matrix/batch_ell.hpp index 627c5b7fd6e..752015fb394 100644 --- a/include/ginkgo/core/matrix/batch_ell.hpp +++ b/include/ginkgo/core/matrix/batch_ell.hpp @@ -283,6 +283,14 @@ class Ell final ptr_param> beta, ptr_param> x) const; + /** + * Performs the operation a = alpha*I + beta*a. + * + * Performs the operation in-place for this batch matrix + */ + void add_scaled_identity(ptr_param> alpha, + ptr_param> beta); + private: size_type compute_num_elems(const batch_dim<2>& size, IndexType num_elems_per_row) diff --git a/omp/matrix/batch_ell_kernels.cpp b/omp/matrix/batch_ell_kernels.cpp index 71bc3a6d87f..33a2d985856 100644 --- a/omp/matrix/batch_ell_kernels.cpp +++ b/omp/matrix/batch_ell_kernels.cpp @@ -83,6 +83,29 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); +template +void add_scaled_identity(std::shared_ptr exec, + const batch::MultiVector* alpha, + const batch::MultiVector* beta, + batch::matrix::Ell* mat) +{ + const auto mat_ub = host::get_batch_struct(mat); + const auto alpha_ub = host::get_batch_struct(alpha); + const auto beta_ub = host::get_batch_struct(beta); +#pragma omp parallel for + for (size_type batch_id = 0; batch_id < mat->get_num_batch_items(); + ++batch_id) { + const auto alpha_b = batch::extract_batch_item(alpha_ub, batch_id); + const auto beta_b = batch::extract_batch_item(beta_ub, batch_id); + const auto mat_b = batch::matrix::extract_batch_item(mat_ub, batch_id); + add_scaled_identity_kernel(alpha_b.values[0], beta_b.values[0], mat_b); + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( + GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); + + } // namespace batch_ell } // namespace omp } // namespace kernels diff --git a/reference/matrix/batch_ell_kernels.cpp b/reference/matrix/batch_ell_kernels.cpp index 932068f9bec..500f566d944 100644 --- a/reference/matrix/batch_ell_kernels.cpp +++ b/reference/matrix/batch_ell_kernels.cpp @@ -81,6 +81,28 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL); +template +void add_scaled_identity(std::shared_ptr exec, + const batch::MultiVector* alpha, + const batch::MultiVector* beta, + batch::matrix::Ell* mat) +{ + const auto mat_ub = host::get_batch_struct(mat); + const auto alpha_ub = host::get_batch_struct(alpha); + const auto beta_ub = host::get_batch_struct(beta); + for (size_type batch_id = 0; batch_id < mat->get_num_batch_items(); + ++batch_id) { + const auto alpha_b = batch::extract_batch_item(alpha_ub, batch_id); + const auto beta_b = batch::extract_batch_item(beta_ub, batch_id); + const auto mat_b = batch::matrix::extract_batch_item(mat_ub, batch_id); + add_scaled_identity_kernel(alpha_b.values[0], beta_b.values[0], mat_b); + } +} + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE( + GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); + + } // namespace batch_ell } // namespace reference } // namespace kernels diff --git a/reference/matrix/batch_ell_kernels.hpp.inc b/reference/matrix/batch_ell_kernels.hpp.inc index 3389a90c5f5..29cbbbecc4e 100644 --- a/reference/matrix/batch_ell_kernels.hpp.inc +++ b/reference/matrix/batch_ell_kernels.hpp.inc @@ -50,3 +50,19 @@ inline void advanced_apply_kernel( } } } + + +template +inline void add_scaled_identity_kernel( + const ValueType alpha, const ValueType beta, + const gko::batch::matrix::ell::batch_item& mat) +{ + for (int row = 0; row < mat.num_rows; row++) { + for (int k = 0; k < mat.num_stored_elems_per_row; k++) { + mat.values[row + k * mat.stride] *= beta; + if (row == mat.col_idxs[row + k * mat.stride]) { + mat.values[row + k * mat.stride] += alpha; + } + } + } +} diff --git a/reference/test/matrix/batch_csr_kernels.cpp b/reference/test/matrix/batch_csr_kernels.cpp index e75a8ac8955..6c1d784a208 100644 --- a/reference/test/matrix/batch_csr_kernels.cpp +++ b/reference/test/matrix/batch_csr_kernels.cpp @@ -186,6 +186,29 @@ TYPED_TEST(Csr, CanTwoSidedScale) } +TYPED_TEST(Csr, CanAddScaledIdentity) +{ + using value_type = typename TestFixture::value_type; + using index_type = gko::int32; + using BMtx = typename TestFixture::BMtx; + using BMVec = typename TestFixture::BMVec; + auto alpha = gko::batch::initialize({{2.0}, {-1.0}}, this->exec); + auto beta = gko::batch::initialize({{3.0}, {-2.0}}, this->exec); + auto mat = gko::batch::initialize( + {{{1.0, 2.0, 0.0}, {3.0, 1.0, 1.0}, {0.0, 1.0, 1.0}}, + {{2.0, -2.0, 0.0}, {1.0, -1.0, 2.0}, {0.0, 2.0, 1.0}}}, + this->exec, 7); + + mat->add_scaled_identity(alpha, beta); + + auto result_mat = gko::batch::initialize( + {{{5.0, 6.0, 0.0}, {9.0, 5.0, 3.0}, {0.0, 3.0, 5.0}}, + {{-5.0, 4.0, 0.0}, {-2.0, 1.0, -4.0}, {0.0, -4.0, -3.0}}}, + this->exec, 7); + GKO_ASSERT_BATCH_MTX_NEAR(mat.get(), result_mat.get(), 0.); +} + + TYPED_TEST(Csr, ApplyFailsOnWrongNumberOfResultCols) { using BMVec = typename TestFixture::BMVec; diff --git a/reference/test/matrix/batch_dense_kernels.cpp b/reference/test/matrix/batch_dense_kernels.cpp index bc8449b03e4..23ab5e1d893 100644 --- a/reference/test/matrix/batch_dense_kernels.cpp +++ b/reference/test/matrix/batch_dense_kernels.cpp @@ -144,6 +144,29 @@ TYPED_TEST(Dense, CanTwoSidedScale) } +TYPED_TEST(Dense, CanAddScaledIdentity) +{ + using value_type = typename TestFixture::value_type; + using index_type = gko::int32; + using BMtx = typename TestFixture::BMtx; + using BMVec = typename TestFixture::BMVec; + auto alpha = gko::batch::initialize({{2.0}, {-1.0}}, this->exec); + auto beta = gko::batch::initialize({{3.0}, {-2.0}}, this->exec); + auto mat = gko::batch::initialize( + {{{1.0, 2.0, 0.0}, {3.0, 1.0, 1.0}, {0.0, 1.0, 1.0}}, + {{2.0, -2.0, 0.0}, {1.0, -1.0, 2.0}, {0.0, 2.0, 1.0}}}, + this->exec); + + mat->add_scaled_identity(alpha, beta); + + auto result_mat = gko::batch::initialize( + {{{5.0, 6.0, 0.0}, {9.0, 5.0, 3.0}, {0.0, 3.0, 5.0}}, + {{-5.0, 4.0, 0.0}, {-2.0, 1.0, -4.0}, {0.0, -4.0, -3.0}}}, + this->exec); + GKO_ASSERT_BATCH_MTX_NEAR(mat.get(), result_mat.get(), 0.); +} + + TYPED_TEST(Dense, ApplyFailsOnWrongNumberOfResultCols) { using BMVec = typename TestFixture::BMVec; diff --git a/reference/test/matrix/batch_ell_kernels.cpp b/reference/test/matrix/batch_ell_kernels.cpp index 7adf6c2e443..e106130f0a1 100644 --- a/reference/test/matrix/batch_ell_kernels.cpp +++ b/reference/test/matrix/batch_ell_kernels.cpp @@ -167,6 +167,29 @@ TYPED_TEST(Ell, ConstAppliesLinearCombinationToBatchMultiVector) } +TYPED_TEST(Ell, CanAddScaledIdentity) +{ + using value_type = typename TestFixture::value_type; + using index_type = gko::int32; + using BMtx = typename TestFixture::BMtx; + using BMVec = typename TestFixture::BMVec; + auto alpha = gko::batch::initialize({{2.0}, {-1.0}}, this->exec); + auto beta = gko::batch::initialize({{3.0}, {-2.0}}, this->exec); + auto mat = gko::batch::initialize( + {{{1.0, 2.0, 0.0}, {0.0, 1.0, 1.0}, {0.0, 1.0, 1.0}}, + {{2.0, -2.0, 0.0}, {0.0, -1.0, 2.0}, {0.0, 2.0, 1.0}}}, + this->exec, 2); + + mat->add_scaled_identity(alpha, beta); + + auto result_mat = gko::batch::initialize( + {{{5.0, 6.0, 0.0}, {0.0, 5.0, 3.0}, {0.0, 3.0, 5.0}}, + {{-5.0, 4.0, 0.0}, {0.0, 1.0, -4.0}, {0.0, -4.0, -3.0}}}, + this->exec, 2); + GKO_ASSERT_BATCH_MTX_NEAR(mat.get(), result_mat.get(), 0.); +} + + TYPED_TEST(Ell, ApplyFailsOnWrongNumberOfResultCols) { using BMVec = typename TestFixture::BMVec;