Skip to content

Commit

Permalink
make the type in deferred_factory explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 19, 2023
1 parent 0b5eaf8 commit f062dfc
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 50 deletions.
44 changes: 22 additions & 22 deletions include/ginkgo/core/base/abstract_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,30 +338,32 @@ class deferred_factory_parameter {
* shared ownership.
*/
template <typename ConcreteFactoryType,
std::enable_if_t<std::is_base_of<
FactoryType,
std::remove_const_t<ConcreteFactoryType>>::value>* = nullptr>
std::enable_if_t<
std::is_base_of<FactoryType, ConcreteFactoryType>::value &&
!(!std::is_const<FactoryType>::value &&
std::is_const<ConcreteFactoryType>::value)>* = nullptr>
deferred_factory_parameter(std::shared_ptr<ConcreteFactoryType> factory)
{
generator_ =
[factory = std::shared_ptr<const FactoryType>(std::move(factory))](
std::shared_ptr<const Executor>) { return factory; };
generator_ = [factory =
std::shared_ptr<FactoryType>(std::move(factory))](
std::shared_ptr<const Executor>) { return factory; };
}

/**
* Creates a deferred factory parameter by taking ownership of a
* preexisting factory with unique ownership.
*/
template <typename ConcreteFactoryType, typename Deleter,
std::enable_if_t<std::is_base_of<
FactoryType,
std::remove_const_t<ConcreteFactoryType>>::value>* = nullptr>
std::enable_if_t<
std::is_base_of<FactoryType, ConcreteFactoryType>::value &&
!(!std::is_const<FactoryType>::value &&
std::is_const<ConcreteFactoryType>::value)>* = nullptr>
deferred_factory_parameter(
std::unique_ptr<ConcreteFactoryType, Deleter> factory)
{
generator_ =
[factory = std::shared_ptr<const FactoryType>(std::move(factory))](
std::shared_ptr<const Executor>) { return factory; };
generator_ = [factory =
std::shared_ptr<FactoryType>(std::move(factory))](
std::shared_ptr<const Executor>) { return factory; };
}

/**
Expand All @@ -375,17 +377,14 @@ class deferred_factory_parameter {
deferred_factory_parameter(ParametersType parameters)
{
generator_ = [parameters](std::shared_ptr<const Executor> exec)
-> std::shared_ptr<const FactoryType> {
return parameters.on(exec);
};
-> std::shared_ptr<FactoryType> { return parameters.on(exec); };
}

/**
* Instantiates the deferred parameter into an actual factory. This will
* throw if the deferred factory parameter is empty.
*/
std::shared_ptr<const FactoryType> on(
std::shared_ptr<const Executor> exec) const
std::shared_ptr<FactoryType> on(std::shared_ptr<const Executor> exec) const
{
if (this->is_empty()) {
GKO_NOT_SUPPORTED(*this);
Expand All @@ -397,8 +396,7 @@ class deferred_factory_parameter {
bool is_empty() const { return !bool(generator_); }

private:
std::function<std::shared_ptr<const FactoryType>(
std::shared_ptr<const Executor>)>
std::function<std::shared_ptr<FactoryType>(std::shared_ptr<const Executor>)>
generator_;
};

Expand Down Expand Up @@ -537,7 +535,7 @@ class deferred_factory_parameter {
*/
#define GKO_DEFERRED_FACTORY_PARAMETER(_name, _type) \
public: \
std::shared_ptr<const _type> _name{}; \
std::shared_ptr<_type> _name{}; \
parameters_type& with_##_name(deferred_factory_parameter<_type> factory) \
{ \
this->_name##_generator_ = std::move(factory); \
Expand Down Expand Up @@ -570,7 +568,7 @@ public: \
*/
#define GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(_name, _type) \
public: \
std::vector<std::shared_ptr<const _type>> _name{}; \
std::vector<std::shared_ptr<_type>> _name{}; \
template <typename... Args, \
typename = \
std::enable_if_t<xstd::conjunction<std::is_convertible< \
Expand All @@ -590,7 +588,9 @@ public: \
}; \
return *this; \
} \
template <typename FactoryType> \
template <typename FactoryType, \
typename = std::enable_if_t<std::is_convertible< \
FactoryType, deferred_factory_parameter<_type>>::value>> \
parameters_type& with_##_name(const std::vector<FactoryType>& factories) \
{ \
this->_name##_generator_.clear(); \
Expand Down
2 changes: 1 addition & 1 deletion include/ginkgo/core/distributed/preconditioner/schwarz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Schwarz
/**
* Local solver factory.
*/
GKO_DEFERRED_FACTORY_PARAMETER(local_solver, LinOpFactory);
GKO_DEFERRED_FACTORY_PARAMETER(local_solver, const LinOpFactory);
};
GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down
14 changes: 8 additions & 6 deletions include/ginkgo/core/preconditioner/ic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {

[[deprecated("use with_l_solver instead")]] parameters_type&
with_l_solver_factory(
deferred_factory_parameter<typename l_solver_type::Factory> solver)
deferred_factory_parameter<const typename l_solver_type::Factory>
solver)
{
return with_l_solver(std::move(solver));
}

parameters_type& with_l_solver(
deferred_factory_parameter<typename l_solver_type::Factory> solver)
deferred_factory_parameter<const typename l_solver_type::Factory>
solver)
{
this->l_solver_generator = std::move(solver);
this->deferred_factories["l_solver"] = [](const auto& exec,
Expand All @@ -157,13 +159,13 @@ class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {

[[deprecated("use with_factorization instead")]] parameters_type&
with_factorization_factory(
deferred_factory_parameter<LinOpFactory> factorization)
deferred_factory_parameter<const LinOpFactory> factorization)
{
return with_factorization(std::move(factorization));
}

parameters_type& with_factorization(
deferred_factory_parameter<LinOpFactory> factorization)
deferred_factory_parameter<const LinOpFactory> factorization)
{
this->factorization_generator = std::move(factorization);
this->deferred_factories["factorization"] = [](const auto& exec,
Expand All @@ -177,10 +179,10 @@ class Ic : public EnableLinOp<Ic<LSolverType, IndexType>>, public Transposable {
}

private:
deferred_factory_parameter<typename l_solver_type::Factory>
deferred_factory_parameter<const typename l_solver_type::Factory>
l_solver_generator;

deferred_factory_parameter<LinOpFactory> factorization_generator;
deferred_factory_parameter<const LinOpFactory> factorization_generator;
};

GKO_ENABLE_LIN_OP_FACTORY(Ic, parameters, Factory);
Expand Down
22 changes: 13 additions & 9 deletions include/ginkgo/core/preconditioner/ilu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,15 @@ class Ilu : public EnableLinOp<

[[deprecated("use with_l_solver instead")]] parameters_type&
with_l_solver_factory(
deferred_factory_parameter<typename l_solver_type::Factory> solver)
deferred_factory_parameter<const typename l_solver_type::Factory>
solver)
{
return with_l_solver(std::move(solver));
}

parameters_type& with_l_solver(
deferred_factory_parameter<typename l_solver_type::Factory> solver)
deferred_factory_parameter<const typename l_solver_type::Factory>
solver)
{
this->l_solver_generator = std::move(solver);
this->deferred_factories["l_solver"] = [](const auto& exec,
Expand All @@ -175,13 +177,15 @@ class Ilu : public EnableLinOp<

[[deprecated("use with_u_solver instead")]] parameters_type&
with_u_solver_factory(
deferred_factory_parameter<typename u_solver_type::Factory> solver)
deferred_factory_parameter<const typename u_solver_type::Factory>
solver)
{
return with_u_solver(std::move(solver));
}

parameters_type& with_u_solver(
deferred_factory_parameter<typename u_solver_type::Factory> solver)
deferred_factory_parameter<const typename u_solver_type::Factory>
solver)
{
this->u_solver_generator = std::move(solver);
this->deferred_factories["u_solver"] = [](const auto& exec,
Expand All @@ -196,13 +200,13 @@ class Ilu : public EnableLinOp<

[[deprecated("use with_factorization instead")]] parameters_type&
with_factorization_factory(
deferred_factory_parameter<LinOpFactory> factorization)
deferred_factory_parameter<const LinOpFactory> factorization)
{
return with_factorization(std::move(factorization));
}

parameters_type& with_factorization(
deferred_factory_parameter<LinOpFactory> factorization)
deferred_factory_parameter<const LinOpFactory> factorization)
{
this->factorization_generator = std::move(factorization);
this->deferred_factories["factorization"] = [](const auto& exec,
Expand All @@ -216,13 +220,13 @@ class Ilu : public EnableLinOp<
}

private:
deferred_factory_parameter<typename l_solver_type::Factory>
deferred_factory_parameter<const typename l_solver_type::Factory>
l_solver_generator;

deferred_factory_parameter<typename u_solver_type::Factory>
deferred_factory_parameter<const typename u_solver_type::Factory>
u_solver_generator;

deferred_factory_parameter<LinOpFactory> factorization_generator;
deferred_factory_parameter<const LinOpFactory> factorization_generator;
};

GKO_ENABLE_LIN_OP_FACTORY(Ilu, parameters, Factory);
Expand Down
2 changes: 1 addition & 1 deletion include/ginkgo/core/solver/direct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Direct : public EnableLinOp<Direct<ValueType, IndexType>>,
gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u);

/** The factorization factory to use for generating the factors. */
GKO_DEFERRED_FACTORY_PARAMETER(factorization, LinOpFactory);
GKO_DEFERRED_FACTORY_PARAMETER(factorization, const LinOpFactory);
};
GKO_ENABLE_LIN_OP_FACTORY(Direct, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down
2 changes: 1 addition & 1 deletion include/ginkgo/core/solver/ir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class Ir : public EnableLinOp<Ir<ValueType>>,
/**
* Inner solver factory.
*/
GKO_DEFERRED_FACTORY_PARAMETER(solver, LinOpFactory);
GKO_DEFERRED_FACTORY_PARAMETER(solver, const LinOpFactory);

/**
* Already generated solver. If one is provided, the factory `solver`
Expand Down
12 changes: 7 additions & 5 deletions include/ginkgo/core/solver/multigrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class Multigrid : public EnableLinOp<Multigrid>,
/**
* MultigridLevel Factory list
*/
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mg_level, LinOpFactory);
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mg_level, const LinOpFactory);

