diff --git a/core/config/multigrid_config.cpp b/core/config/multigrid_config.cpp index 55b65a6c345..0da6c42abcf 100644 --- a/core/config/multigrid_config.cpp +++ b/core/config/multigrid_config.cpp @@ -1,34 +1,6 @@ -/************************************************************* -Copyright (c) 2017-2023, the Ginkgo authors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - -1. Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS -IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED -TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*************************************************************/ +// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause #include #include @@ -37,6 +9,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include + #include "core/config/config.hpp" #include "core/config/dispatch.hpp" @@ -45,161 +18,38 @@ namespace gko { namespace config { -template -class PgmConfigurator { -public: - static std::unique_ptr< - typename multigrid::Pgm::Factory> - build_from_config(const pnode& config, const registry& context, - std::shared_ptr exec, - type_descriptor td_for_child) - { - auto factory = multigrid::Pgm::build(); - SET_VALUE(factory, unsigned, max_iterations, config); - SET_VALUE(factory, double, max_unassigned_ratio, config); - SET_VALUE(factory, bool, deterministic, config); - SET_VALUE(factory, bool, skip_sorting, config); - return factory.on(exec); - } -}; - - template <> -std::unique_ptr build_from_config( - const pnode& config, const registry& context, - std::shared_ptr& exec, gko::config::type_descriptor td) +deferred_factory_parameter +build_from_config(const pnode& config, + const registry& context, + gko::config::type_descriptor td) { auto updated = update_type(config, td); - return dispatch( - updated.first + "," + updated.second, config, context, exec, updated, + return dispatch( + updated.first + "," + updated.second, config, context, updated, value_type_list(), index_type_list()); } - -template -class FixedCoarseningConfigurator { -public: - static std::unique_ptr< - typename multigrid::FixedCoarsening::Factory> - build_from_config(const pnode& config, const registry& context, - std::shared_ptr exec, - type_descriptor td_for_child) - { - auto factory = - multigrid::FixedCoarsening::build(); - SET_VALUE_ARRAY(factory, array, coarse_rows, config, exec); - SET_VALUE(factory, bool, skip_sorting, config); - return factory.on(exec); - } -}; - - template <> -std::unique_ptr +deferred_factory_parameter build_from_config( const pnode& config, const registry& context, - std::shared_ptr& exec, gko::config::type_descriptor td) + gko::config::type_descriptor td) { auto updated = update_type(config, td); - return dispatch( - updated.first + "," + updated.second, config, context, exec, updated, + return dispatch( + updated.first + "," + updated.second, config, context, updated, value_type_list(), index_type_list()); } - -std::function get_selector( - std::string key) -{ - static std::map> - selector_map{ - {{"first_for_top", [](const size_type level, const LinOp*) { - return (level == 0) ? 0 : 1; - }}}}; - return selector_map.at(key); -} - - -class MultigridConfigurator { -public: - static std::unique_ptr - build_from_config(const pnode& config, const registry& context, - std::shared_ptr exec, - type_descriptor td_for_child) - { - auto factory = solver::Multigrid::build(); - SET_POINTER_VECTOR(factory, const stop::CriterionFactory, criteria, - config, context, exec, td_for_child); - SET_POINTER_VECTOR(factory, const gko::LinOpFactory, mg_level, config, - context, exec, td_for_child); - if (config.contains("level_selector")) { - auto str = config.at("level_selector").get_data(); - factory.with_level_selector(get_selector(str)); - } - SET_POINTER_VECTOR(factory, const LinOpFactory, pre_smoother, config, - context, exec, td_for_child); - SET_POINTER_VECTOR(factory, const LinOpFactory, post_smoother, config, - context, exec, td_for_child); - SET_POINTER_VECTOR(factory, const LinOpFactory, mid_smoother, config, - context, exec, td_for_child); - SET_VALUE(factory, bool, post_uses_pre, config); - if (config.contains("mid_case")) { - auto str = config.at("mid_case").get_data(); - if (str == "both") { - factory.with_mid_case(solver::multigrid::mid_smooth_type::both); - } else if (str == "post_smoother") { - factory.with_mid_case( - solver::multigrid::mid_smooth_type::post_smoother); - } else if (str == "pre_smoother") { - factory.with_mid_case( - solver::multigrid::mid_smooth_type::pre_smoother); - } else if (str == "standalone") { - factory.with_mid_case( - solver::multigrid::mid_smooth_type::standalone); - } else { - GKO_INVALID_STATE("Not valid mid_smooth_type value"); - } - } - SET_VALUE(factory, size_type, max_levels, config); - SET_VALUE(factory, size_type, min_coarse_rows, config); - SET_POINTER_VECTOR(factory, const LinOpFactory, coarsest_solver, config, - context, exec, td_for_child); - if (config.contains("solver_selector")) { - auto str = config.at("solver_selector").get_data(); - factory.with_solver_selector(get_selector(str)); - } - if (config.contains("cycle")) { - auto str = config.at("cycle").get_data(); - if (str == "v") { - factory.with_cycle(solver::multigrid::cycle::v); - } else if (str == "w") { - factory.with_cycle(solver::multigrid::cycle::w); - } else if (str == "f") { - factory.with_cycle(solver::multigrid::cycle::f); - } else { - GKO_INVALID_STATE("Not valid cycle value"); - } - } - SET_VALUE(factory, size_type, kcycle_base, config); - SET_VALUE(factory, double, kcycle_rel_tol, config); - SET_VALUE(factory, std::complex, smoother_relax, config); - SET_VALUE(factory, size_type, smoother_iters, config); - SET_VALUE(factory, solver::initial_guess_mode, default_initial_guess, - config); - return factory.on(exec); - } -}; - - template <> -std::unique_ptr -build_from_config( - const pnode& config, const registry& context, - std::shared_ptr& exec, gko::config::type_descriptor td) +deferred_factory_parameter +build_from_config(const pnode& config, + const registry& context, + gko::config::type_descriptor td) { auto updated = update_type(config, td); - return MultigridConfigurator::build_from_config(config, context, exec, - updated); + return solver::Multigrid::build_from_config(config, context, updated); } diff --git a/core/multigrid/fixed_coarsening.cpp b/core/multigrid/fixed_coarsening.cpp index 413614abf28..281261214d7 100644 --- a/core/multigrid/fixed_coarsening.cpp +++ b/core/multigrid/fixed_coarsening.cpp @@ -20,6 +20,7 @@ #include "core/base/utils.hpp" #include "core/components/fill_array_kernels.hpp" +#include "core/config/config.hpp" #include "core/matrix/csr_builder.hpp" @@ -37,6 +38,19 @@ GKO_REGISTER_OPERATION(fill_seq_array, components::fill_seq_array); } // namespace fixed_coarsening +template +typename FixedCoarsening::parameters_type +FixedCoarsening::build_from_config( + const config::pnode& config, const config::registry& context, + config::type_descriptor td_for_child) +{ + auto factory = FixedCoarsening::build(); + // TODO: ARRAY + SET_VALUE(factory, bool, skip_sorting, config); + return factory; +} + + template void FixedCoarsening::generate() { diff --git a/core/multigrid/pgm.cpp b/core/multigrid/pgm.cpp index 22d9b8c6052..5fd2ff10de5 100644 --- a/core/multigrid/pgm.cpp +++ b/core/multigrid/pgm.cpp @@ -22,6 +22,7 @@ #include "core/base/utils.hpp" #include "core/components/fill_array_kernels.hpp" #include "core/components/format_conversion_kernels.hpp" +#include "core/config/config.hpp" #include "core/matrix/csr_builder.hpp" #include "core/multigrid/pgm_kernels.hpp" @@ -115,6 +116,21 @@ std::shared_ptr> generate_coarse( } // namespace +template +typename Pgm::parameters_type +Pgm::build_from_config( + const config::pnode& config, const config::registry& context, + config::type_descriptor td_for_child) +{ + auto factory = Pgm::build(); + SET_VALUE(factory, unsigned, max_iterations, config); + SET_VALUE(factory, double, max_unassigned_ratio, config); + SET_VALUE(factory, bool, deterministic, config); + SET_VALUE(factory, bool, skip_sorting, config); + return factory; +} + + template void Pgm::generate() { diff --git a/core/solver/multigrid.cpp b/core/solver/multigrid.cpp index ea3099cf185..e3956412ffb 100644 --- a/core/solver/multigrid.cpp +++ b/core/solver/multigrid.cpp @@ -27,6 +27,7 @@ #include "core/base/dispatch_helper.hpp" #include "core/components/fill_array_kernels.hpp" +#include "core/config/config.hpp" #include "core/solver/ir_kernels.hpp" #include "core/solver/multigrid_kernels.hpp" #include "core/solver/solver_base.hpp" @@ -467,6 +468,83 @@ void MultigridState::run_cycle(multigrid::cycle cycle, size_type level, } // namespace multigrid +std::function get_selector( + std::string key) +{ + static std::map> + selector_map{ + {{"first_for_top", [](const size_type level, const LinOp*) { + return (level == 0) ? 0 : 1; + }}}}; + return selector_map.at(key); +} + + +typename Multigrid::parameters_type Multigrid::build_from_config( + const config::pnode& config, const config::registry& context, + config::type_descriptor td_for_child) +{ + auto factory = Multigrid::build(); + SET_FACTORY_VECTOR(factory, const stop::CriterionFactory, criteria, config, + context, td_for_child); + SET_FACTORY_VECTOR(factory, const gko::LinOpFactory, mg_level, config, + context, td_for_child); + if (config.contains("level_selector")) { + auto str = config.at("level_selector").get_data(); + factory.with_level_selector(get_selector(str)); + } + SET_FACTORY_VECTOR(factory, const LinOpFactory, pre_smoother, config, + context, td_for_child); + SET_FACTORY_VECTOR(factory, const LinOpFactory, post_smoother, config, + context, td_for_child); + SET_FACTORY_VECTOR(factory, const LinOpFactory, mid_smoother, config, + context, td_for_child); + SET_VALUE(factory, bool, post_uses_pre, config); + if (config.contains("mid_case")) { + auto str = config.at("mid_case").get_data(); + if (str == "both") { + factory.with_mid_case(multigrid::mid_smooth_type::both); + } else if (str == "post_smoother") { + factory.with_mid_case(multigrid::mid_smooth_type::post_smoother); + } else if (str == "pre_smoother") { + factory.with_mid_case(multigrid::mid_smooth_type::pre_smoother); + } else if (str == "standalone") { + factory.with_mid_case(multigrid::mid_smooth_type::standalone); + } else { + GKO_INVALID_STATE("Not valid mid_smooth_type value"); + } + } + SET_VALUE(factory, size_type, max_levels, config); + SET_VALUE(factory, size_type, min_coarse_rows, config); + SET_FACTORY_VECTOR(factory, const LinOpFactory, coarsest_solver, config, + context, td_for_child); + if (config.contains("solver_selector")) { + auto str = config.at("solver_selector").get_data(); + factory.with_solver_selector(get_selector(str)); + } + if (config.contains("cycle")) { + auto str = config.at("cycle").get_data(); + if (str == "v") { + factory.with_cycle(multigrid::cycle::v); + } else if (str == "w") { + factory.with_cycle(multigrid::cycle::w); + } else if (str == "f") { + factory.with_cycle(multigrid::cycle::f); + } else { + GKO_INVALID_STATE("Not valid cycle value"); + } + } + SET_VALUE(factory, size_type, kcycle_base, config); + SET_VALUE(factory, double, kcycle_rel_tol, config); + SET_VALUE(factory, std::complex, smoother_relax, config); + SET_VALUE(factory, size_type, smoother_iters, config); + SET_VALUE(factory, solver::initial_guess_mode, default_initial_guess, + config); + return factory; +} + + void Multigrid::generate() { // generate coarse matrix until reaching max_level or min_coarse_rows diff --git a/core/test/config/multigrid.cpp b/core/test/config/multigrid.cpp index 5340164392d..1b7ac43c717 100644 --- a/core/test/config/multigrid.cpp +++ b/core/test/config/multigrid.cpp @@ -1,34 +1,6 @@ -/************************************************************* -Copyright (c) 2017-2023, the Ginkgo authors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - -1. Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS -IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED -TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*************************************************************/ +// SPDX-FileCopyrightText: 2017-2023 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause #include @@ -61,27 +33,30 @@ struct MultigridLevelConfigTest { static void change_template(pnode& config) { - config.get_list()["ValueType"] = pnode{"float"}; - config.get_list()["IndexType"] = pnode{"int64"}; + config.get_map()["ValueType"] = pnode{"float"}; + config.get_map()["IndexType"] = pnode{"int64"}; } }; struct Pgm : MultigridLevelConfigTest, gko::multigrid::Pgm> { - static pnode setup_base() { return pnode{{{"Type", pnode{"Pgm"}}}}; } + static pnode setup_base() + { + return pnode{std::map{{"Type", pnode{"Pgm"}}}}; + } template static void set(pnode& config, ParamType& param, registry reg, std::shared_ptr exec) { - config.get_list()["max_iterations"] = pnode{20}; + config.get_map()["max_iterations"] = pnode{20}; param.with_max_iterations(20u); - config.get_list()["max_unassigned_ratio"] = pnode{0.1}; + config.get_map()["max_unassigned_ratio"] = pnode{0.1}; param.with_max_unassigned_ratio(0.1); - config.get_list()["deterministic"] = pnode{true}; + config.get_map()["deterministic"] = pnode{true}; param.with_deterministic(true); - config.get_list()["skip_sorting"] = pnode{true}; + config.get_map()["skip_sorting"] = pnode{true}; param.with_skip_sorting(true); } @@ -105,17 +80,18 @@ struct FixedCoarsening : MultigridLevelConfigTest< gko::multigrid::FixedCoarsening> { static pnode setup_base() { - return pnode{{{"Type", pnode{"FixedCoarsening"}}}}; + return pnode{ + std::map{{"Type", pnode{"FixedCoarsening"}}}}; } template static void set(pnode& config, ParamType& param, registry reg, std::shared_ptr exec) { - config.get_list()["coarse_rows"] = - pnode{std::vector{{2}, {3}, {5}}}; - param.with_coarse_rows(gko::array(exec, {2, 3, 5})); - config.get_list()["skip_sorting"] = pnode{true}; + // config.get_map()["coarse_rows"] = + // pnode{std::vector{{2}, {3}, {5}}}; + // param.with_coarse_rows(gko::array(exec, {2, 3, 5})); + config.get_map()["skip_sorting"] = pnode{true}; param.with_skip_sorting(true); } @@ -159,7 +135,7 @@ TYPED_TEST(MultigridLevel, CreateDefault) using Config = typename TestFixture::Config; auto config = Config::setup_base(); - auto res = build_from_config(config, this->reg, this->exec, this->td); + auto res = build_from_config(config, this->reg, this->td).on(this->exec); auto ans = Config::default_type::build().on(this->exec); Config::validate(res.get(), ans.get()); @@ -172,7 +148,7 @@ TYPED_TEST(MultigridLevel, ExplicitTemplate) auto config = Config::setup_base(); Config::change_template(config); - auto res = build_from_config(config, this->reg, this->exec, this->td); + auto res = build_from_config(config, this->reg, this->td).on(this->exec); auto ans = Config::explicit_type::build().on(this->exec); Config::validate(res.get(), ans.get()); @@ -187,7 +163,7 @@ TYPED_TEST(MultigridLevel, Set) auto param = Config::explicit_type::build(); Config::set(config, param, this->reg, this->exec); - auto res = build_from_config(config, this->reg, this->exec, this->td); + auto res = build_from_config(config, this->reg, this->td).on(this->exec); auto ans = param.on(this->exec); Config::validate(res.get(), ans.get()); @@ -199,79 +175,86 @@ using DummySmoother = gko::solver::Ir; using DummyStop = gko::stop::Iteration; struct MultigridConfig { - static pnode setup_base() { return pnode{{{"Type", pnode{"Multigrid"}}}}; } + static pnode setup_base() + { + return pnode{ + std::map{{"Type", pnode{"Multigrid"}}}}; + } template static void set(pnode& config, ParamType& param, registry reg, std::shared_ptr exec) { - config.get_list()["post_uses_pre"] = pnode{false}; + config.get_map()["post_uses_pre"] = pnode{false}; param.with_post_uses_pre(false); - config.get_list()["mid_case"] = pnode{"both"}; + config.get_map()["mid_case"] = pnode{"both"}; param.with_mid_case(gko::solver::multigrid::mid_smooth_type::both); - config.get_list()["max_levels"] = pnode{20u}; + config.get_map()["max_levels"] = pnode{20u}; param.with_max_levels(20u); - config.get_list()["min_coarse_rows"] = pnode{32u}; + config.get_map()["min_coarse_rows"] = pnode{32u}; param.with_min_coarse_rows(32u); - config.get_list()["cycle"] = pnode{"w"}; + config.get_map()["cycle"] = pnode{"w"}; param.with_cycle(gko::solver::multigrid::cycle::w); - config.get_list()["kcycle_base"] = pnode{2u}; + config.get_map()["kcycle_base"] = pnode{2u}; param.with_kcycle_base(2u); - config.get_list()["kcycle_rel_tol"] = pnode{0.5}; + config.get_map()["kcycle_rel_tol"] = pnode{0.5}; param.with_kcycle_rel_tol(0.5); - config.get_list()["smoother_relax"] = pnode{0.3}; + config.get_map()["smoother_relax"] = pnode{0.3}; param.with_smoother_relax(0.3); - config.get_list()["smoother_iters"] = pnode{2u}; + config.get_map()["smoother_iters"] = pnode{2u}; param.with_smoother_iters(2u); - config.get_list()["default_initial_guess"] = pnode{"provided"}; + config.get_map()["default_initial_guess"] = pnode{"provided"}; param.with_default_initial_guess( gko::solver::initial_guess_mode::provided); if (from_reg) { - config.get_list()["criteria"] = pnode{"criterion_factory"}; + config.get_map()["criteria"] = pnode{"criterion_factory"}; param.with_criteria(reg.search_data( "criterion_factory")); - config.get_list()["mg_level"] = + config.get_map()["mg_level"] = pnode{std::vector{{"mg_level_0"}, {"mg_level_1"}}}; param.with_mg_level( reg.search_data("mg_level_0"), reg.search_data("mg_level_1")); - config.get_list()["pre_smoother"] = pnode{"pre_smoother"}; + config.get_map()["pre_smoother"] = pnode{"pre_smoother"}; param.with_pre_smoother( reg.search_data("pre_smoother")); - config.get_list()["post_smoother"] = pnode{"post_smoother"}; + config.get_map()["post_smoother"] = pnode{"post_smoother"}; param.with_post_smoother( reg.search_data("post_smoother")); - config.get_list()["mid_smoother"] = pnode{"mid_smoother"}; + config.get_map()["mid_smoother"] = pnode{"mid_smoother"}; param.with_mid_smoother( reg.search_data("mid_smoother")); - config.get_list()["coarsest_solver"] = pnode{"coarsest_solver"}; + config.get_map()["coarsest_solver"] = pnode{"coarsest_solver"}; param.with_coarsest_solver( reg.search_data("coarsest_solver")); } else { - config.get_list()["criteria"] = - pnode{{{"Type", pnode{"Iteration"}}}}; + config.get_map()["criteria"] = pnode{ + std::map{{"Type", pnode{"Iteration"}}}}; param.with_criteria(DummyStop::build().on(exec)); - config.get_list()["mg_level"] = pnode{std::vector{ + config.get_map()["mg_level"] = pnode{std::vector{ pnode{std::map{{"Type", {"Pgm"}}}}, pnode{std::map{{"Type", {"Pgm"}}}}}}; param.with_mg_level(DummyMgLevel::build().on(exec), DummyMgLevel::build().on(exec)); - config.get_list()["pre_smoother"] = pnode{{{"Type", pnode{"Ir"}}}}; + config.get_map()["pre_smoother"] = + pnode{std::map{{"Type", pnode{"Ir"}}}}; param.with_pre_smoother(DummySmoother::build().on(exec)); - config.get_list()["post_smoother"] = pnode{{{"Type", pnode{"Ir"}}}}; + config.get_map()["post_smoother"] = + pnode{std::map{{"Type", pnode{"Ir"}}}}; param.with_post_smoother(DummySmoother::build().on(exec)); - config.get_list()["mid_smoother"] = pnode{{{"Type", pnode{"Ir"}}}}; + config.get_map()["mid_smoother"] = + pnode{std::map{{"Type", pnode{"Ir"}}}}; param.with_mid_smoother(DummySmoother::build().on(exec)); - config.get_list()["coarsest_solver"] = - pnode{{{"Type", pnode{"Ir"}}}}; + config.get_map()["coarsest_solver"] = + pnode{std::map{{"Type", pnode{"Ir"}}}}; param.with_coarsest_solver(DummySmoother::build().on(exec)); } - config.get_list()["level_selector"] = pnode{"first_for_top"}; + config.get_map()["level_selector"] = pnode{"first_for_top"}; param.with_level_selector( [](const gko::size_type level, const gko::LinOp*) { return level == 0 ? 0 : 1; }); - config.get_list()["solver_selector"] = pnode{"first_for_top"}; + config.get_map()["solver_selector"] = pnode{"first_for_top"}; param.with_solver_selector( [](const gko::size_type level, const gko::LinOp*) { return level == 0 ? 0 : 1; @@ -392,7 +375,7 @@ TEST_F(MultigridT, CreateDefault) { auto config = Config::setup_base(); - auto res = build_from_config(config, this->reg, this->exec, this->td); + auto res = build_from_config(config, this->reg, this->td).on(this->exec); auto ans = gko::solver::Multigrid::build().on(this->exec); Config::template validate(res.get(), ans.get()); @@ -405,7 +388,7 @@ TEST_F(MultigridT, SetFromRegistry) auto param = gko::solver::Multigrid::build(); Config::template set(config, param, this->reg, this->exec); - auto res = build_from_config(config, this->reg, this->exec, this->td); + auto res = build_from_config(config, this->reg, this->td).on(exec); auto ans = param.on(this->exec); Config::template validate(res.get(), ans.get()); @@ -418,8 +401,8 @@ TEST_F(MultigridT, SetFromConfig) auto param = gko::solver::Multigrid::build(); Config::template set(config, param, this->reg, this->exec); - auto res = build_from_config(config, this->reg, this->exec, this->td); + auto res = build_from_config(config, this->reg, this->td).on(exec); auto ans = param.on(this->exec); Config::template validate(res.get(), ans.get()); -} \ No newline at end of file +} diff --git a/include/ginkgo/core/multigrid/fixed_coarsening.hpp b/include/ginkgo/core/multigrid/fixed_coarsening.hpp index d8c2231c498..c3ccb50c351 100644 --- a/include/ginkgo/core/multigrid/fixed_coarsening.hpp +++ b/include/ginkgo/core/multigrid/fixed_coarsening.hpp @@ -13,10 +13,13 @@ #include #include #include +#include +#include #include #include #include + namespace gko { namespace multigrid { @@ -83,6 +86,10 @@ class FixedCoarsening GKO_ENABLE_LIN_OP_FACTORY(FixedCoarsening, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); + static parameters_type build_from_config( + const config::pnode& config, const config::registry& context, + config::type_descriptor td_for_child); + protected: void apply_impl(const LinOp* b, LinOp* x) const override { diff --git a/include/ginkgo/core/multigrid/pgm.hpp b/include/ginkgo/core/multigrid/pgm.hpp index 779a06ff3fb..5aeffee9f64 100644 --- a/include/ginkgo/core/multigrid/pgm.hpp +++ b/include/ginkgo/core/multigrid/pgm.hpp @@ -13,10 +13,13 @@ #include #include #include +#include +#include #include #include #include + namespace gko { namespace multigrid { @@ -127,6 +130,10 @@ class Pgm : public EnableLinOp>, GKO_ENABLE_LIN_OP_FACTORY(Pgm, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); + static parameters_type build_from_config( + const config::pnode& config, const config::registry& context, + config::type_descriptor td_for_child); + protected: void apply_impl(const LinOp* b, LinOp* x) const override { diff --git a/include/ginkgo/core/solver/multigrid.hpp b/include/ginkgo/core/solver/multigrid.hpp index 9646e2779d7..7a940c905ac 100644 --- a/include/ginkgo/core/solver/multigrid.hpp +++ b/include/ginkgo/core/solver/multigrid.hpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include #include #include @@ -379,6 +381,10 @@ class Multigrid : public EnableLinOp, GKO_ENABLE_LIN_OP_FACTORY(Multigrid, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); + static parameters_type build_from_config( + const config::pnode& config, const config::registry& context, + config::type_descriptor td_for_child); + protected: void apply_impl(const LinOp* b, LinOp* x) const override;