Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResidualNorm stop can compute the needed info in gen/check #818

Merged
merged 2 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions core/stop/residual_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ bool ResidualNormBase<ValueType>::check_impl(
dense_r->compute_norm2(u_dense_tau_.get());
}
dense_tau = u_dense_tau_.get();
} else if (updater.solution_ != nullptr && system_matrix_ != nullptr &&
b_ != nullptr) {
auto exec = this->get_executor();
// when LinOp is real but rhs is complex, we use real view on complex,
// so it still uses the same type of scalar in apply.
if (auto vec_b = std::dynamic_pointer_cast<const Vector>(b_)) {
auto dense_r = vec_b->clone();
system_matrix_->apply(neg_one_.get(), updater.solution_, one_.get(),
dense_r.get());
dense_r->compute_norm2(u_dense_tau_.get());
} else if (auto vec_b =
std::dynamic_pointer_cast<const ComplexVector>(b_)) {
auto dense_r = vec_b->clone();
system_matrix_->apply(neg_one_.get(), updater.solution_, one_.get(),
dense_r.get());
dense_r->compute_norm2(u_dense_tau_.get());
} else {
GKO_NOT_SUPPORTED(nullptr);
}
dense_tau = u_dense_tau_.get();
} else {
GKO_NOT_SUPPORTED(nullptr);
}
Expand Down
51 changes: 41 additions & 10 deletions include/ginkgo/core/stop/residual_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,47 @@ class ResidualNormBase
: EnablePolymorphicObject<ResidualNormBase, Criterion>(exec),
device_storage_{exec, 2},
reduction_factor_{reduction_factor},
baseline_{baseline}
baseline_{baseline},
system_matrix_{args.system_matrix},
b_{args.b},
one_{gko::initialize<Vector>({1}, exec)},
neg_one_{gko::initialize<Vector>({-1}, exec)}
{
switch (baseline_) {
case mode::initial_resnorm: {
if (args.initial_residual == nullptr) {
GKO_NOT_SUPPORTED(nullptr);
}
this->starting_tau_ = NormVector::create(
exec, dim<2>{1, args.initial_residual->get_size()[1]});
if (dynamic_cast<const ComplexVector *>(args.initial_residual)) {
auto dense_r = as<ComplexVector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
if (args.system_matrix == nullptr || args.b == nullptr ||
args.x == nullptr) {
GKO_NOT_SUPPORTED(nullptr);
} else {
this->starting_tau_ = NormVector::create(
exec, dim<2>{1, args.b->get_size()[1]});
auto b_clone = share(args.b->clone());
args.system_matrix->apply(neg_one_.get(), args.x,
one_.get(), b_clone.get());
if (auto vec =
std::dynamic_pointer_cast<const ComplexVector>(
b_clone)) {
vec->compute_norm2(this->starting_tau_.get());
} else if (auto vec =
std::dynamic_pointer_cast<const Vector>(
b_clone)) {
vec->compute_norm2(this->starting_tau_.get());
} else {
GKO_NOT_SUPPORTED(nullptr);
}
}
} else {
auto dense_r = as<Vector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
this->starting_tau_ = NormVector::create(
exec, dim<2>{1, args.initial_residual->get_size()[1]});
if (dynamic_cast<const ComplexVector *>(
args.initial_residual)) {
auto dense_r = as<ComplexVector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
} else {
auto dense_r = as<Vector>(args.initial_residual);
dense_r->compute_norm2(this->starting_tau_.get());
}
}
break;
}
Expand Down Expand Up @@ -157,6 +183,11 @@ class ResidualNormBase

private:
mode baseline_{mode::rhs_norm};
std::shared_ptr<const LinOp> system_matrix_{};
std::shared_ptr<const LinOp> b_{};
/* one/neg_one for residual computation */
std::shared_ptr<const Vector> one_{};
std::shared_ptr<const Vector> neg_one_{};
};


Expand Down
249 changes: 249 additions & 0 deletions reference/test/stop/residual_norm_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,205 @@ TYPED_TEST(ResidualNorm, WaitsTillResidualGoal)
}


TYPED_TEST(ResidualNorm, SelfCalulatesThrowWithoutMatrix)
{
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T = TypeParam;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);

T rhs_val = 10.0;
std::shared_ptr<gko::LinOp> rhs =
gko::initialize<Mtx>({rhs_val}, this->exec_);
auto rhs_criterion =
this->rhs_factory_->generate(nullptr, rhs, nullptr, initial_res.get());
auto rel_criterion =
this->rel_factory_->generate(nullptr, rhs, nullptr, initial_res.get());
auto abs_criterion =
this->abs_factory_->generate(nullptr, rhs, nullptr, initial_res.get());
{
auto solution = gko::initialize<Mtx>({rhs_val - T{10.0}}, this->exec_);
auto rhs_norm = gko::initialize<NormVector>({100.0}, this->exec_);
gko::as<Mtx>(rhs)->compute_norm2(rhs_norm.get());
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}
{
T initial_norm = 100.0;
auto solution =
gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}
{
auto solution = gko::initialize<Mtx>({rhs_val - T{100.0}}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}
}


TYPED_TEST(ResidualNorm, RelativeSelfCalulatesThrowWithoutRhs)
{
// only relative residual norm allows generation without rhs.
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T = TypeParam;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);