/**
* Custom selector size_type (size_type level, const LinOp* fine_matrix)
Expand Down Expand Up @@ -270,14 +270,15 @@ class Multigrid : public EnableLinOp<Multigrid>,
* If any element in the vector is a `nullptr` then the smoother
* application at the corresponding level is skipped.
*/
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(pre_smoother, LinOpFactory);
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(pre_smoother, const LinOpFactory);

/**
* Post-smooth Factory list.
* It is similar to Pre-smooth Factory list. It is ignored if
* the factory parameter post_uses_pre is set to true.
*/
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(post_smoother, LinOpFactory);
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(post_smoother,
const LinOpFactory);

/**
* Mid-smooth Factory list. If it contains available elements, multigrid
Expand All @@ -286,7 +287,7 @@ class Multigrid : public EnableLinOp<Multigrid>,
* Pre-smooth Factory list. It is ignored if the factory parameter
* mid_case is not mid.
*/
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mid_smoother, LinOpFactory);
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(mid_smoother, const LinOpFactory);

/**
* Whether post-smoothing-related calls use corresponding
Expand Down Expand Up @@ -326,7 +327,8 @@ class Multigrid : public EnableLinOp<Multigrid>,
* If not set, then a direct LU solver will be used as solver on the
* coarsest level.
*/
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(coarsest_solver, LinOpFactory);
GKO_DEFERRED_FACTORY_VECTOR_PARAMETER(coarsest_solver,
const LinOpFactory);

