Skip to content

Commit

Permalink
config key check for criterion
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Feb 18, 2025
1 parent 88efcd8 commit 2ea21ad
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
12 changes: 12 additions & 0 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ parse_minimal_criteria(const pnode& config, const registry& context,

type_descriptor updated_td = update_type(config, td);

// Although it can still be caught in the following map, it gives consistent
// error exception when some keys are not allowed.
std::set<std::string> allowed_keys{"time",
"iteration",
"relative_residual_norm",
"initial_residual_norm",
"absolute_residual_norm",
"relative_implicit_residual_norm",
"initial_implicit_residual_norm",
"absolute_implicit_residual_norm"};
check_allowed_keys(config, allowed_keys);

std::vector<deferred_factory_parameter<const stop::CriterionFactory>> res;
for (const auto& it : config.get_map()) {
if (it.first == "value_type") {
Expand Down
13 changes: 12 additions & 1 deletion core/config/stop_config.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "core/config/stop_config.hpp"

#include <set>
#include <string>

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
Expand All @@ -27,6 +30,8 @@ namespace config {
deferred_factory_parameter<stop::CriterionFactory> configure_time(
const pnode& config, const registry& context, const type_descriptor& td)
{
std::set<std::string> allowed_keys{"time_limit"};
gko::config::check_allowed_keys(config, allowed_keys);
auto params = stop::Time::build();
if (auto& obj = config.get("time_limit")) {
params.with_time_limit(gko::config::get_value<long long int>(obj));
Expand All @@ -38,6 +43,8 @@ deferred_factory_parameter<stop::CriterionFactory> configure_time(
deferred_factory_parameter<stop::CriterionFactory> configure_iter(
const pnode& config, const registry& context, const type_descriptor& td)
{
std::set<std::string> allowed_keys{"max_iters"};
gko::config::check_allowed_keys(config, allowed_keys);
auto params = stop::Iteration::build();
if (auto& obj = config.get("max_iters")) {
params.with_max_iters(gko::config::get_value<size_type>(obj));
Expand Down Expand Up @@ -68,6 +75,8 @@ class ResidualNormConfigurer {
const gko::config::registry& context,
const gko::config::type_descriptor& td_for_child)
{
std::set<std::string> allowed_keys{"reduction_factor", "baseline"};
gko::config::check_allowed_keys(config, allowed_keys);
auto params = stop::ResidualNorm<ValueType>::build();
if (auto& obj = config.get("reduction_factor")) {
params.with_reduction_factor(
Expand Down Expand Up @@ -100,6 +109,8 @@ class ImplicitResidualNormConfigurer {
const gko::config::registry& context,
const gko::config::type_descriptor& td_for_child)
{
std::set<std::string> allowed_keys{"reduction_factor", "baseline"};
gko::config::check_allowed_keys(config, allowed_keys);
auto params = stop::ImplicitResidualNorm<ValueType>::build();
if (auto& obj = config.get("reduction_factor")) {
params.with_reduction_factor(
Expand Down
33 changes: 33 additions & 0 deletions core/test/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <gtest/gtest.h>

#include <ginkgo/core/base/exception.hpp>
#include <ginkgo/core/base/executor.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/matrix/dense.hpp>
Expand Down Expand Up @@ -123,6 +124,26 @@ TEST_F(Config, GenerateObjectWithCustomBuild)
}


TEST_F(Config, ThrowWhenKeyIsInvalidInCriterion)
{
auto reg = registry();
reg.emplace("precond", this->mtx);

for (const auto& stop :
{"Time", "Iteration", "ResidualNorm", "ImplicitResidualNorm"}) {
pnode stop_config{
{{"type", pnode{stop}}, {"invalid_key", pnode{"no"}}}};
pnode p{{{"generated_preconditioner", pnode{"precond"}},
{"criteria", stop_config}}};

ASSERT_THROW(parse<LinOpFactoryType::Cg>(
p, reg, type_descriptor{"float32", "void"})
.on(this->exec),
gko::InvalidStateError);
}
}


TEST_F(Config, GenerateCriteriaFromMinimalConfig)
{
// the map is ordered, since this allows for easier comparison in the test
Expand Down Expand Up @@ -256,6 +277,18 @@ TEST_F(Config, GenerateCriteriaFromMinimalConfigWithValueType)
}


TEST_F(Config, MinimalConfigThrowWhenKeyIsInvalid)
{
pnode minimal_stop{{{"time", pnode{100}}, {"invalid", pnode{"no"}}}};
pnode p{{{"criteria", minimal_stop}}};

ASSERT_THROW(parse<LinOpFactoryType::Cg>(p, registry(),
type_descriptor{"float32", "void"})
.on(this->exec),
gko::InvalidStateError);
}


TEST(GetValue, IndexType)
{
long long int value = 123;
Expand Down

0 comments on commit 2ea21ad

Please sign in to comment.