diff --git a/include/ginkgo/core/solver/minres.hpp b/include/ginkgo/core/solver/minres.hpp index 271f76d1a15..d3772514ef8 100644 --- a/include/ginkgo/core/solver/minres.hpp +++ b/include/ginkgo/core/solver/minres.hpp @@ -170,13 +170,15 @@ class Minres : public EnableLinOp>, if (parameters_.generated_preconditioner) { GKO_ASSERT_EQUAL_DIMENSIONS(parameters_.generated_preconditioner, this); - set_preconditioner(parameters_.generated_preconditioner); + Preconditionable::set_preconditioner( + parameters_.generated_preconditioner); } else if (parameters_.preconditioner) { - set_preconditioner( + Preconditionable::set_preconditioner( parameters_.preconditioner->generate(system_matrix_)); } else { - set_preconditioner(matrix::Identity::create( - this->get_executor(), this->get_size())); + Preconditionable::set_preconditioner( + matrix::Identity::create(this->get_executor(), + this->get_size())); } stop_criterion_factory_ = stop::combine(std::move(parameters_.criteria)); diff --git a/reference/test/solver/minres_kernels.cpp b/reference/test/solver/minres_kernels.cpp index dd17e40ef97..61f351518f0 100644 --- a/reference/test/solver/minres_kernels.cpp +++ b/reference/test/solver/minres_kernels.cpp @@ -39,6 +39,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include #include #include #include @@ -94,6 +95,19 @@ class Minres : public ::testing::Test { gko::stop::ResidualNorm::build() .with_reduction_factor(r::value / 2) .on(exec)) + .on(exec)), + preconditioned_minres_factory( + Solver::build() + .with_criteria( + gko::stop::Iteration::build().with_max_iters(40u).on( + exec), + gko::stop::ResidualNorm::build() + .with_reduction_factor(r::value / 2) + .on(exec)) + .with_preconditioner( + gko::preconditioner::Jacobi::build() + .with_max_block_size(1u) + .on(exec)) .on(exec)) { stopped.stop(1); @@ -135,6 +149,7 @@ class Minres : public ::testing::Test { gko::Array small_stop; std::unique_ptr minres_factory; + std::unique_ptr preconditioned_minres_factory; }; TYPED_TEST_SUITE(Minres, gko::test::ValueTypes, TypenameNameGenerator); @@ -308,4 +323,23 @@ TYPED_TEST(Minres, SolvesSystem) } +TYPED_TEST(Minres, SolvesPreconditionedSystem) +{ + using Mtx = typename TestFixture::Mtx; + using vt = typename TestFixture::value_type; + auto one_op = gko::initialize({gko::one()}, this->exec); + auto neg_one_op = gko::initialize({-gko::one()}, this->exec); + auto solver = this->preconditioned_minres_factory->generate(this->mtx); + auto x = gko::initialize({-1., 2., 3., 4.}, this->exec); + auto sol = gko::clone(this->exec, x); + auto b = Mtx::create(this->exec, x->get_size()); + this->mtx->apply(x.get(), b.get()); + x->fill(0.); + + solver->apply(b.get(), x.get()); + + GKO_ASSERT_MTX_NEAR(x, sol, r::value * 10); +} + + } // namespace diff --git a/test/solver/minres_kernels.cpp b/test/solver/minres_kernels.cpp index 186a799f766..bedfb0064aa 100644 --- a/test/solver/minres_kernels.cpp +++ b/test/solver/minres_kernels.cpp @@ -40,6 +40,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include #include #include #include @@ -328,4 +329,48 @@ TEST_F(Minres, ApplyIsEquivalentToRef) } +TEST_F(Minres, PreconditionedApplyIsEquivalentToRef) +{ + + auto mtx = gen_mtx(50, 50, 53); + gko::test::make_hpd(mtx.get()); + auto x = gen_mtx(50, 1, 5); + auto b = gen_mtx(50, 1, 4); + auto d_mtx = gko::clone(exec, mtx); + auto d_x = gko::clone(exec, x); + auto d_b = gko::clone(exec, b); + auto minres_factory = + gko::solver::Minres::build() + .with_criteria( + gko::stop::Iteration::build().with_max_iters(400u).on(ref), + gko::stop::ResidualNorm::build() + .with_reduction_factor(::r::value) + .on(ref)) + .with_preconditioner( + gko::preconditioner::Jacobi::build() + .with_max_block_size(1u) + .on(ref)) + .on(ref); + auto d_minres_factory = + gko::solver::Minres::build() + .with_criteria( + gko::stop::Iteration::build().with_max_iters(400u).on(exec), + gko::stop::ResidualNorm::build() + .with_reduction_factor(::r::value) + .on(exec)) + .with_preconditioner( + gko::preconditioner::Jacobi::build() + .with_max_block_size(1u) + .on(exec)) + .on(exec); + auto solver = minres_factory->generate(std::move(mtx)); + auto d_solver = d_minres_factory->generate(std::move(d_mtx)); + + solver->apply(b.get(), x.get()); + d_solver->apply(d_b.get(), d_x.get()); + + GKO_ASSERT_MTX_NEAR(d_x, x, ::r::value * 100); +} + + } // namespace