/**
* Custom coarsest_solver selector
Expand Down
35 changes: 30 additions & 5 deletions include/ginkgo/core/solver/solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,11 +881,14 @@ struct enable_iterative_solver_factory_parameters
* Provides stopping criteria via stop::CriterionFactory instances to be
* used by the iterative solver in a fluent interface.
*/
template <typename... Args>
template <typename... Args,
typename = std::enable_if_t<xstd::conjunction<std::is_convertible<
Args, deferred_factory_parameter<
const stop::CriterionFactory>>...>::value>>
Parameters& with_criteria(Args&&... value)
{
this->criterion_generators = {
deferred_factory_parameter<stop::CriterionFactory>{
deferred_factory_parameter<const stop::CriterionFactory>{
std::forward<Args>(value)}...};
this->deferred_factories["criteria"] = [](const auto& exec,
auto& params) {
Expand All @@ -899,10 +902,32 @@ struct enable_iterative_solver_factory_parameters
return *self();
}

template <typename FactoryType,
typename = std::enable_if_t<std::is_convertible<
FactoryType, deferred_factory_parameter<
const stop::CriterionFactory>>::value>>
Parameters& with_criteria(const std::vector<FactoryType>& criteria_vec)
{
this->criterion_generators.clear();
for (const auto& factory : criteria_vec) {
this->criterion_generators.push_back(factory);
}
this->deferred_factories["criteria"] = [](const auto& exec,
auto& params) {
if (!params.criterion_generators.empty()) {
params.criteria.clear();
for (auto& generator : params.criterion_generators) {
params.criteria.push_back(generator.on(exec));
}
}
};
return *self();
}

private:
GKO_ENABLE_SELF(Parameters);

std::vector<deferred_factory_parameter<stop::CriterionFactory>>
std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
criterion_generators;
};

Expand Down Expand Up @@ -937,7 +962,7 @@ struct enable_preconditioned_iterative_solver_factory_parameters
* @see preconditioned_iterative_solver_factory_parameters::preconditioner
*/
Parameters& with_preconditioner(
deferred_factory_parameter<LinOpFactory> preconditioner)
deferred_factory_parameter<const LinOpFactory> preconditioner)
{
this->preconditioner_generator = std::move(preconditioner);
this->deferred_factories["preconditioner"] = [](const auto& exec,
Expand Down Expand Up @@ -965,7 +990,7 @@ struct enable_preconditioned_iterative_solver_factory_parameters
private:
GKO_ENABLE_SELF(Parameters);

deferred_factory_parameter<LinOpFactory> preconditioner_generator;
deferred_factory_parameter<const LinOpFactory> preconditioner_generator;
};


Expand Down

0 comments on commit f062dfc

Please sign in to comment.