Skip to content

Commit

Permalink
use preconditioner not inner_solver for chebyshev
Browse files Browse the repository at this point in the history
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
  • Loading branch information
yhmtsai and MarcelKoch committed Aug 8, 2023
1 parent f9d0e28 commit 6a66ee5
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 87 deletions.
59 changes: 18 additions & 41 deletions core/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,10 @@ Chebyshev<ValueType>::Chebyshev(const Factory* factory,
std::shared_ptr<const LinOp> system_matrix)
: EnableLinOp<Chebyshev>(factory->get_executor(),
gko::transpose(system_matrix->get_size())),
EnableSolverBase<Chebyshev>{std::move(system_matrix)},
EnableIterativeBase<Chebyshev>{
stop::combine(factory->get_parameters().criteria)},
EnablePreconditionedIterativeSolver<ValueType, Chebyshev<ValueType>>{
std::move(system_matrix), factory->get_parameters()},
parameters_{factory->get_parameters()}
{
if (parameters_.generated_solver) {
this->set_solver(parameters_.generated_solver);
} else if (parameters_.solver) {
this->set_solver(
parameters_.solver->generate(this->get_system_matrix()));
} else {
this->set_solver(matrix::Identity<ValueType>::create(
this->get_executor(), this->get_size()));
}
this->set_default_initial_guess(parameters_.default_initial_guess);
center_ = (std::get<0>(parameters_.foci) + std::get<1>(parameters_.foci)) /
ValueType{2};
Expand All @@ -89,30 +79,17 @@ Chebyshev<ValueType>::Chebyshev(const Factory* factory,
}


template <typename ValueType>
void Chebyshev<ValueType>::set_solver(std::shared_ptr<const LinOp> new_solver)
{
auto exec = this->get_executor();
if (new_solver) {
GKO_ASSERT_EQUAL_DIMENSIONS(new_solver, this);
GKO_ASSERT_IS_SQUARE_MATRIX(new_solver);
if (new_solver->get_executor() != exec) {
new_solver = gko::clone(exec, new_solver);
}
}
solver_ = new_solver;
}


template <typename ValueType>
Chebyshev<ValueType>& Chebyshev<ValueType>::operator=(const Chebyshev& other)
{
if (&other != this) {
EnableLinOp<Chebyshev>::operator=(other);
EnableSolverBase<Chebyshev>::operator=(other);
EnableIterativeBase<Chebyshev>::operator=(other);
EnablePreconditionedIterativeSolver<
ValueType, Chebyshev<ValueType>>::operator=(other);
this->parameters_ = other.parameters_;
this->set_solver(other.get_solver());
// the workspace is not copied.
this->num_generated_scalar_ = 0;
this->num_max_generation_ = 3;
}
return *this;
}
Expand All @@ -123,11 +100,11 @@ Chebyshev<ValueType>& Chebyshev<ValueType>::operator=(Chebyshev&& other)
{
if (&other != this) {
EnableLinOp<Chebyshev>::operator=(std::move(other));
EnableSolverBase<Chebyshev>::operator=(std::move(other));
EnableIterativeBase<Chebyshev>::operator=(std::move(other));
this->parameters_ = std::exchange(other.parameters_, parameters_type{});
this->set_solver(other.get_solver());
other.set_solver(nullptr);
EnablePreconditionedIterativeSolver<
ValueType, Chebyshev<ValueType>>::operator=(std::move(other));
// the workspace is not moved.
this->num_generated_scalar_ = 0;
this->num_max_generation_ = 3;
}
return *this;
}
Expand All @@ -153,8 +130,8 @@ template <typename ValueType>
std::unique_ptr<LinOp> Chebyshev<ValueType>::transpose() const
{
return build()
.with_generated_solver(
share(as<Transposable>(this->get_solver())->transpose()))
.with_generated_preconditioner(
share(as<Transposable>(this->get_preconditioner())->transpose()))
.with_criteria(this->get_stop_criterion_factory())
.with_foci(parameters_.foci)
.on(this->get_executor())
Expand All @@ -167,8 +144,8 @@ template <typename ValueType>
std::unique_ptr<LinOp> Chebyshev<ValueType>::conj_transpose() const
{
return build()
.with_generated_solver(
share(as<Transposable>(this->get_solver())->conj_transpose()))
.with_generated_preconditioner(share(
as<Transposable>(this->get_preconditioner())->conj_transpose()))
.with_criteria(this->get_stop_criterion_factory())
.with_foci(conj(std::get<0>(parameters_.foci)),
conj(std::get<1>(parameters_.foci)))
Expand Down Expand Up @@ -283,13 +260,13 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
break;
}

if (solver_->apply_uses_initial_guess()) {
if (this->get_preconditioner()->apply_uses_initial_guess()) {
// Use the inner solver to solve
// A * inner_solution = residual
// with residual as initial guess.
inner_solution->copy_from(residual_ptr);
}
solver_->apply(residual_ptr, inner_solution);
this->get_preconditioner()->apply(residual_ptr, inner_solution);
size_type index =
(iter >= num_max_generation_) ? num_max_generation_ : iter;
auto alpha_scalar =
Expand Down
39 changes: 20 additions & 19 deletions core/test/solver/chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ TYPED_TEST(Chebyshev, ChebyshevFactoryCreatesCorrectSolver)
{
using Solver = typename TestFixture::Solver;
ASSERT_EQ(this->solver->get_size(), gko::dim<2>(3, 3));
auto cg_solver = static_cast<Solver*>(this->solver.get());
ASSERT_NE(cg_solver->get_system_matrix(), nullptr);
ASSERT_EQ(cg_solver->get_system_matrix(), this->mtx);
auto solver = static_cast<Solver*>(this->solver.get());
ASSERT_NE(solver->get_system_matrix(), nullptr);
ASSERT_EQ(solver->get_system_matrix(), this->mtx);
}


Expand Down Expand Up @@ -196,20 +196,20 @@ TYPED_TEST(Chebyshev, CanSetInnerSolverInFactory)
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_solver(
.with_preconditioner(
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(
this->exec))
.on(this->exec))
.on(this->exec);
auto solver = chebyshev_factory->generate(this->mtx);
auto inner_solver = dynamic_cast<const Solver*>(
static_cast<Solver*>(solver.get())->get_solver().get());
auto preconditioner = dynamic_cast<const Solver*>(
static_cast<Solver*>(solver.get())->get_preconditioner().get());

ASSERT_NE(inner_solver, nullptr);
ASSERT_EQ(inner_solver->get_size(), gko::dim<2>(3, 3));
ASSERT_EQ(inner_solver->get_system_matrix(), this->mtx);
ASSERT_NE(preconditioner, nullptr);
ASSERT_EQ(preconditioner->get_size(), gko::dim<2>(3, 3));
ASSERT_EQ(preconditioner->get_system_matrix(), this->mtx);
}


Expand All @@ -227,13 +227,13 @@ TYPED_TEST(Chebyshev, CanSetGeneratedInnerSolverInFactory)
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_generated_solver(chebyshev_solver)
.with_generated_preconditioner(chebyshev_solver)
.on(this->exec);
auto solver = chebyshev_factory->generate(this->mtx);
auto inner_solver = solver->get_solver();
auto preconditioner = solver->get_preconditioner();

ASSERT_NE(inner_solver.get(), nullptr);
ASSERT_EQ(inner_solver.get(), chebyshev_solver.get());
ASSERT_NE(preconditioner.get(), nullptr);
ASSERT_EQ(preconditioner.get(), chebyshev_solver.get());
}


Expand Down Expand Up @@ -279,7 +279,7 @@ TYPED_TEST(Chebyshev, ThrowsOnWrongInnerSolverInFactory)
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.with_generated_solver(chebyshev_solver)
.with_generated_preconditioner(chebyshev_solver)
.on(this->exec);

ASSERT_THROW(chebyshev_factory->generate(this->mtx),
Expand All @@ -303,11 +303,11 @@ TYPED_TEST(Chebyshev, CanSetInnerSolver)
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec))
.on(this->exec);
auto solver = chebyshev_factory->generate(this->mtx);
solver->set_solver(chebyshev_solver);
auto inner_solver = solver->get_solver();
solver->set_preconditioner(chebyshev_solver);
auto preconditioner = solver->get_preconditioner();

