Skip to content

Commit

Permalink
recover the test and update doc
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Grützmacher <thomas.gruetzmacher@kit.edu>
  • Loading branch information
yhmtsai and Thomas Grützmacher committed Nov 7, 2022
1 parent 3a2d29c commit 00a4e75
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 74 deletions.
32 changes: 18 additions & 14 deletions core/solver/multigrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ struct MultigridState {
/**
* allocate_memory is a helper function to allocate the memory of one level
*
* @tparam VT the value type of memory
* @tparam ValueType the value type of memory
*
* @param level the current level index
* @param cycle the multigrid cycle
* @param current_nrows the number of rows of current fine matrix
* @param next_nrows the number of rows of next coarse matrix
*/
template <typename VT>
template <typename ValueType>
void allocate_memory(int level, multigrid::cycle cycle,
size_type current_nrows, size_type next_nrows);

Expand All @@ -222,11 +222,11 @@ struct MultigridState {
/**
* @copydoc run_cycle
*
* @tparam VT the value type
* @tparam ValueType the value type
*
* @note it is the version with known ValueType
*/
template <typename VT>
template <typename ValueType>
void run_cycle(multigrid::cycle cycle, size_type level,
const std::shared_ptr<const LinOp>& matrix, const LinOp* b,
LinOp* x, bool x_is_zero, bool is_first, bool is_end);
Expand Down Expand Up @@ -287,13 +287,14 @@ void MultigridState::generate(const LinOp* system_matrix_in,
}
}

template <typename VT>

template <typename ValueType>
void MultigridState::allocate_memory(int level, multigrid::cycle cycle,
size_type current_nrows,
size_type next_nrows)
{
using vec = matrix::Dense<VT>;
using norm_vec = matrix::Dense<remove_complex<VT>>;
using vec = matrix::Dense<ValueType>;
using norm_vec = matrix::Dense<remove_complex<ValueType>>;

auto exec =
as<LinOp>(multigrid->get_mg_level_list().at(level))->get_executor();
Expand All @@ -302,18 +303,19 @@ void MultigridState::allocate_memory(int level, multigrid::cycle cycle,
// allocate the previous level
g_list.emplace_back(vec::create(exec, dim<2>{current_nrows, nrhs}));
e_list.emplace_back(vec::create(exec, dim<2>{current_nrows, nrhs}));
next_one_list.emplace_back(initialize<vec>({gko::one<VT>()}, exec));
next_one_list.emplace_back(initialize<vec>({one<ValueType>()}, exec));
}
if (level + 1 == multigrid->get_mg_level_list().size()) {
// the last level allocate the g, e for coarsest solver
g_list.emplace_back(vec::create(exec, dim<2>{next_nrows, nrhs}));
e_list.emplace_back(vec::create(exec, dim<2>{next_nrows, nrhs}));
next_one_list.emplace_back(initialize<vec>({gko::one<VT>()}, exec));
next_one_list.emplace_back(initialize<vec>({one<ValueType>()}, exec));
}
one_list.emplace_back(initialize<vec>({gko::one<VT>()}, exec));
neg_one_list.emplace_back(initialize<vec>({-gko::one<VT>()}, exec));
one_list.emplace_back(initialize<vec>({one<ValueType>()}, exec));
neg_one_list.emplace_back(initialize<vec>({-one<ValueType>()}, exec));
}


void MultigridState::run_cycle(multigrid::cycle cycle, size_type level,
const std::shared_ptr<const LinOp>& matrix,
const LinOp* b, LinOp* x, bool x_is_zero,
Expand All @@ -334,7 +336,8 @@ void MultigridState::run_cycle(multigrid::cycle cycle, size_type level,
});
}

template <typename VT>