T rhs_val = 10.0;
auto rel_criterion = this->rel_factory_->generate(nullptr, nullptr, nullptr,
initial_res.get());
T initial_norm = 100.0;
auto solution = gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_THROW(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed),
gko::NotSupported);
}


TYPED_TEST(ResidualNorm, SelfCalulatesAndWaitsTillResidualGoal)
{
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
using T = TypeParam;
using T_nc = gko::remove_complex<TypeParam>;
auto initial_res = gko::initialize<Mtx>({100.0}, this->exec_);
auto system_mtx = share(gko::initialize<Mtx>({1.0}, this->exec_));

T rhs_val = 10.0;
std::shared_ptr<gko::LinOp> rhs =
gko::initialize<Mtx>({rhs_val}, this->exec_);
auto rhs_criterion = this->rhs_factory_->generate(system_mtx, rhs, nullptr,
initial_res.get());
auto rel_criterion = this->rel_factory_->generate(system_mtx, rhs, nullptr,
initial_res.get());
auto abs_criterion = this->abs_factory_->generate(system_mtx, rhs, nullptr,
initial_res.get());
{
auto solution = gko::initialize<Mtx>({rhs_val - T{10.0}}, this->exec_);
auto rhs_norm = gko::initialize<NormVector>({100.0}, this->exec_);
gko::as<Mtx>(rhs)->compute_norm2(rhs_norm.get());
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));

solution->at(0) = rhs_val - r<T>::value * T{1.1} * rhs_norm->at(0);
ASSERT_FALSE(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

solution->at(0) = rhs_val - r<T>::value * T{0.9} * rhs_norm->at(0);
ASSERT_TRUE(
rhs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
{
T initial_norm = 100.0;
auto solution =
gko::initialize<Mtx>({rhs_val - initial_norm}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));

solution->at(0) = rhs_val - r<T>::value * T{1.1} * initial_norm;
ASSERT_FALSE(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

solution->at(0) = rhs_val - r<T>::value * T{0.9} * initial_norm;
ASSERT_TRUE(
rel_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
{
auto solution = gko::initialize<Mtx>({rhs_val - T{100.0}}, this->exec_);
constexpr gko::uint8 RelativeStoppingId{1};
bool one_changed{};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));

solution->at(0) = rhs_val - r<T>::value * T{1.2};
ASSERT_FALSE(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

solution->at(0) = rhs_val - r<T>::value * T{0.9};
ASSERT_TRUE(
abs_criterion->update()
.solution(solution.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}
}


TYPED_TEST(ResidualNorm, WaitsTillResidualGoalMultipleRHS)
{
using Mtx = typename TestFixture::Mtx;
Expand Down Expand Up @@ -370,6 +569,20 @@ class ResidualNormReduction : public ::testing::Test {
TYPED_TEST_SUITE(ResidualNormReduction, gko::test::ValueTypes);


TYPED_TEST(ResidualNormReduction,
CanCreateCriterionWithMtxRhsXWithoutInitialRes)
{
using Mtx = typename TestFixture::Mtx;
std::shared_ptr<gko::LinOp> x = gko::initialize<Mtx>({100.0}, this->exec_);
std::shared_ptr<gko::LinOp> mtx = gko::initialize<Mtx>({1.0}, this->exec_);
std::shared_ptr<gko::LinOp> b = gko::initialize<Mtx>({10.0}, this->exec_);

auto criterion = this->factory_->generate(mtx, b, x.get());

ASSERT_NE(criterion, nullptr);
}


TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoal)
{
using Mtx = typename TestFixture::Mtx;
Expand Down Expand Up @@ -407,6 +620,42 @@ TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoal)
}


TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoalWithoutInitialRes)
{
using T = TypeParam;
using Mtx = typename TestFixture::Mtx;
using NormVector = typename TestFixture::NormVector;
T initial_res = 100;
T rhs_val = 10;
std::shared_ptr<gko::LinOp> rhs =
gko::initialize<Mtx>({rhs_val}, this->exec_);
std::shared_ptr<Mtx> x =
gko::initialize<Mtx>({rhs_val - initial_res}, this->exec_);
std::shared_ptr<gko::LinOp> mtx = gko::initialize<Mtx>({1.0}, this->exec_);

auto criterion = this->factory_->generate(mtx, rhs, x.get());
bool one_changed{};
constexpr gko::uint8 RelativeStoppingId{1};
gko::Array<gko::stopping_status> stop_status(this->exec_, 1);
stop_status.get_data()[0].reset();

ASSERT_FALSE(criterion->update().solution(x.get()).check(
RelativeStoppingId, true, &stop_status, &one_changed));

x->at(0) = rhs_val - r<T>::value * T{1.1} * initial_res;
ASSERT_FALSE(criterion->update().solution(x.get()).check(
RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), false);
ASSERT_EQ(one_changed, false);

x->at(0) = rhs_val - r<T>::value * T{0.9} * initial_res;
ASSERT_TRUE(criterion->update().solution(x.get()).check(
RelativeStoppingId, true, &stop_status, &one_changed));
ASSERT_EQ(stop_status.get_data()[0].has_converged(), true);
ASSERT_EQ(one_changed, true);
}


TYPED_TEST(ResidualNormReduction, WaitsTillResidualGoalMultipleRHS)
{
using Mtx = typename TestFixture::Mtx;
Expand Down