ASSERT_NE(inner_solver.get(), nullptr);
ASSERT_EQ(inner_solver.get(), chebyshev_solver.get());
ASSERT_NE(preconditioner.get(), nullptr);
ASSERT_EQ(preconditioner.get(), chebyshev_solver.get());
}


Expand Down Expand Up @@ -353,7 +353,8 @@ TYPED_TEST(Chebyshev, ThrowOnWrongInnerSolverSet)
.on(this->exec);
auto solver = chebyshev_factory->generate(this->mtx);

ASSERT_THROW(solver->set_solver(chebyshev_solver), gko::DimensionMismatch);
ASSERT_THROW(solver->set_preconditioner(chebyshev_solver),
gko::DimensionMismatch);
}


Expand Down
37 changes: 12 additions & 25 deletions include/ginkgo/core/solver/chebyshev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace solver {
* solution = initial_guess
* while not converged:
* residual = b - A solution
* error = solver(A, residual)
* error = preconditioner(A) * residual
* solution = solution + alpha_i * error + beta_i * (solution_i -
* solution_{i-1})
* ```
Expand All @@ -76,11 +76,12 @@ namespace solver {
* @ingroup LinOp
*/
template <typename ValueType = default_precision>
class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
public EnableSolverBase<Chebyshev<ValueType>>,
public EnableIterativeBase<Chebyshev<ValueType>>,
public EnableApplyWithInitialGuess<Chebyshev<ValueType>>,
public Transposable {
class Chebyshev
: public EnableLinOp<Chebyshev<ValueType>>,
public EnablePreconditionedIterativeSolver<ValueType,
Chebyshev<ValueType>>,
public EnableApplyWithInitialGuess<Chebyshev<ValueType>>,
public Transposable {
friend class EnableLinOp<Chebyshev>;
friend class EnablePolymorphicObject<Chebyshev, LinOp>;
friend class EnableApplyWithInitialGuess<Chebyshev>;
Expand All @@ -104,20 +105,6 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
initial_guess_mode::provided;
}

/**
* Returns the solver operator used as the inner solver.
*
* @return the solver operator used as the inner solver
*/
std::shared_ptr<const LinOp> get_solver() const { return solver_; }

/**
* Sets the solver operator used as the inner solver.
*
* @param new_solver the new inner solver
*/
void set_solver(std::shared_ptr<const LinOp> new_solver);

/**
* Copy-assigns a Chebyshev solver. Preserves the executor, shallow-copies
* inner solver, stopping criterion and system matrix. If the executors
Expand Down Expand Up @@ -158,18 +145,18 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
GKO_FACTORY_PARAMETER_VECTOR(criteria, nullptr);

/**
* Inner solver (preconditioner) factory. If not provided this will
* Preconditioner factory. If not provided this will
* result in a non-preconditioned Chebyshev iteration.
*/
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
solver, nullptr);
preconditioner, nullptr);

/**
* Already generated solver. If one is provided, the factory `solver`
* will be ignored.
* Already generated preconditioner. If one is provided, the factory
* `preconditioner` will be ignored.
*/
std::shared_ptr<const LinOp> GKO_FACTORY_PARAMETER_SCALAR(
generated_solver, nullptr);
generated_preconditioner, nullptr);

/**
* The pair of foci of ellipse, which covers the eigenvalues of
Expand Down
1 change: 1 addition & 0 deletions include/ginkgo/ginkgo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/chebyshev.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
Expand Down
4 changes: 2 additions & 2 deletions reference/test/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ TYPED_TEST(Chebyshev, SolvesTriangularSystemWithIterativeInnerSolver)
using value_type = typename TestFixture::value_type;

const gko::remove_complex<value_type> inner_reduction_factor = 1e-2;
auto inner_solver_factory = gko::share(
auto precond_factory = gko::share(
gko::solver::Gmres<value_type>::build()
.with_criteria(gko::stop::ResidualNorm<value_type>::build()
.with_reduction_factor(inner_reduction_factor)
Expand All @@ -354,7 +354,7 @@ TYPED_TEST(Chebyshev, SolvesTriangularSystemWithIterativeInnerSolver)
gko::stop::ResidualNorm<value_type>::build()
.with_reduction_factor(r<value_type>::value)
.on(this->exec))
.with_solver(inner_solver_factory)
.with_preconditioner(precond_factory)
.with_foci(value_type{0.9}, value_type{1.1})
.on(this->exec);
auto b = gko::initialize<Mtx>({3.9, 9.0, 2.2}, this->exec);
Expand Down
10 changes: 10 additions & 0 deletions test/test_install/test_install.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,16 @@ int main()
check_solver<Solver>(exec, A_raw, b, x);
}

// core/solver/chebyshev.hpp
{
using Solver = gko::solver::Chebyshev<>;
auto test =
Solver::build()
.with_criteria(
gko::stop::Iteration::build().with_max_iters(1u).on(exec))
.on(exec);
}

// core/solver/fcg.hpp
{
using Solver = gko::solver::Fcg<>;
Expand Down

0 comments on commit 6a66ee5

Please sign in to comment.