From 06123ac4e27d78ddfd5b5c8c90876d39751b96cf Mon Sep 17 00:00:00 2001 From: liuzhe Date: Fri, 23 Jul 2021 12:15:26 +0800 Subject: [PATCH 1/3] add search space validation for choice types --- nni/common/hpo_utils.py | 4 ++++ test/ut/sdk/test_hpo_utils.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/nni/common/hpo_utils.py b/nni/common/hpo_utils.py index f99e97eff1..58f2dc499f 100644 --- a/nni/common/hpo_utils.py +++ b/nni/common/hpo_utils.py @@ -47,6 +47,10 @@ def validate_search_space( raise ValueError(f'search space "{name}"\'s value is not a list : {spec}') if type_ == 'choice': + if not all(isinstance(arg, (float, int, str)) for arg in args): + # FIXME: need further check for each algorithm which types are actually supported + # for now validation only prints warning so it doesn't harm + raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings') continue if type_.startswith('q'): diff --git a/test/ut/sdk/test_hpo_utils.py b/test/ut/sdk/test_hpo_utils.py index e0f6d12294..f562ca43af 100644 --- a/test/ut/sdk/test_hpo_utils.py +++ b/test/ut/sdk/test_hpo_utils.py @@ -22,6 +22,7 @@ bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } } bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } } bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } } +bad_choice_args = { 'x': { '_type': 'choice', 'value': [ 'a', object() ] } } bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } } bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } } bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } } @@ -37,6 +38,7 @@ def test_hpo_utils(): assert not validate_search_space(bad_fields, raise_exception=False) assert not validate_search_space(bad_type_name, raise_exception=False) assert not validate_search_space(bad_value, raise_exception=False) + assert not validate_search_space(bad_choice_args, raise_exception=False) assert not validate_search_space(bad_2_args, raise_exception=False) assert not validate_search_space(bad_3_args, raise_exception=False) assert not validate_search_space(bad_int_args, raise_exception=False) From 44cf25de75504fc1b07f5a094985d389145223b6 Mon Sep 17 00:00:00 2001 From: liuzhe Date: Fri, 23 Jul 2021 12:17:50 +0800 Subject: [PATCH 2/3] add message --- nni/common/hpo_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nni/common/hpo_utils.py b/nni/common/hpo_utils.py index 58f2dc499f..22f1e77110 100644 --- a/nni/common/hpo_utils.py +++ b/nni/common/hpo_utils.py @@ -50,7 +50,7 @@ def validate_search_space( if not all(isinstance(arg, (float, int, str)) for arg in args): # FIXME: need further check for each algorithm which types are actually supported # for now validation only prints warning so it doesn't harm - raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings') + raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}') continue if type_.startswith('q'): From c6f4528d2c25988c7e1e463b33946929812c2a6b Mon Sep 17 00:00:00 2001 From: liuzhe Date: Mon, 26 Jul 2021 10:48:42 +0800 Subject: [PATCH 3/3] fix nested --- nni/common/hpo_utils.py | 3 ++- test/ut/sdk/test_hpo_utils.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/nni/common/hpo_utils.py b/nni/common/hpo_utils.py index 22f1e77110..3fb8cc1b6f 100644 --- a/nni/common/hpo_utils.py +++ b/nni/common/hpo_utils.py @@ -50,7 +50,8 @@ def validate_search_space( if not all(isinstance(arg, (float, int, str)) for arg in args): # FIXME: need further check for each algorithm which types are actually supported # for now validation only prints warning so it doesn't harm - raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}') + if not isinstance(args[0], dict) or '_name' not in args[0]: # not nested search space + raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}') continue if type_.startswith('q'): diff --git a/test/ut/sdk/test_hpo_utils.py b/test/ut/sdk/test_hpo_utils.py index f562ca43af..c1204d0b9a 100644 --- a/test/ut/sdk/test_hpo_utils.py +++ b/test/ut/sdk/test_hpo_utils.py @@ -16,6 +16,15 @@ 'choice': good['choice'], 'randint': good['randint'], } +good_nested = { + 'outer': { + '_type': 'choice', + '_value': [ + { '_name': 'empty' }, + { '_name': 'a', 'a_1': { '_type': 'choice', '_value': ['a', 'b'] } } + ] + } +} bad_type = 'x' bad_spec_type = { 'x': [1, 2, 3] } @@ -33,6 +42,7 @@ def test_hpo_utils(): assert validate_search_space(good, raise_exception=False) + assert validate_search_space(good_nested, raise_exception=False) assert not validate_search_space(bad_type, raise_exception=False) assert not validate_search_space(bad_spec_type, raise_exception=False) assert not validate_search_space(bad_fields, raise_exception=False)