diff --git a/core/distributed/helpers.hpp b/core/distributed/helpers.hpp index 0e4f7b34e55..ef689ffced8 100644 --- a/core/distributed/helpers.hpp +++ b/core/distributed/helpers.hpp @@ -30,6 +30,10 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *************************************************************/ +#ifndef GKO_CORE_DISTRIBUTED_HELPERS_HPP_ +#define GKO_CORE_DISTRIBUTED_HELPERS_HPP_ + + #include @@ -124,5 +128,48 @@ bool is_distributed(Arg* linop, Rest*... rest) } +/** + * Cast an input linop to the correct underlying vector type (dense/distributed) + * and passes it to the given function. + * + * @tparam ValueType The value type of the underlying dense or distributed + * vector. + * @tparam T The linop type, either LinOp, or const LinOp. + * @tparam F The function type. + * @tparam Args The types for the additional arguments of f. + * + * @param linop The linop to be casted into either a dense or distributed + * vector. + * @param f The function that is to be called with the correctly casted linop. + * @param args The additional arguments of f. + */ +template +void vector_dispatch(T* linop, F&& f, Args&&... args) +{ +#if GINKGO_BUILD_MPI + if (is_distributed(linop)) { + using type = std::conditional_t< + std::is_const::value, + const experimental::distributed::Vector, + experimental::distributed::Vector>; + f(dynamic_cast(linop), std::forward(args)...); + } else +#endif + { + using type = std::conditional_t::value, + const matrix::Dense, + matrix::Dense>; + if (auto concrete_linop = dynamic_cast(linop)) { + f(concrete_linop, std::forward(args)...); + } else { + GKO_NOT_SUPPORTED(linop); + } + } +} + + } // namespace detail } // namespace gko + + +#endif // GKO_CORE_DISTRIBUTED_HELPERS_HPP_ diff --git a/core/log/convergence.cpp b/core/log/convergence.cpp index 591df7cf758..0f186a432a8 100644 --- a/core/log/convergence.cpp +++ b/core/log/convergence.cpp @@ -35,10 +35,15 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include +#include #include #include +#include "core/base/dispatch_helper.hpp" +#include "core/distributed/helpers.hpp" + + namespace gko { namespace log { @@ -73,12 +78,14 @@ void Convergence::on_criterion_check_completed( if (residual_norm != nullptr) { this->residual_norm_.reset(residual_norm->clone().release()); } else if (residual != nullptr) { - using Vector = matrix::Dense; using NormVector = matrix::Dense>; - this->residual_norm_ = NormVector::create( - residual->get_executor(), dim<2>{1, residual->get_size()[1]}); - auto dense_r = as(residual); - dense_r->compute_norm2(this->residual_norm_.get()); + detail::vector_dispatch( + residual, [&](const auto* dense_r) { + this->residual_norm_ = + NormVector::create(residual->get_executor(), + dim<2>{1, residual->get_size()[1]}); + dense_r->compute_norm2(this->residual_norm_.get()); + }); } } } diff --git a/core/log/papi.cpp b/core/log/papi.cpp index cd9c8584027..31fc8f519ad 100644 --- a/core/log/papi.cpp +++ b/core/log/papi.cpp @@ -37,6 +37,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include "core/distributed/helpers.hpp" + + namespace gko { namespace log { @@ -243,12 +246,13 @@ void Papi::on_criterion_check_completed( residual_norm_d = static_cast(std::real(dense_r_norm->at(0, 0))); } else if (residual != nullptr) { - auto tmp_res_norm = Vector::create(residual->get_executor(), - dim<2>{1, residual->get_size()[1]}); - auto dense_r = as(residual); - dense_r->compute_norm2(tmp_res_norm.get()); - residual_norm_d = - static_cast(std::real(tmp_res_norm->at(0, 0))); + detail::vector_dispatch(residual, [&](const auto* dense_r) { + auto tmp_res_norm = Vector::create( + residual->get_executor(), dim<2>{1, residual->get_size()[1]}); + dense_r->compute_norm2(tmp_res_norm.get()); + residual_norm_d = + static_cast(std::real(tmp_res_norm->at(0, 0))); + }); } const auto tmp = reinterpret_cast(criterion); diff --git a/core/test/log/convergence.cpp b/core/test/log/convergence.cpp index 8e35ed413ed..5e906fcd523 100644 --- a/core/test/log/convergence.cpp +++ b/core/test/log/convergence.cpp @@ -44,17 +44,81 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace { + template -class Convergence : public ::testing::Test {}; +class Convergence : public ::testing::Test { +public: + using Dense = gko::matrix::Dense; + using AbsoluteDense = gko::matrix::Dense>; + + Convergence() + { + status.get_data()[0].reset(); + status.get_data()[0].converge(0); + } + + std::shared_ptr exec = + gko::ReferenceExecutor::create(); + + std::unique_ptr residual = gko::initialize({3, 4}, exec); + std::unique_ptr residual_norm = + gko::initialize({5}, exec); + std::unique_ptr implicit_sq_resnorm = + gko::initialize({6}, exec); + std::unique_ptr solution = gko::initialize({-2, 7}, exec); + + gko::array status = {exec, 1}; +}; TYPED_TEST_SUITE(Convergence, gko::test::ValueTypes, TypenameNameGenerator); -TYPED_TEST(Convergence, CanGetData) +TYPED_TEST(Convergence, CanGetEmptyData) +{ + auto logger = gko::log::Convergence::create( + gko::log::Logger::criterion_events_mask); + + ASSERT_EQ(logger->has_converged(), false); + ASSERT_EQ(logger->get_num_iterations(), 0); + ASSERT_EQ(logger->get_residual(), nullptr); + ASSERT_EQ(logger->get_residual_norm(), nullptr); + ASSERT_EQ(logger->get_implicit_sq_resnorm(), nullptr); +} + + +TYPED_TEST(Convergence, CanLogData) +{ + using Dense = gko::matrix::Dense; + using AbsoluteDense = gko::matrix::Dense>; + auto logger = gko::log::Convergence::create( + gko::log::Logger::criterion_events_mask); + + logger->template on( + nullptr, 100, this->residual.get(), this->residual_norm.get(), + this->implicit_sq_resnorm.get(), this->solution.get(), 0, false, + &this->status, false, true); + + ASSERT_EQ(logger->has_converged(), true); + ASSERT_EQ(logger->get_num_iterations(), 100); + GKO_ASSERT_MTX_NEAR(gko::as(logger->get_residual()), + this->residual.get(), 0); + GKO_ASSERT_MTX_NEAR(gko::as(logger->get_residual_norm()), + this->residual_norm.get(), 0); + GKO_ASSERT_MTX_NEAR( + gko::as(logger->get_implicit_sq_resnorm()), + this->implicit_sq_resnorm.get(), 0); +} + + +TYPED_TEST(Convergence, DoesNotLogIfNotStopped) { - auto exec = gko::ReferenceExecutor::create(); auto logger = gko::log::Convergence::create( - gko::log::Logger::iteration_complete_mask); + gko::log::Logger::criterion_events_mask); + + logger->template on( + nullptr, 100, this->residual.get(), this->residual_norm.get(), + this->implicit_sq_resnorm.get(), this->solution.get(), 0, false, + &this->status, false, false); ASSERT_EQ(logger->has_converged(), false); ASSERT_EQ(logger->get_num_iterations(), 0); @@ -63,4 +127,19 @@ TYPED_TEST(Convergence, CanGetData) } +TYPED_TEST(Convergence, CanComputeResidualNorm) +{ + using AbsoluteDense = gko::matrix::Dense>; + auto logger = gko::log::Convergence::create( + gko::log::Logger::criterion_events_mask); + + logger->template on( + nullptr, 100, this->residual.get(), nullptr, nullptr, nullptr, 0, false, + &this->status, false, true); + + GKO_ASSERT_MTX_NEAR(gko::as(logger->get_residual_norm()), + this->residual_norm, r::value); +} + + } // namespace diff --git a/core/test/mpi/distributed/CMakeLists.txt b/core/test/mpi/distributed/CMakeLists.txt index 2e35c68c4ac..49375072010 100644 --- a/core/test/mpi/distributed/CMakeLists.txt +++ b/core/test/mpi/distributed/CMakeLists.txt @@ -1 +1,2 @@ +ginkgo_create_test(helpers MPI_SIZE 1) ginkgo_create_test(matrix MPI_SIZE 1) diff --git a/core/test/mpi/distributed/helpers.cpp b/core/test/mpi/distributed/helpers.cpp new file mode 100644 index 00000000000..b311cea437e --- /dev/null +++ b/core/test/mpi/distributed/helpers.cpp @@ -0,0 +1,134 @@ +/************************************************************* +Copyright (c) 2017-2022, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + +#include + + +#include + + +#include "core/distributed/helpers.hpp" +#include "core/test/utils.hpp" + + +int run_function(gko::experimental::distributed::Vector<>*) { return 1; } + +int run_function(const gko::experimental::distributed::Vector<>*) { return 2; } + +int run_function(gko::matrix::Dense<>*) { return 3; } + +int run_function(const gko::matrix::Dense<>*) { return 4; } + + +class RunVector : public ::testing::Test { +public: + std::shared_ptr exec = + gko::ReferenceExecutor::create(); +}; + + +TEST_F(RunVector, PicksDistributedVectorCorrectly) +{ + std::unique_ptr dist_vector = + gko::experimental::distributed::Vector<>::create(exec, MPI_COMM_WORLD); + int result; + + gko::detail::vector_dispatch( + dist_vector.get(), [&](auto* dense) { result = run_function(dense); }); + + ASSERT_EQ(result, + run_function(gko::as>( + dist_vector.get()))); +} + + +TEST_F(RunVector, PicksConstDistributedVectorCorrectly) +{ + std::unique_ptr const_dist_vector = + gko::experimental::distributed::Vector<>::create(exec, MPI_COMM_WORLD); + int result; + + gko::detail::vector_dispatch( + const_dist_vector.get(), + [&](auto* dense) { result = run_function(dense); }); + + ASSERT_EQ( + result, + run_function(gko::as>( + const_dist_vector.get()))); +} + + +TEST_F(RunVector, PicksDenseVectorCorrectly) +{ + std::unique_ptr dense_vector = + gko::matrix::Dense<>::create(exec); + int result; + + gko::detail::vector_dispatch( + dense_vector.get(), [&](auto* dense) { result = run_function(dense); }); + + ASSERT_EQ(result, + run_function(gko::as>(dense_vector.get()))); +} + + +TEST_F(RunVector, PicksConstDenseVectorCorrectly) +{ + std::unique_ptr const_dense_vector = + gko::matrix::Dense<>::create(exec); + int result; + + gko::detail::vector_dispatch( + const_dense_vector.get(), + [&](auto* dense) { result = run_function(dense); }); + + ASSERT_EQ(result, run_function(gko::as>( + const_dense_vector.get()))); +} + +TEST_F(RunVector, ThrowsIfWrongType) +{ + std::unique_ptr csr = gko::matrix::Csr<>::create(exec); + + ASSERT_THROW( + gko::detail::vector_dispatch(csr.get(), [&](auto* dense) {}), + gko::NotSupported); +} + + +TEST_F(RunVector, ThrowsIfNullptr) +{ + ASSERT_THROW(gko::detail::vector_dispatch( + static_cast(nullptr), [&](auto* dense) {}), + gko::NotSupported); +}