Skip to content

Commit

Permalink
update minres to new solver base
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 12, 2022
1 parent 16fc9af commit 6a991bc
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 116 deletions.
6 changes: 3 additions & 3 deletions common/unified/solver/minres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void initialize(
matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos,
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
matrix::Dense<ValueType>* eta_next, matrix::Dense<ValueType>* eta,
Array<stopping_status>* stop_status)
array<stopping_status>* stop_status)
{
run_kernel(
exec,
Expand Down Expand Up @@ -137,7 +137,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
matrix::Dense<ValueType>* eta, matrix::Dense<ValueType>* eta_next,
typename matrix::Dense<ValueType>::absolute_type* tau,
const Array<stopping_status>* stop_status)
const array<stopping_status>* stop_status)
{
run_kernel(
exec,
Expand Down Expand Up @@ -183,7 +183,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
matrix::Dense<ValueType>* beta, matrix::Dense<ValueType>* gamma,
matrix::Dense<ValueType>* delta, matrix::Dense<ValueType>* cos,
matrix::Dense<ValueType>* eta,
const Array<stopping_status>* stop_status)
const array<stopping_status>* stop_status)
{
run_kernel_solver(
exec,
Expand Down
20 changes: 11 additions & 9 deletions core/solver/minres.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ std::unique_ptr<LinOp> Minres<ValueType>::transpose() const
return build()
.with_generated_preconditioner(
share(as<Transposable>(this->get_preconditioner())->transpose()))
.with_criteria(this->stop_criterion_factory_)
.with_criteria(this->get_stop_criterion_factory())
.on(this->get_executor())
->generate(
share(as<Transposable>(this->get_system_matrix())->transpose()));
Expand All @@ -77,7 +77,7 @@ std::unique_ptr<LinOp> Minres<ValueType>::conj_transpose() const
return build()
.with_generated_preconditioner(share(
as<Transposable>(this->get_preconditioner())->conj_transpose()))
.with_criteria(this->stop_criterion_factory_)
.with_criteria(this->get_stop_criterion_factory())
.on(this->get_executor())
->generate(share(
as<Transposable>(this->get_system_matrix())->conj_transpose()));
Expand Down Expand Up @@ -157,21 +157,22 @@ void Minres<ValueType>::apply_dense_impl(
auto sin = Vector::create_with_config_of(alpha.get());

bool one_changed{};
Array<stopping_status> stop_status(alpha->get_executor(),
array<stopping_status> stop_status(alpha->get_executor(),
dense_b->get_size()[1]);

// r = dense_b
r = gko::clone(dense_b);
system_matrix_->apply(neg_one_op.get(), dense_x, one_op.get(), r.get());
auto stop_criterion = stop_criterion_factory_->generate(
system_matrix_,
this->get_system_matrix()->apply(neg_one_op.get(), dense_x, one_op.get(),
r.get());
auto stop_criterion = this->get_stop_criterion_factory()->generate(
this->get_system_matrix(),
std::shared_ptr<const LinOp>(dense_b, [](const LinOp*) {}), dense_x,
r.get());

// z = M^-1 * r
// beta = <r, z>
// tau = ||z||_2
get_preconditioner()->apply(r.get(), z.get());
this->get_preconditioner()->apply(r.get(), z.get());
r->compute_conj_dot(z.get(), beta.get());
z->compute_norm2(tau.get());

Expand Down Expand Up @@ -214,10 +215,11 @@ void Minres<ValueType>::apply_dense_impl(
// v = v - alpha * q
// z_tilde = M * v
// beta = <v, z_tilde>
system_matrix_->apply(one_op.get(), z.get(), neg_one_op.get(), v.get());
this->get_system_matrix()->apply(one_op.get(), z.get(),
neg_one_op.get(), v.get());
v->compute_conj_dot(z.get(), alpha.get());
v->sub_scaled(alpha.get(), q.get());
get_preconditioner()->apply(v.get(), z_tilde.get());
this->get_preconditioner()->apply(v.get(), z_tilde.get());
v->compute_conj_dot(z_tilde.get(), beta.get());

// Updates scalars (row vectors)
Expand Down
6 changes: 3 additions & 3 deletions core/solver/minres_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace cg {
matrix::Dense<_type>* cos_prev, matrix::Dense<_type>* cos, \
matrix::Dense<_type>* sin_prev, matrix::Dense<_type>* sin, \
matrix::Dense<_type>* eta_next, matrix::Dense<_type>* eta, \
Array<stopping_status>* stop_status)
array<stopping_status>* stop_status)


#define GKO_DECLARE_MINRES_STEP_1_KERNEL(_type) \
Expand All @@ -73,7 +73,7 @@ namespace cg {
matrix::Dense<_type>* sin_prev, matrix::Dense<_type>* sin, \
matrix::Dense<_type>* eta, matrix::Dense<_type>* eta_next, \
typename matrix::Dense<_type>::absolute_type* tau, \
const Array<stopping_status>* stop_status)
const array<stopping_status>* stop_status)

#define GKO_DECLARE_MINRES_STEP_2_KERNEL(_type) \
void step_2(std::shared_ptr<const DefaultExecutor> exec, \
Expand All @@ -84,7 +84,7 @@ namespace cg {
matrix::Dense<_type>* alpha, matrix::Dense<_type>* beta, \
matrix::Dense<_type>* gamma, matrix::Dense<_type>* delta, \
matrix::Dense<_type>* cos, matrix::Dense<_type>* eta, \
const Array<stopping_status>* stop_status)
const array<stopping_status>* stop_status)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
Expand Down
68 changes: 9 additions & 59 deletions include/ginkgo/core/solver/minres.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/log/logger.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/matrix/identity.hpp>
#include <ginkgo/core/solver/solver_base.hpp>
#include <ginkgo/core/stop/combined.hpp>
#include <ginkgo/core/stop/criterion.hpp>

Expand All @@ -69,26 +70,17 @@ namespace solver {
* @ingroup LinOp
*/
template <typename ValueType = default_precision>
class Minres : public EnableLinOp<Minres<ValueType>>,
public Preconditionable,
public Transposable {
class Minres
: public EnableLinOp<Minres<ValueType>>,
public EnablePreconditionedIterativeSolver<ValueType, Minres<ValueType>>,
public Transposable {
friend class EnableLinOp<Minres>;
friend class EnablePolymorphicObject<Minres, LinOp>;

public:
using value_type = ValueType;
using transposed_type = Minres<ValueType>;

/**
* Gets the system operator (matrix) of the linear system.
*
* @return the system operator (matrix)
*/
std::shared_ptr<const LinOp> get_system_matrix() const
{
return system_matrix_;
}

std::unique_ptr<LinOp> transpose() const override;

std::unique_ptr<LinOp> conj_transpose() const override;
Expand All @@ -100,28 +92,6 @@ class Minres : public EnableLinOp<Minres<ValueType>>,
*/
bool apply_uses_initial_guess() const override { return true; }

/**
* Gets the stopping criterion factory of the solver.
*
* @return the stopping criterion factory
*/
std::shared_ptr<const stop::CriterionFactory> get_stop_criterion_factory()
const
{
return stop_criterion_factory_;
}

/**
* Sets the stopping criterion of the solver.
*
* @param other the new stopping criterion factory
*/
void set_stop_criterion_factory(
std::shared_ptr<const stop::CriterionFactory> other)
{
stop_criterion_factory_ = std::move(other);
}

GKO_CREATE_FACTORY_PARAMETERS(parameters, Factory)
{
/**
Expand Down Expand Up @@ -163,30 +133,10 @@ class Minres : public EnableLinOp<Minres<ValueType>>,
std::shared_ptr<const LinOp> system_matrix)
: EnableLinOp<Minres>(factory->get_executor(),
gko::transpose(system_matrix->get_size())),
parameters_{factory->get_parameters()},
system_matrix_{std::move(system_matrix)}
{
GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix_);
if (parameters_.generated_preconditioner) {
GKO_ASSERT_EQUAL_DIMENSIONS(parameters_.generated_preconditioner,
this);
Preconditionable::set_preconditioner(
parameters_.generated_preconditioner);
} else if (parameters_.preconditioner) {
Preconditionable::set_preconditioner(
parameters_.preconditioner->generate(system_matrix_));
} else {
Preconditionable::set_preconditioner(
matrix::Identity<ValueType>::create(this->get_executor(),
this->get_size()));
}
stop_criterion_factory_ =
stop::combine(std::move(parameters_.criteria));
}

private:
std::shared_ptr<const LinOp> system_matrix_{};
std::shared_ptr<const stop::CriterionFactory> stop_criterion_factory_{};
EnablePreconditionedIterativeSolver<ValueType, Minres<ValueType>>{
std::move(system_matrix), factory->get_parameters()},
parameters_{factory->get_parameters()}
{}
};


Expand Down
6 changes: 3 additions & 3 deletions reference/solver/minres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void initialize(
matrix::Dense<ValueType>* cos_prev, matrix::Dense<ValueType>* cos,
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
matrix::Dense<ValueType>* eta_next, matrix::Dense<ValueType>* eta,
Array<stopping_status>* stop_status)
array<stopping_status>* stop_status)
{
for (size_type j = 0; j < r->get_size()[1]; ++j) {
delta->at(j) = gamma->at(j) = cos_prev->at(j) = sin_prev->at(j) =
Expand Down Expand Up @@ -110,7 +110,7 @@ void step_1(std::shared_ptr<const DefaultExecutor> exec,
matrix::Dense<ValueType>* sin_prev, matrix::Dense<ValueType>* sin,
matrix::Dense<ValueType>* eta, matrix::Dense<ValueType>* eta_next,
typename matrix::Dense<ValueType>::absolute_type* tau,
const Array<stopping_status>* stop_status)
const array<stopping_status>* stop_status)
{
for (size_type j = 0; j < alpha->get_size()[1]; ++j) {
if (stop_status->get_const_data()[j].has_stopped()) {
Expand Down Expand Up @@ -149,7 +149,7 @@ void step_2(std::shared_ptr<const DefaultExecutor> exec,
matrix::Dense<ValueType>* beta, matrix::Dense<ValueType>* gamma,
matrix::Dense<ValueType>* delta, matrix::Dense<ValueType>* cos,
matrix::Dense<ValueType>* eta,
const Array<stopping_status>* stop_status)
const array<stopping_status>* stop_status)
{
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
Expand Down
80 changes: 41 additions & 39 deletions test/solver/minres_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "core/solver/minres_kernels.hpp"
#include "core/test/utils.hpp"
#include "core/test/utils/matrix_utils.hpp"
#include "core/utils/matrix_utils.hpp"
#include "test/utils/executor.hpp"

namespace {
Expand Down Expand Up @@ -80,14 +80,19 @@ class Minres : public ::testing::Test {
}

std::unique_ptr<Mtx> gen_mtx(gko::size_type num_rows,
gko::size_type num_cols, gko::size_type stride)
gko::size_type num_cols, gko::size_type stride,
bool make_hermitian)
{
auto tmp_mtx = gko::test::generate_random_matrix<Mtx>(
auto tmp_mtx = gko::test::generate_random_matrix_data<value_type , gko::int32>(
num_rows, num_cols,
std::uniform_int_distribution<>(num_cols, num_cols),
std::normal_distribution<value_type>(-1.0, 1.0), rand_engine, ref);
std::normal_distribution<value_type>(-1.0, 1.0), rand_engine);
if (make_hermitian) {
gko::utils::make_unit_diagonal(tmp_mtx);
gko::utils::make_hermitian(tmp_mtx);
}
auto result = Mtx::create(ref, gko::dim<2>{num_rows, num_cols}, stride);
result->copy_from(tmp_mtx.get());
result->read(tmp_mtx);
return result;
}

Expand All @@ -96,31 +101,31 @@ class Minres : public ::testing::Test {
gko::size_type m = 597;
gko::size_type n = 43;
// all vectors need the same stride as b, except x
b = gen_mtx(m, n, n + 2);
r = gen_mtx(m, n, n + 2);
z = gen_mtx(m, n, n + 2);
z_tilde = gen_mtx(m, n, n + 2);
p = gen_mtx(m, n, n + 2);
p_prev = gen_mtx(m, n, n + 2);
q = gen_mtx(m, n, n + 2);
q_prev = gen_mtx(m, n, n + 2);
v = gen_mtx(m, n, n + 2);
x = gen_mtx(m, n, n + 3);
alpha = gen_mtx(1, n, n);
beta = gen_mtx(1, n, n)->compute_absolute();
gamma = gen_mtx(1, n, n);
delta = gen_mtx(1, n, n);
cos_prev = gen_mtx(1, n, n);
cos = gen_mtx(1, n, n);
sin_prev = gen_mtx(1, n, n);
sin = gen_mtx(1, n, n);
eta_next = gen_mtx(1, n, n);
eta = gen_mtx(1, n, n);
tau = gen_mtx(1, n, n)->compute_absolute();
b = gen_mtx(m, n, n + 2, false);
r = gen_mtx(m, n, n + 2, false);
z = gen_mtx(m, n, n + 2, false);
z_tilde = gen_mtx(m, n, n + 2, false);
p = gen_mtx(m, n, n + 2, false);
p_prev = gen_mtx(m, n, n + 2, false);
q = gen_mtx(m, n, n + 2, false);
q_prev = gen_mtx(m, n, n + 2, false);
v = gen_mtx(m, n, n + 2, false);
x = gen_mtx(m, n, n + 3, false);
alpha = gen_mtx(1, n, n, false);
beta = gen_mtx(1, n, n, false)->compute_absolute();
gamma = gen_mtx(1, n, n, false);
delta = gen_mtx(1, n, n, false);
cos_prev = gen_mtx(1, n, n, false);
cos = gen_mtx(1, n, n, false);
sin_prev = gen_mtx(1, n, n, false);
sin = gen_mtx(1, n, n, false);
eta_next = gen_mtx(1, n, n, false);
eta = gen_mtx(1, n, n, false);
tau = gen_mtx(1, n, n, false)->compute_absolute();
// check correct handling for zero values
beta->at(2) = gko::zero<value_type>();
stop_status =
std::make_unique<gko::Array<gko::stopping_status>>(ref, n);
std::make_unique<gko::array<gko::stopping_status>>(ref, n);
for (size_t i = 0; i < stop_status->get_num_elems(); ++i) {
stop_status->get_data()[i].reset();
}
Expand Down Expand Up @@ -148,7 +153,7 @@ class Minres : public ::testing::Test {
d_cos = gko::clone(exec, cos);
d_sin_prev = gko::clone(exec, sin_prev);
d_sin = gko::clone(exec, sin);
d_stop_status = std::make_unique<gko::Array<gko::stopping_status>>(
d_stop_status = std::make_unique<gko::array<gko::stopping_status>>(
exec, *stop_status);
}

Expand Down Expand Up @@ -203,8 +208,8 @@ class Minres : public ::testing::Test {
std::unique_ptr<Mtx> d_sin_prev;
std::unique_ptr<Mtx> d_sin;

std::unique_ptr<gko::Array<gko::stopping_status>> stop_status;
std::unique_ptr<gko::Array<gko::stopping_status>> d_stop_status;
std::unique_ptr<gko::array<gko::stopping_status>> stop_status;
std::unique_ptr<gko::array<gko::stopping_status>> d_stop_status;
};


Expand Down Expand Up @@ -296,10 +301,9 @@ TEST_F(Minres, MinresStep2IsEquivalentToStep2)

TEST_F(Minres, ApplyIsEquivalentToRef)
{
auto mtx = gen_mtx(50, 50, 53);
gko::test::make_hermitian(mtx.get());
auto x = gen_mtx(50, 1, 5);
auto b = gen_mtx(50, 1, 4);
auto mtx = gen_mtx(50, 50, 53, true);
auto x = gen_mtx(50, 1, 5, false);
auto b = gen_mtx(50, 1, 4, false);
auto d_mtx = gko::clone(exec, mtx);
auto d_x = gko::clone(exec, x);
auto d_b = gko::clone(exec, b);
Expand Down Expand Up @@ -331,11 +335,9 @@ TEST_F(Minres, ApplyIsEquivalentToRef)

TEST_F(Minres, PreconditionedApplyIsEquivalentToRef)
{

auto mtx = gen_mtx(50, 50, 53);
gko::test::make_hpd(mtx.get());
auto x = gen_mtx(50, 1, 5);
auto b = gen_mtx(50, 1, 4);
auto mtx = gen_mtx(50, 50, 53, true);
auto x = gen_mtx(50, 1, 5, false);
auto b = gen_mtx(50, 1, 4, false);
auto d_mtx = gko::clone(exec, mtx);
auto d_x = gko::clone(exec, x);
auto d_b = gko::clone(exec, b);
Expand Down

0 comments on commit 6a991bc

Please sign in to comment.