From f062dfc29bbf73a0bd9690be596a00231e35883f Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 11 Oct 2023 14:30:18 +0200 Subject: [PATCH] make the type in deferred_factory explicitly --- include/ginkgo/core/base/abstract_factory.hpp | 44 +++++++++---------- .../distributed/preconditioner/schwarz.hpp | 2 +- include/ginkgo/core/preconditioner/ic.hpp | 14 +++--- include/ginkgo/core/preconditioner/ilu.hpp | 22 ++++++---- include/ginkgo/core/solver/direct.hpp | 2 +- include/ginkgo/core/solver/ir.hpp | 2 +- include/ginkgo/core/solver/multigrid.hpp | 12 ++--- include/ginkgo/core/solver/solver_base.hpp | 35 ++++++++++++--- 8 files changed, 83 insertions(+), 50 deletions(-) diff --git a/include/ginkgo/core/base/abstract_factory.hpp b/include/ginkgo/core/base/abstract_factory.hpp index cca440afe6c..a30afae4c16 100644 --- a/include/ginkgo/core/base/abstract_factory.hpp +++ b/include/ginkgo/core/base/abstract_factory.hpp @@ -338,14 +338,15 @@ class deferred_factory_parameter { * shared ownership. */ template >::value>* = nullptr> + std::enable_if_t< + std::is_base_of::value && + !(!std::is_const::value && + std::is_const::value)>* = nullptr> deferred_factory_parameter(std::shared_ptr factory) { - generator_ = - [factory = std::shared_ptr(std::move(factory))]( - std::shared_ptr) { return factory; }; + generator_ = [factory = + std::shared_ptr(std::move(factory))]( + std::shared_ptr) { return factory; }; } /** @@ -353,15 +354,16 @@ class deferred_factory_parameter { * preexisting factory with unique ownership. */ template >::value>* = nullptr> + std::enable_if_t< + std::is_base_of::value && + !(!std::is_const::value && + std::is_const::value)>* = nullptr> deferred_factory_parameter( std::unique_ptr factory) { - generator_ = - [factory = std::shared_ptr(std::move(factory))]( - std::shared_ptr) { return factory; }; + generator_ = [factory = + std::shared_ptr(std::move(factory))]( + std::shared_ptr) { return factory; }; } /** @@ -375,17 +377,14 @@ class deferred_factory_parameter { deferred_factory_parameter(ParametersType parameters) { generator_ = [parameters](std::shared_ptr exec) - -> std::shared_ptr { - return parameters.on(exec); - }; + -> std::shared_ptr { return parameters.on(exec); }; } /** * Instantiates the deferred parameter into an actual factory. This will * throw if the deferred factory parameter is empty. */ - std::shared_ptr on( - std::shared_ptr exec) const + std::shared_ptr on(std::shared_ptr exec) const { if (this->is_empty()) { GKO_NOT_SUPPORTED(*this); @@ -397,8 +396,7 @@ class deferred_factory_parameter { bool is_empty() const { return !bool(generator_); } private: - std::function( - std::shared_ptr)> + std::function(std::shared_ptr)> generator_; }; @@ -537,7 +535,7 @@ class deferred_factory_parameter { */ #define GKO_DEFERRED_FACTORY_PARAMETER(_name, _type) \ public: \ - std::shared_ptr _name{}; \ + std::shared_ptr<_type> _name{}; \ parameters_type& with_##_name(deferred_factory_parameter<_type> factory) \ { \ this->_name##_generator_ = std::move(factory); \ @@ -570,7 +568,7 @@ public: \ */ #define GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(_name, _type) \ public: \ - std::vector> _name{}; \ + std::vector> _name{}; \ template \ + template >::value>> \ parameters_type& with_##_name(const std::vector& factories) \ { \ this->_name##_generator_.clear(); \ diff --git a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp index f31bd96aa2e..ce8502c310d 100644 --- a/include/ginkgo/core/distributed/preconditioner/schwarz.hpp +++ b/include/ginkgo/core/distributed/preconditioner/schwarz.hpp @@ -94,7 +94,7 @@ class Schwarz /** * Local solver factory. */ - GKO_DEFERRED_FACTORY_PARAMETER(local_solver, LinOpFactory); + GKO_DEFERRED_FACTORY_PARAMETER(local_solver, const LinOpFactory); }; GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); diff --git a/include/ginkgo/core/preconditioner/ic.hpp b/include/ginkgo/core/preconditioner/ic.hpp index 97e7fe37871..c0bb4962663 100644 --- a/include/ginkgo/core/preconditioner/ic.hpp +++ b/include/ginkgo/core/preconditioner/ic.hpp @@ -136,13 +136,15 @@ class Ic : public EnableLinOp>, public Transposable { [[deprecated("use with_l_solver instead")]] parameters_type& with_l_solver_factory( - deferred_factory_parameter solver) + deferred_factory_parameter + solver) { return with_l_solver(std::move(solver)); } parameters_type& with_l_solver( - deferred_factory_parameter solver) + deferred_factory_parameter + solver) { this->l_solver_generator = std::move(solver); this->deferred_factories["l_solver"] = [](const auto& exec, @@ -157,13 +159,13 @@ class Ic : public EnableLinOp>, public Transposable { [[deprecated("use with_factorization instead")]] parameters_type& with_factorization_factory( - deferred_factory_parameter factorization) + deferred_factory_parameter factorization) { return with_factorization(std::move(factorization)); } parameters_type& with_factorization( - deferred_factory_parameter factorization) + deferred_factory_parameter factorization) { this->factorization_generator = std::move(factorization); this->deferred_factories["factorization"] = [](const auto& exec, @@ -177,10 +179,10 @@ class Ic : public EnableLinOp>, public Transposable { } private: - deferred_factory_parameter + deferred_factory_parameter l_solver_generator; - deferred_factory_parameter factorization_generator; + deferred_factory_parameter factorization_generator; }; GKO_ENABLE_LIN_OP_FACTORY(Ic, parameters, Factory); diff --git a/include/ginkgo/core/preconditioner/ilu.hpp b/include/ginkgo/core/preconditioner/ilu.hpp index d0f32c18c8c..683e157545c 100644 --- a/include/ginkgo/core/preconditioner/ilu.hpp +++ b/include/ginkgo/core/preconditioner/ilu.hpp @@ -154,13 +154,15 @@ class Ilu : public EnableLinOp< [[deprecated("use with_l_solver instead")]] parameters_type& with_l_solver_factory( - deferred_factory_parameter solver) + deferred_factory_parameter + solver) { return with_l_solver(std::move(solver)); } parameters_type& with_l_solver( - deferred_factory_parameter solver) + deferred_factory_parameter + solver) { this->l_solver_generator = std::move(solver); this->deferred_factories["l_solver"] = [](const auto& exec, @@ -175,13 +177,15 @@ class Ilu : public EnableLinOp< [[deprecated("use with_u_solver instead")]] parameters_type& with_u_solver_factory( - deferred_factory_parameter solver) + deferred_factory_parameter + solver) { return with_u_solver(std::move(solver)); } parameters_type& with_u_solver( - deferred_factory_parameter solver) + deferred_factory_parameter + solver) { this->u_solver_generator = std::move(solver); this->deferred_factories["u_solver"] = [](const auto& exec, @@ -196,13 +200,13 @@ class Ilu : public EnableLinOp< [[deprecated("use with_factorization instead")]] parameters_type& with_factorization_factory( - deferred_factory_parameter factorization) + deferred_factory_parameter factorization) { return with_factorization(std::move(factorization)); } parameters_type& with_factorization( - deferred_factory_parameter factorization) + deferred_factory_parameter factorization) { this->factorization_generator = std::move(factorization); this->deferred_factories["factorization"] = [](const auto& exec, @@ -216,13 +220,13 @@ class Ilu : public EnableLinOp< } private: - deferred_factory_parameter + deferred_factory_parameter l_solver_generator; - deferred_factory_parameter + deferred_factory_parameter u_solver_generator; - deferred_factory_parameter factorization_generator; + deferred_factory_parameter factorization_generator; }; GKO_ENABLE_LIN_OP_FACTORY(Ilu, parameters, Factory); diff --git a/include/ginkgo/core/solver/direct.hpp b/include/ginkgo/core/solver/direct.hpp index ee6783ff96d..c86db46434f 100644 --- a/include/ginkgo/core/solver/direct.hpp +++ b/include/ginkgo/core/solver/direct.hpp @@ -87,7 +87,7 @@ class Direct : public EnableLinOp>, gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u); /** The factorization factory to use for generating the factors. */ - GKO_DEFERRED_FACTORY_PARAMETER(factorization, LinOpFactory); + GKO_DEFERRED_FACTORY_PARAMETER(factorization, const LinOpFactory); }; GKO_ENABLE_LIN_OP_FACTORY(Direct, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); diff --git a/include/ginkgo/core/solver/ir.hpp b/include/ginkgo/core/solver/ir.hpp index 468e539f487..7d00d82cbaa 100644 --- a/include/ginkgo/core/solver/ir.hpp +++ b/include/ginkgo/core/solver/ir.hpp @@ -184,7 +184,7 @@ class Ir : public EnableLinOp>, /** * Inner solver factory. */ - GKO_DEFERRED_FACTORY_PARAMETER(solver, LinOpFactory); + GKO_DEFERRED_FACTORY_PARAMETER(solver, const LinOpFactory); /** * Already generated solver. If one is provided, the factory `solver` diff --git a/include/ginkgo/core/solver/multigrid.hpp b/include/ginkgo/core/solver/multigrid.hpp index 21860844d3e..5888ff65813 100644 --- a/include/ginkgo/core/solver/multigrid.hpp +++ b/include/ginkgo/core/solver/multigrid.hpp @@ -225,7 +225,7 @@ class Multigrid : public EnableLinOp, /** * MultigridLevel Factory list */ - GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mg_level, LinOpFactory); + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mg_level, const LinOpFactory); /** * Custom selector size_type (size_type level, const LinOp* fine_matrix) @@ -270,14 +270,15 @@ class Multigrid : public EnableLinOp, * If any element in the vector is a `nullptr` then the smoother * application at the corresponding level is skipped. */ - GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(pre_smoother, LinOpFactory); + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(pre_smoother, const LinOpFactory); /** * Post-smooth Factory list. * It is similar to Pre-smooth Factory list. It is ignored if * the factory parameter post_uses_pre is set to true. */ - GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(post_smoother, LinOpFactory); + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(post_smoother, + const LinOpFactory); /** * Mid-smooth Factory list. If it contains available elements, multigrid @@ -286,7 +287,7 @@ class Multigrid : public EnableLinOp, * Pre-smooth Factory list. It is ignored if the factory parameter * mid_case is not mid. */ - GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mid_smoother, LinOpFactory); + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mid_smoother, const LinOpFactory); /** * Whether post-smoothing-related calls use corresponding @@ -326,7 +327,8 @@ class Multigrid : public EnableLinOp, * If not set, then a direct LU solver will be used as solver on the * coarsest level. */ - GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(coarsest_solver, LinOpFactory); + GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(coarsest_solver, + const LinOpFactory); /** * Custom coarsest_solver selector diff --git a/include/ginkgo/core/solver/solver_base.hpp b/include/ginkgo/core/solver/solver_base.hpp index 3888d7fe62d..b27ade844fb 100644 --- a/include/ginkgo/core/solver/solver_base.hpp +++ b/include/ginkgo/core/solver/solver_base.hpp @@ -881,11 +881,14 @@ struct enable_iterative_solver_factory_parameters * Provides stopping criteria via stop::CriterionFactory instances to be * used by the iterative solver in a fluent interface. */ - template + template >...>::value>> Parameters& with_criteria(Args&&... value) { this->criterion_generators = { - deferred_factory_parameter{ + deferred_factory_parameter{ std::forward(value)}...}; this->deferred_factories["criteria"] = [](const auto& exec, auto& params) { @@ -899,10 +902,32 @@ struct enable_iterative_solver_factory_parameters return *self(); } + template >::value>> + Parameters& with_criteria(const std::vector& criteria_vec) + { + this->criterion_generators.clear(); + for (const auto& factory : criteria_vec) { + this->criterion_generators.push_back(factory); + } + this->deferred_factories["criteria"] = [](const auto& exec, + auto& params) { + if (!params.criterion_generators.empty()) { + params.criteria.clear(); + for (auto& generator : params.criterion_generators) { + params.criteria.push_back(generator.on(exec)); + } + } + }; + return *self(); + } + private: GKO_ENABLE_SELF(Parameters); - std::vector> + std::vector> criterion_generators; }; @@ -937,7 +962,7 @@ struct enable_preconditioned_iterative_solver_factory_parameters * @see preconditioned_iterative_solver_factory_parameters::preconditioner */ Parameters& with_preconditioner( - deferred_factory_parameter preconditioner) + deferred_factory_parameter preconditioner) { this->preconditioner_generator = std::move(preconditioner); this->deferred_factories["preconditioner"] = [](const auto& exec, @@ -965,7 +990,7 @@ struct enable_preconditioned_iterative_solver_factory_parameters private: GKO_ENABLE_SELF(Parameters); - deferred_factory_parameter preconditioner_generator; + deferred_factory_parameter preconditioner_generator; };