template <typename ValueType>
void MultigridState::run_cycle(multigrid::cycle cycle, size_type level,
const std::shared_ptr<const LinOp>& matrix,
const LinOp* b, LinOp* x, bool x_is_zero,
Expand Down Expand Up @@ -369,7 +372,8 @@ void MultigridState::run_cycle(multigrid::cycle cycle, size_type level,
if (x_is_zero) {
// when level is zero, the x_ptr is already filled by zero
if (level != 0) {
dynamic_cast<matrix::Dense<VT>*>(x_ptr)->fill(zero<VT>());
dynamic_cast<matrix::Dense<ValueType>*>(x_ptr)->fill(
zero<ValueType>());
}
if (auto pre_allow_zero_input =
std::dynamic_pointer_cast<const ApplyWithInitialGuess>(
Expand Down Expand Up @@ -398,7 +402,7 @@ void MultigridState::run_cycle(multigrid::cycle cycle, size_type level,
// next level
if (level + 1 == total_level) {
// the coarsest solver use the last level valuetype
as_vec<VT>(e)->fill(zero<VT>());
as_vec<ValueType>(e)->fill(zero<ValueType>());
}
auto next_level_matrix =
(level + 1 < total_level)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ int main(int argc, char* argv[])
// Create multigrid factory
std::shared_ptr<gko::LinOpFactory> multigrid_gen;
if (use_mixed) {
std::cout << "USE" << std::endl;
multigrid_gen =
mg::build()
.with_max_levels(10u)
Expand All @@ -185,7 +184,6 @@ int main(int argc, char* argv[])
.with_mg_level(mg_level_gen, mg_level_gen_f)
.with_level_selector([](const gko::size_type level,
const gko::LinOp*) -> gko::size_type {
std::cout << "level " << level << std::endl;
return level >= 1 ? 1 : 0;
})
.with_coarsest_solver(coarsest_gen_f)
Expand All @@ -209,7 +207,6 @@ int main(int argc, char* argv[])
gko::stop::Iteration::build().with_max_iters(1u).on(exec))
.on(exec);
}
std::cout << "multigrid_gen " << multigrid_gen.get() << std::endl;
// Create solver factory
auto solver_gen = cg::build()
.with_criteria(iter_stop, tol_stop)
Expand Down
5 changes: 1 addition & 4 deletions include/ginkgo/core/solver/multigrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ namespace gko {
namespace solver {


class Multigrid;


/**
* @brief The solver multigrid namespace.
*
Expand Down Expand Up @@ -529,7 +526,7 @@ class Multigrid : public EnableLinOp<Multigrid>,
std::function<size_type(const size_type, const LinOp*)> solver_selector_;

/**
* Manages three vectors as a cache, so there is no need to allocate them
* Manages MultigridState as a cache, so there is no need to allocate them
* every time an intermediate vector is required. Copying an instance
* will only yield an empty object since copying the cached vector would
* not make sense.
Expand Down
106 changes: 53 additions & 53 deletions test/solver/multigrid_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,56 +166,56 @@ class Multigrid : public CommonTestFixture {
};


// TEST_F(Multigrid, MultigridKCycleStep1IsEquivalentToRef)
// {
// initialize_data();

// gko::kernels::reference::multigrid::kcycle_step_1(
// ref, alpha.get(), rho.get(), v.get(), g.get(), d.get(), e.get());
// gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_step_1(
// exec, d_alpha.get(), d_rho.get(), d_v.get(), d_g.get(), d_d.get(),
// d_e.get());

// GKO_ASSERT_MTX_NEAR(d_g, g, 1e-14);
// GKO_ASSERT_MTX_NEAR(d_d, d, 1e-14);
// GKO_ASSERT_MTX_NEAR(d_e, e, 1e-14);
// }


// TEST_F(Multigrid, MultigridKCycleStep2IsEquivalentToRef)
// {
// initialize_data();

// gko::kernels::reference::multigrid::kcycle_step_2(
// ref, alpha.get(), rho.get(), gamma.get(), beta.get(), zeta.get(),
// d.get(), e.get());
// gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_step_2(
// exec, d_alpha.get(), d_rho.get(), d_gamma.get(), d_beta.get(),
// d_zeta.get(), d_d.get(), d_e.get());

// GKO_ASSERT_MTX_NEAR(d_e, e, 1e-14);
// }


// TEST_F(Multigrid, MultigridKCycleCheckStopIsEquivalentToRef)
// {
// initialize_data();
// bool is_stop_10;
// bool d_is_stop_10;
// bool is_stop_5;
// bool d_is_stop_5;

// gko::kernels::reference::multigrid::kcycle_check_stop(
// ref, old_norm.get(), new_norm.get(), 1.0, is_stop_10);
// gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_check_stop(
// exec, d_old_norm.get(), d_new_norm.get(), 1.0, d_is_stop_10);
// gko::kernels::reference::multigrid::kcycle_check_stop(
// ref, old_norm.get(), new_norm.get(), 0.5, is_stop_5);
// gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_check_stop(
// exec, d_old_norm.get(), d_new_norm.get(), 0.5, d_is_stop_5);

// GKO_ASSERT_EQ(d_is_stop_10, is_stop_10);
// GKO_ASSERT_EQ(d_is_stop_10, true);
// GKO_ASSERT_EQ(d_is_stop_5, is_stop_5);
// GKO_ASSERT_EQ(d_is_stop_5, false);
// }
TEST_F(Multigrid, MultigridKCycleStep1IsEquivalentToRef)
{
initialize_data();

gko::kernels::reference::multigrid::kcycle_step_1(
ref, alpha.get(), rho.get(), v.get(), g.get(), d.get(), e.get());
gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_step_1(
exec, d_alpha.get(), d_rho.get(), d_v.get(), d_g.get(), d_d.get(),
d_e.get());

GKO_ASSERT_MTX_NEAR(d_g, g, 1e-14);
GKO_ASSERT_MTX_NEAR(d_d, d, 1e-14);
GKO_ASSERT_MTX_NEAR(d_e, e, 1e-14);
}


TEST_F(Multigrid, MultigridKCycleStep2IsEquivalentToRef)
{
initialize_data();

gko::kernels::reference::multigrid::kcycle_step_2(
ref, alpha.get(), rho.get(), gamma.get(), beta.get(), zeta.get(),
d.get(), e.get());
gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_step_2(
exec, d_alpha.get(), d_rho.get(), d_gamma.get(), d_beta.get(),
d_zeta.get(), d_d.get(), d_e.get());

GKO_ASSERT_MTX_NEAR(d_e, e, 1e-14);
}


TEST_F(Multigrid, MultigridKCycleCheckStopIsEquivalentToRef)
{
initialize_data();
bool is_stop_10;
bool d_is_stop_10;
bool is_stop_5;
bool d_is_stop_5;

gko::kernels::reference::multigrid::kcycle_check_stop(
ref, old_norm.get(), new_norm.get(), 1.0, is_stop_10);
gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_check_stop(
exec, d_old_norm.get(), d_new_norm.get(), 1.0, d_is_stop_10);
gko::kernels::reference::multigrid::kcycle_check_stop(
ref, old_norm.get(), new_norm.get(), 0.5, is_stop_5);
gko::kernels::EXEC_NAMESPACE::multigrid::kcycle_check_stop(
exec, d_old_norm.get(), d_new_norm.get(), 0.5, d_is_stop_5);

GKO_ASSERT_EQ(d_is_stop_10, is_stop_10);
GKO_ASSERT_EQ(d_is_stop_10, true);
GKO_ASSERT_EQ(d_is_stop_5, is_stop_5);
GKO_ASSERT_EQ(d_is_stop_5, false);
}

0 comments on commit 00a4e75

Please sign in to comment.