diff --git a/nni/common/hpo_utils.py b/nni/common/hpo_utils.py index f99e97eff1..3fb8cc1b6f 100644 --- a/nni/common/hpo_utils.py +++ b/nni/common/hpo_utils.py @@ -47,6 +47,11 @@ def validate_search_space( raise ValueError(f'search space "{name}"\'s value is not a list : {spec}') if type_ == 'choice': + if not all(isinstance(arg, (float, int, str)) for arg in args): + # FIXME: need further check for each algorithm which types are actually supported + # for now validation only prints warning so it doesn't harm + if not isinstance(args[0], dict) or '_name' not in args[0]: # not nested search space + raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}') continue if type_.startswith('q'): diff --git a/test/ut/sdk/test_hpo_utils.py b/test/ut/sdk/test_hpo_utils.py index e0f6d12294..c1204d0b9a 100644 --- a/test/ut/sdk/test_hpo_utils.py +++ b/test/ut/sdk/test_hpo_utils.py @@ -16,12 +16,22 @@ 'choice': good['choice'], 'randint': good['randint'], } +good_nested = { + 'outer': { + '_type': 'choice', + '_value': [ + { '_name': 'empty' }, + { '_name': 'a', 'a_1': { '_type': 'choice', '_value': ['a', 'b'] } } + ] + } +} bad_type = 'x' bad_spec_type = { 'x': [1, 2, 3] } bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } } bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } } bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } } +bad_choice_args = { 'x': { '_type': 'choice', 'value': [ 'a', object() ] } } bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } } bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } } bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } } @@ -32,11 +42,13 @@ def test_hpo_utils(): assert validate_search_space(good, raise_exception=False) + assert validate_search_space(good_nested, raise_exception=False) assert not validate_search_space(bad_type, raise_exception=False) assert not validate_search_space(bad_spec_type, raise_exception=False) assert not validate_search_space(bad_fields, raise_exception=False) assert not validate_search_space(bad_type_name, raise_exception=False) assert not validate_search_space(bad_value, raise_exception=False) + assert not validate_search_space(bad_choice_args, raise_exception=False) assert not validate_search_space(bad_2_args, raise_exception=False) assert not validate_search_space(bad_3_args, raise_exception=False) assert not validate_search_space(bad_int_args, raise_exception=False)