diff --git a/core/stop/residual_norm.cpp b/core/stop/residual_norm.cpp index c4c37162408..73630204a22 100644 --- a/core/stop/residual_norm.cpp +++ b/core/stop/residual_norm.cpp @@ -75,6 +75,26 @@ bool ResidualNormBase::check_impl( dense_r->compute_norm2(u_dense_tau_.get()); } dense_tau = u_dense_tau_.get(); + } else if (updater.solution_ != nullptr && system_matrix_ != nullptr && + b_ != nullptr) { + auto exec = this->get_executor(); + // when LinOp is real but rhs is complex, we use real view on complex, + // so it still uses the same type of scalar in apply. + if (auto vec_b = std::dynamic_pointer_cast(b_)) { + auto dense_r = vec_b->clone(); + system_matrix_->apply(neg_one_.get(), updater.solution_, one_.get(), + dense_r.get()); + dense_r->compute_norm2(u_dense_tau_.get()); + } else if (auto vec_b = + std::dynamic_pointer_cast(b_)) { + auto dense_r = vec_b->clone(); + system_matrix_->apply(neg_one_.get(), updater.solution_, one_.get(), + dense_r.get()); + dense_r->compute_norm2(u_dense_tau_.get()); + } else { + GKO_NOT_SUPPORTED(nullptr); + } + dense_tau = u_dense_tau_.get(); } else { GKO_NOT_SUPPORTED(nullptr); } diff --git a/include/ginkgo/core/stop/residual_norm.hpp b/include/ginkgo/core/stop/residual_norm.hpp index e2456403643..8dbbdc88544 100644 --- a/include/ginkgo/core/stop/residual_norm.hpp +++ b/include/ginkgo/core/stop/residual_norm.hpp @@ -100,21 +100,47 @@ class ResidualNormBase : EnablePolymorphicObject(exec), device_storage_{exec, 2}, reduction_factor_{reduction_factor}, - baseline_{baseline} + baseline_{baseline}, + system_matrix_{args.system_matrix}, + b_{args.b}, + one_{gko::initialize({1}, exec)}, + neg_one_{gko::initialize({-1}, exec)} { switch (baseline_) { case mode::initial_resnorm: { if (args.initial_residual == nullptr) { - GKO_NOT_SUPPORTED(nullptr); - } - this->starting_tau_ = NormVector::create( - exec, dim<2>{1, args.initial_residual->get_size()[1]}); - if (dynamic_cast(args.initial_residual)) { - auto dense_r = as(args.initial_residual); - dense_r->compute_norm2(this->starting_tau_.get()); + if (args.system_matrix == nullptr || args.b == nullptr || + args.x == nullptr) { + GKO_NOT_SUPPORTED(nullptr); + } else { + this->starting_tau_ = NormVector::create( + exec, dim<2>{1, args.b->get_size()[1]}); + auto b_clone = share(args.b->clone()); + args.system_matrix->apply(neg_one_.get(), args.x, + one_.get(), b_clone.get()); + if (auto vec = + std::dynamic_pointer_cast( + b_clone)) { + vec->compute_norm2(this->starting_tau_.get()); + } else if (auto vec = + std::dynamic_pointer_cast( + b_clone)) { + vec->compute_norm2(this->starting_tau_.get()); + } else { + GKO_NOT_SUPPORTED(nullptr); + } + } } else { - auto dense_r = as(args.initial_residual); - dense_r->compute_norm2(this->starting_tau_.get()); + this->starting_tau_ = NormVector::create( + exec, dim<2>{1, args.initial_residual->get_size()[1]}); + if (dynamic_cast( + args.initial_residual)) { + auto dense_r = as(args.initial_residual); + dense_r->compute_norm2(this->starting_tau_.get()); + } else { + auto dense_r = as(args.initial_residual); + dense_r->compute_norm2(this->starting_tau_.get()); + } } break; } @@ -157,6 +183,11 @@ class ResidualNormBase private: mode baseline_{mode::rhs_norm}; + std::shared_ptr system_matrix_{}; + std::shared_ptr b_{}; + /* one/neg_one for residual computation */ + std::shared_ptr one_{}; + std::shared_ptr neg_one_{}; }; diff --git a/reference/test/stop/residual_norm_kernels.cpp b/reference/test/stop/residual_norm_kernels.cpp index b7676c807ee..78721d0f1ae 100644 --- a/reference/test/stop/residual_norm_kernels.cpp +++ b/reference/test/stop/residual_norm_kernels.cpp @@ -237,6 +237,205 @@ TYPED_TEST(ResidualNorm, WaitsTillResidualGoal) } +TYPED_TEST(ResidualNorm, SelfCalulatesThrowWithoutMatrix) +{ + using Mtx = typename TestFixture::Mtx; + using NormVector = typename TestFixture::NormVector; + using T = TypeParam; + using T_nc = gko::remove_complex; + auto initial_res = gko::initialize({100.0}, this->exec_); + + T rhs_val = 10.0; + std::shared_ptr rhs = + gko::initialize({rhs_val}, this->exec_); + auto rhs_criterion = + this->rhs_factory_->generate(nullptr, rhs, nullptr, initial_res.get()); + auto rel_criterion = + this->rel_factory_->generate(nullptr, rhs, nullptr, initial_res.get()); + auto abs_criterion = + this->abs_factory_->generate(nullptr, rhs, nullptr, initial_res.get()); + { + auto solution = gko::initialize({rhs_val - T{10.0}}, this->exec_); + auto rhs_norm = gko::initialize({100.0}, this->exec_); + gko::as(rhs)->compute_norm2(rhs_norm.get()); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_THROW( + rhs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed), + gko::NotSupported); + } + { + T initial_norm = 100.0; + auto solution = + gko::initialize({rhs_val - initial_norm}, this->exec_); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_THROW( + rel_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed), + gko::NotSupported); + } + { + auto solution = gko::initialize({rhs_val - T{100.0}}, this->exec_); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_THROW( + abs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed), + gko::NotSupported); + } +} + + +TYPED_TEST(ResidualNorm, RelativeSelfCalulatesThrowWithoutRhs) +{ + // only relative residual norm allows generation without rhs. + using Mtx = typename TestFixture::Mtx; + using NormVector = typename TestFixture::NormVector; + using T = TypeParam; + using T_nc = gko::remove_complex; + auto initial_res = gko::initialize({100.0}, this->exec_); + + T rhs_val = 10.0; + auto rel_criterion = this->rel_factory_->generate(nullptr, nullptr, nullptr, + initial_res.get()); + T initial_norm = 100.0; + auto solution = gko::initialize({rhs_val - initial_norm}, this->exec_); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_THROW( + rel_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed), + gko::NotSupported); +} + + +TYPED_TEST(ResidualNorm, SelfCalulatesAndWaitsTillResidualGoal) +{ + using Mtx = typename TestFixture::Mtx; + using NormVector = typename TestFixture::NormVector; + using T = TypeParam; + using T_nc = gko::remove_complex; + auto initial_res = gko::initialize({100.0}, this->exec_); + auto system_mtx = share(gko::initialize({1.0}, this->exec_)); + + T rhs_val = 10.0; + std::shared_ptr rhs = + gko::initialize({rhs_val}, this->exec_); + auto rhs_criterion = this->rhs_factory_->generate(system_mtx, rhs, nullptr, + initial_res.get()); + auto rel_criterion = this->rel_factory_->generate(system_mtx, rhs, nullptr, + initial_res.get()); + auto abs_criterion = this->abs_factory_->generate(system_mtx, rhs, nullptr, + initial_res.get()); + { + auto solution = gko::initialize({rhs_val - T{10.0}}, this->exec_); + auto rhs_norm = gko::initialize({100.0}, this->exec_); + gko::as(rhs)->compute_norm2(rhs_norm.get()); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_FALSE( + rhs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + + solution->at(0) = rhs_val - r::value * T{1.1} * rhs_norm->at(0); + ASSERT_FALSE( + rhs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + solution->at(0) = rhs_val - r::value * T{0.9} * rhs_norm->at(0); + ASSERT_TRUE( + rhs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); + ASSERT_EQ(one_changed, true); + } + { + T initial_norm = 100.0; + auto solution = + gko::initialize({rhs_val - initial_norm}, this->exec_); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_FALSE( + rel_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + + solution->at(0) = rhs_val - r::value * T{1.1} * initial_norm; + ASSERT_FALSE( + rel_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + solution->at(0) = rhs_val - r::value * T{0.9} * initial_norm; + ASSERT_TRUE( + rel_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); + ASSERT_EQ(one_changed, true); + } + { + auto solution = gko::initialize({rhs_val - T{100.0}}, this->exec_); + constexpr gko::uint8 RelativeStoppingId{1}; + bool one_changed{}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_FALSE( + abs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + + solution->at(0) = rhs_val - r::value * T{1.2}; + ASSERT_FALSE( + abs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + solution->at(0) = rhs_val - r::value * T{0.9}; + ASSERT_TRUE( + abs_criterion->update() + .solution(solution.get()) + .check(RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); + ASSERT_EQ(one_changed, true); + } +} + + TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS) { using Mtx = typename TestFixture::Mtx; @@ -370,6 +569,20 @@ class ResidualNormReduction : public ::testing::Test { TYPED_TEST_SUITE(ResidualNormReduction, gko::test::ValueTypes); +TYPED_TEST(ResidualNormReduction, + CanCreateCriterionWithMtxRhsXWithoutInitialRes) +{ + using Mtx = typename TestFixture::Mtx; + std::shared_ptr x = gko::initialize({100.0}, this->exec_); + std::shared_ptr mtx = gko::initialize({1.0}, this->exec_); + std::shared_ptr b = gko::initialize({10.0}, this->exec_); + + auto criterion = this->factory_->generate(mtx, b, x.get()); + + ASSERT_NE(criterion, nullptr); +} + + TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoal) { using Mtx = typename TestFixture::Mtx; @@ -407,6 +620,42 @@ TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoal) } +TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoalWithoutInitialRes) +{ + using T = TypeParam; + using Mtx = typename TestFixture::Mtx; + using NormVector = typename TestFixture::NormVector; + T initial_res = 100; + T rhs_val = 10; + std::shared_ptr rhs = + gko::initialize({rhs_val}, this->exec_); + std::shared_ptr x = + gko::initialize({rhs_val - initial_res}, this->exec_); + std::shared_ptr mtx = gko::initialize({1.0}, this->exec_); + + auto criterion = this->factory_->generate(mtx, rhs, x.get()); + bool one_changed{}; + constexpr gko::uint8 RelativeStoppingId{1}; + gko::Array stop_status(this->exec_, 1); + stop_status.get_data()[0].reset(); + + ASSERT_FALSE(criterion->update().solution(x.get()).check( + RelativeStoppingId, true, &stop_status, &one_changed)); + + x->at(0) = rhs_val - r::value * T{1.1} * initial_res; + ASSERT_FALSE(criterion->update().solution(x.get()).check( + RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), false); + ASSERT_EQ(one_changed, false); + + x->at(0) = rhs_val - r::value * T{0.9} * initial_res; + ASSERT_TRUE(criterion->update().solution(x.get()).check( + RelativeStoppingId, true, &stop_status, &one_changed)); + ASSERT_EQ(stop_status.get_data()[0].has_converged(), true); + ASSERT_EQ(one_changed, true); +} + + TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoalMultipleRHS) { using Mtx = typename TestFixture::Mtx;