From a990e913d645e638a98b9d766fa248f383ddae82 Mon Sep 17 00:00:00 2001 From: Aditya Kashi Date: Wed, 5 Jan 2022 22:21:27 +0100 Subject: [PATCH] refactored code to add scaled identity, introduced new mixin --- core/matrix/csr.cpp | 14 +++++++ core/matrix/dense.cpp | 14 +++++++ core/matrix/identity.cpp | 38 +------------------ .../ginkgo/core/base/exception_helpers.hpp | 20 ++++++++++ include/ginkgo/core/base/lin_op.hpp | 27 ++++++++++++- include/ginkgo/core/matrix/csr.hpp | 5 ++- include/ginkgo/core/matrix/dense.hpp | 5 ++- reference/test/matrix/csr_kernels.cpp | 16 ++++++++ reference/test/matrix/dense_kernels.cpp | 16 ++++++++ 9 files changed, 116 insertions(+), 39 deletions(-) diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index 473ecc9ef2c..0594dfa6a7b 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -51,6 +51,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "core/components/device_matrix_data_kernels.hpp" #include "core/components/fill_array_kernels.hpp" #include "core/components/prefix_sum_kernels.hpp" +#include "core/factorization/factorization_kernels.hpp" #include "core/matrix/csr_kernels.hpp" @@ -99,6 +100,9 @@ GKO_REGISTER_OPERATION(outplace_absolute_array, components::outplace_absolute_array); GKO_REGISTER_OPERATION(scale, csr::scale); GKO_REGISTER_OPERATION(inv_scale, csr::inv_scale); +GKO_REGISTER_OPERATION(add_scaled_identity, csr::add_scaled_identity); +GKO_REGISTER_OPERATION(add_diagonal_elems, + factorization::add_diagonal_elements); } // anonymous namespace @@ -619,6 +623,16 @@ void Csr::inv_scale_impl(const LinOp* alpha) } +template +void Csr::add_scaled_identity_impl(const LinOp* const a, + const LinOp* const b) +{ + this->get_executor()->run(csr::make_add_diagonal_elems(this, false)); + this->get_executor()->run(csr::make_add_scaled_identity( + as>(a), as>(b), this)); +} + + #define GKO_DECLARE_CSR_MATRIX(ValueType, IndexType) \ class Csr GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_MATRIX); diff --git a/core/matrix/dense.cpp b/core/matrix/dense.cpp index 67f53542c07..0fbc3643830 100644 --- a/core/matrix/dense.cpp +++ b/core/matrix/dense.cpp @@ -103,6 +103,7 @@ GKO_REGISTER_OPERATION(outplace_absolute_dense, dense::outplace_absolute_dense); GKO_REGISTER_OPERATION(make_complex, dense::make_complex); GKO_REGISTER_OPERATION(get_real, dense::get_real); GKO_REGISTER_OPERATION(get_imag, dense::get_imag); +GKO_REGISTER_OPERATION(add_scaled_identity, dense::add_scaled_identity); } // anonymous namespace @@ -1276,6 +1277,19 @@ void Dense::get_imag( } +template +void Dense::add_scaled_identity_impl(const LinOp* const a, + const LinOp* const b) +{ + precision_dispatch_real_complex( + [this](auto dense_alpha, auto dense_beta, auto dense_x) { + this->get_executor()->run(dense::make_add_scaled_identity( + dense_alpha, dense_beta, dense_x)); + }, + a, b, this); +} + + #define GKO_DECLARE_DENSE_MATRIX(_type) class Dense<_type> GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_MATRIX); diff --git a/core/matrix/identity.cpp b/core/matrix/identity.cpp index 6b44dc3057c..f30877dbad2 100644 --- a/core/matrix/identity.cpp +++ b/core/matrix/identity.cpp @@ -39,25 +39,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include "core/factorization/factorization_kernels.hpp" -#include "core/matrix/csr_kernels.hpp" -#include "core/matrix/dense_kernels.hpp" - - namespace gko { namespace matrix { -namespace identity { -namespace { - - -GKO_REGISTER_OPERATION(dense_add_scaled_identity, dense::add_scaled_identity); -GKO_REGISTER_OPERATION(csr_add_scaled_identity, csr::add_scaled_identity); -GKO_REGISTER_OPERATION(csr_add_diagonal_elems, - factorization::add_diagonal_elements); - - -} // anonymous namespace -} // namespace identity template @@ -72,25 +55,8 @@ void Identity::apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, LinOp* x) const { if (auto bI = dynamic_cast*>(b)) { - GKO_ASSERT_IS_SQUARE_MATRIX(x); - if (auto xd = dynamic_cast*>(x)) { - precision_dispatch_real_complex( - [this](auto dense_alpha, auto dense_beta, auto dense_x) { - this->get_executor()->run( - identity::make_dense_add_scaled_identity( - dense_alpha, dense_beta, dense_x)); - }, - alpha, beta, x); - } else if (auto xc = dynamic_cast*>(x)) { - this->get_executor()->run( - identity::make_csr_add_diagonal_elems(xc, false)); - this->get_executor()->run(identity::make_csr_add_scaled_identity( - as>(alpha), as>(beta), xc)); - } else if (auto xc = dynamic_cast*>(x)) { - this->get_executor()->run( - identity::make_csr_add_diagonal_elems(xc, false)); - this->get_executor()->run(identity::make_csr_add_scaled_identity( - as>(alpha), as>(beta), xc)); + if (auto xs = dynamic_cast(x)) { + xs->add_scaled_identity(alpha, beta); } else { GKO_NOT_IMPLEMENTED; } diff --git a/include/ginkgo/core/base/exception_helpers.hpp b/include/ginkgo/core/base/exception_helpers.hpp index 315a8855a73..3ffe0578894 100644 --- a/include/ginkgo/core/base/exception_helpers.hpp +++ b/include/ginkgo/core/base/exception_helpers.hpp @@ -671,6 +671,26 @@ inline T ensure_allocated_impl(T ptr, const std::string& file, int line, "semi-colon warnings") +/** + * Checks that the operator is a scalar, ie., has size 1x1. + * + * @param _op Operator to be checked. + * + * @throw BadDimension if _op does not have size 1x1. + */ +#define GKO_ASSERT_IS_SCALAR(_op) \ + { \ + auto sz = gko::detail::get_size(_op); \ + if (sz[0] != 1 || sz[1] != 1) { \ + throw ::gko::BadDimension(__FILE__, __LINE__, __func__, #_op, \ + sz[0], sz[1], "expected scalar"); \ + } \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + } // namespace gko diff --git a/include/ginkgo/core/base/lin_op.hpp b/include/ginkgo/core/base/lin_op.hpp index 6a883164951..2f39979f7f7 100644 --- a/include/ginkgo/core/base/lin_op.hpp +++ b/include/ginkgo/core/base/lin_op.hpp @@ -60,7 +60,7 @@ template class Diagonal; -} +} // namespace matrix /** @@ -767,6 +767,31 @@ class EnableAbsoluteComputation : public AbsoluteComputable { }; +/** + * Mix-in that adds the operation M <- a I + b M for matrix M, identity + * operator I and scalars a and b, where M is the calling object. + */ +class EnableScaledIdentityAddition { +public: + /** + * Scales this and adds another scalar times the identity to it. + * + * @param a Scalar to multiply the identity operator by before adding. + * @param b Scalar to multiply this before adding the scaled identity to + * it. + */ + void add_scaled_identity(const LinOp* const a, const LinOp* const b) + { + GKO_ASSERT_IS_SCALAR(a); + GKO_ASSERT_IS_SCALAR(b); + add_scaled_identity_impl(a, b); + } + +private: + virtual void add_scaled_identity_impl(const LinOp* a, const LinOp* b) = 0; +}; + + /** * The EnableLinOp mixin can be used to provide sensible default implementations * of the majority of the LinOp and PolymorphicObject interface. diff --git a/include/ginkgo/core/matrix/csr.hpp b/include/ginkgo/core/matrix/csr.hpp index fe148f3bc4f..c2e29cdc9ca 100644 --- a/include/ginkgo/core/matrix/csr.hpp +++ b/include/ginkgo/core/matrix/csr.hpp @@ -132,7 +132,8 @@ class Csr : public EnableLinOp>, public Transposable, public Permutable, public EnableAbsoluteComputation< - remove_complex>> { + remove_complex>>, + public EnableScaledIdentityAddition { friend class EnableCreateMethod; friend class EnablePolymorphicObject; friend class Coo; @@ -1168,6 +1169,8 @@ class Csr : public EnableLinOp>, Array row_ptrs_; Array srow_; std::shared_ptr strategy_; + + void add_scaled_identity_impl(const LinOp* a, const LinOp* b) override; }; diff --git a/include/ginkgo/core/matrix/dense.hpp b/include/ginkgo/core/matrix/dense.hpp index 0d1bb042e70..f23d5028c60 100644 --- a/include/ginkgo/core/matrix/dense.hpp +++ b/include/ginkgo/core/matrix/dense.hpp @@ -113,7 +113,8 @@ class Dense public Transposable, public Permutable, public Permutable, - public EnableAbsoluteComputation>> { + public EnableAbsoluteComputation>>, + public EnableScaledIdentityAddition { friend class EnableCreateMethod; friend class EnablePolymorphicObject; friend class Coo; @@ -1061,6 +1062,8 @@ class Dense private: Array values_; size_type stride_; + + void add_scaled_identity_impl(const LinOp* a, const LinOp* b) override; }; diff --git a/reference/test/matrix/csr_kernels.cpp b/reference/test/matrix/csr_kernels.cpp index 3458ba3693c..d068d3b0439 100644 --- a/reference/test/matrix/csr_kernels.cpp +++ b/reference/test/matrix/csr_kernels.cpp @@ -1544,6 +1544,22 @@ TYPED_TEST(Csr, InvScalesData) } +TYPED_TEST(Csr, ScaleCsrAddIdentityRectangular) +{ + using Vec = typename TestFixture::Vec; + using T = typename TestFixture::value_type; + using Csr = typename TestFixture::Mtx; + auto alpha = gko::initialize({2.0}, this->exec); + auto beta = gko::initialize({-1.0}, this->exec); + auto b = gko::initialize( + {I{2.0, 0.0}, I{1.0, 2.5}, I{0.0, -4.0}}, this->exec); + + b->add_scaled_identity(alpha.get(), beta.get()); + + GKO_ASSERT_MTX_NEAR(b, l({{0.0, 0.0}, {-1.0, -0.5}, {0.0, 4.0}}), 0.0); +} + + template class CsrComplex : public ::testing::Test { protected: diff --git a/reference/test/matrix/dense_kernels.cpp b/reference/test/matrix/dense_kernels.cpp index 28b3804672e..66a506ace1a 100644 --- a/reference/test/matrix/dense_kernels.cpp +++ b/reference/test/matrix/dense_kernels.cpp @@ -4081,6 +4081,22 @@ TYPED_TEST(Dense, MakeTemporaryConversionConstDoesntConvertBack) } +TYPED_TEST(Dense, ScaleAddIdentityRectangular) +{ + using T = typename TestFixture::value_type; + using Vec = typename TestFixture::Mtx; + using MixedVec = typename TestFixture::MixedMtx; + auto alpha = gko::initialize({2.0}, this->exec); + auto beta = gko::initialize({-1.0}, this->exec); + auto b = gko::initialize( + {I{2.0, 0.0}, I{1.0, 2.5}, I{0.0, -4.0}}, this->exec); + + b->add_scaled_identity(alpha.get(), beta.get()); + + GKO_ASSERT_MTX_NEAR(b, l({{0.0, 0.0}, {-1.0, -0.5}, {0.0, 4.0}}), 0.0); +} + + template class DenseComplex : public ::testing::Test { protected: