Skip to content

Commit

Permalink
Add review suggestions
Browse files Browse the repository at this point in the history
Co-authored-by: Yuhsiang Tsai <yhmtsai@gmail.com>
  • Loading branch information
greole and yhmtsai committed Oct 21, 2023
1 parent 74c3bbf commit 10a31d0
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
26 changes: 20 additions & 6 deletions core/distributed/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::apply_impl(
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::set_solver(
std::shared_ptr<const LinOp> new_solver)
{
auto exec = this->get_executor();
if (new_solver) {
if (new_solver->get_executor() != exec) {
new_solver = gko::clone(exec, new_solver);
}
}
this->local_solver_ = new_solver;
}


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
std::shared_ptr<const LinOp> system_matrix)
Expand All @@ -113,13 +127,13 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
}

if (parameters_.local_solver) {
this->local_solver_ = parameters_.local_solver->generate(
as<experimental::distributed::Matrix<ValueType, LocalIndexType,
GlobalIndexType>>(
system_matrix)
->get_local_matrix());
this->set_solver(gko::share(parameters_.local_solver->generate(
as<experimental::distributed::Matrix<
ValueType, LocalIndexType, GlobalIndexType>>(system_matrix)
->get_local_matrix())));

} else {
this->local_solver_ = parameters_.generated_local_solver;
this->set_solver(parameters_.generated_local_solver);
}
}

Expand Down
10 changes: 8 additions & 2 deletions include/ginkgo/core/distributed/preconditioner/schwarz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class Schwarz
/**
* Generated Inner solvers.
*/
std::shared_ptr<const LinOp> GKO_FACTORY_PARAMETER(
std::shared_ptr<const LinOp> GKO_FACTORY_PARAMETER_SCALAR(
generated_local_solver, nullptr);
};
GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory);
Expand Down Expand Up @@ -136,7 +136,6 @@ class Schwarz
*/
void generate(std::shared_ptr<const LinOp> system_matrix);


void apply_impl(const LinOp* b, LinOp* x) const override;

template <typename VectorType>
Expand All @@ -146,6 +145,13 @@ class Schwarz
LinOp* x) const override;

private:
/**
* Sets the solver operator used as the local solver.
*
* @param new_solver the new local solver
*/
void set_solver(std::shared_ptr<const LinOp> new_solver);

std::shared_ptr<const LinOp> local_solver_;
};

Expand Down
3 changes: 2 additions & 1 deletion test/mpi/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ TYPED_TEST(SchwarzPreconditioner, GenerateFailsIfNoSolverProvided)
{
using prec = typename TestFixture::dist_prec_type;
auto schwarz_no_solver = prec::build().on(this->exec);

ASSERT_THROW(schwarz_no_solver->generate(this->dist_mat),
gko::InvalidStateError);
}
Expand Down Expand Up @@ -273,7 +274,7 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditionedSolverWithPregenSolver)
auto dist_x_pregen = gko::share(this->dist_x->clone());

precond->apply(this->dist_b.get(), dist_x.get());
precond->apply(this->dist_b.get(), dist_x_pregen.get());
precond_pregen->apply(this->dist_b.get(), dist_x_pregen.get());

GKO_ASSERT_MTX_NEAR(dist_x->get_local_vector(),
dist_x_pregen->get_local_vector(),
Expand Down

0 comments on commit 10a31d0

Please sign in to comment.