Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Add simple HPO search space validation (#3877)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhe-lz authored Jun 30, 2021
1 parent 749a463 commit 32fdd32
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/batch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class BatchTuner
import logging

import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner

TYPE = '_type'
Expand Down Expand Up @@ -75,6 +76,7 @@ def update_search_space(self, search_space):
----------
search_space : dict
"""
validate_search_space(search_space, ['choice'])
self._values = self.is_valid(search_space)

def generate_parameters(self, parameter_id, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/dngo_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import nni.parameter_expressions as parameter_expressions
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,6 +87,7 @@ def generate_parameters(self, parameter_id, **kwargs):
return new_x

def update_search_space(self, search_space):
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self.searchspace_json = search_space
self.random_state = np.random.RandomState()

Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/gp_tuner/gp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn.gaussian_process import GaussianProcessRegressor

from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward

Expand Down Expand Up @@ -103,6 +104,7 @@ def update_search_space(self, search_space):
Override of the abstract method in :class:`~nni.tuner.Tuner`.
"""
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self._space = TargetSpace(search_space, self._random_state)

def generate_parameters(self, parameter_id, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/gridsearch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class GridSearchTuner
import numpy as np

import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import convert_dict2tuple

Expand Down Expand Up @@ -144,6 +145,7 @@ def update_search_space(self, search_space):
search_space : dict
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
validate_search_space(search_space, ['choice', 'randint', 'quniform'])
self.expanded_search_space = self._json2parameter(search_space)

def generate_parameters(self, parameter_id, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/hyperband_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from schema import Schema, Optional

from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
Expand Down Expand Up @@ -379,6 +380,7 @@ def _get_one_trial_job(self):
def handle_update_search_space(self, data):
"""data: JSON object, which is search space
"""
validate_search_space(data)
self.searchspace_json = data
self.random_state = np.random.RandomState()

Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/hyperopt_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from schema import Optional, Schema
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index

Expand Down Expand Up @@ -246,6 +247,7 @@ def update_search_space(self, search_space):
----------
search_space : dict
"""
validate_search_space(search_space)
self.json = search_space

search_space_instance = json2space(self.json)
Expand Down
3 changes: 3 additions & 0 deletions nni/algorithms/hpo/metis_tuner/metis_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from nni import ClassArgsValidator
from nni.tuner import Tuner
from nni.common.hpo_utils import validate_search_space
from nni.utils import OptimizeMode, extract_scalar_reward
from . import lib_constraint_summation
from . import lib_data
Expand Down Expand Up @@ -152,6 +153,8 @@ def update_search_space(self, search_space):
----------
search_space : dict
"""
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform'])

self.x_bounds = [[] for i in range(len(search_space))]
self.x_types = [NONE_TYPE for i in range(len(search_space))]

Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/smac_tuner/smac_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import nni
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward

Expand Down Expand Up @@ -143,6 +144,7 @@ def update_search_space(self, search_space):
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
self.logger.info('update search space in SMAC.')
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform'])
if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None:
Expand Down
75 changes: 75 additions & 0 deletions nni/common/hpo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
from typing import Any, List, Optional

common_search_space_types = [
'choice',
'randint',
'uniform',
'quniform',
'loguniform',
'qloguniform',
'normal',
'qnormal',
'lognormal',
'qlognormal',
]

def validate_search_space(
search_space: Any,
support_types: Optional[List[str]] = None,
raise_exception: bool = False # for now, in case false positive
) -> bool:

if not raise_exception:
try:
validate_search_space(search_space, support_types, True)
return True
except ValueError as e:
logging.getLogger(__name__).error(e.args[0])
return False

if support_types is None:
support_types = common_search_space_types

if not isinstance(search_space, dict):
raise ValueError(f'search space is a {type(search_space).__name__}, expect a dict : {repr(search_space)}')

for name, spec in search_space.items():
if not isinstance(spec, dict):
raise ValueError(f'search space "{name}" is a {type(spec).__name__}, expect a dict : {repr(spec)}')
if '_type' not in spec or '_value' not in spec:
raise ValueError(f'search space "{name}" does not have "_type" or "_value" : {spec}')
type_ = spec['_type']
if type_ not in support_types:
raise ValueError(f'search space "{name}" has unsupported type "{type_}" : {spec}')
args = spec['_value']
if not isinstance(args, list):
raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')

if type_ == 'choice':
continue

if type_.startswith('q'):
if len(args) != 3:
raise ValueError(f'search space "{name}" ({type_}) must have 3 values : {spec}')
else:
if len(args) != 2:
raise ValueError(f'search space "{name}" ({type_}) must have 2 values : {spec}')

if type_ == 'randint':
if not all(isinstance(arg, int) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have int values : {spec}')
else:
if not all(isinstance(arg, (float, int)) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have float values : {spec}')

if 'normal' not in type_:
if args[0] >= args[1]:
raise ValueError(f'search space "{name}" ({type_}) must have high > low : {spec}')
if 'log' in type_ and args[0] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have low > 0 : {spec}')
else:
if args[1] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have sigma > 0 : {spec}')

return True
52 changes: 52 additions & 0 deletions test/ut/sdk/test_hpo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from nni.common.hpo_utils import validate_search_space

good = {
'choice': { '_type': 'choice', '_value': ['a', 'b'] },
'randint': { '_type': 'randint', '_value': [1, 10] },
'uniform': { '_type': 'uniform', '_value': [0, 1.0] },
'quniform': { '_type': 'quniform', '_value': [1, 10, 0.1] },
'loguniform': { '_type': 'loguniform', '_value': [0.001, 0.1] },
'qloguniform': { '_type': 'qloguniform', '_value': [0.001, 0.1, 0.001] },
'normal': { '_type': 'normal', '_value': [0, 0.1] },
'qnormal': { '_type': 'qnormal', '_value': [0.5, 0.1, 0.1] },
'lognormal': { '_type': 'lognormal', '_value': [0.0, 1] },
'qlognormal': { '_type': 'qlognormal', '_value': [-1, 1, 0.1] },
}
good_partial = {
'choice': good['choice'],
'randint': good['randint'],
}

bad_type = 'x'
bad_spec_type = { 'x': [1, 2, 3] }
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
bad_float_args = { 'x': { '_type': 'uniform', '_value': ['0.1', '0.2'] } }
bad_low_high = { 'x': { '_type': 'quniform', '_value': [2, 1, 0.1] } }
bad_log = { 'x': { '_type': 'loguniform', '_value': [0, 1] } }
bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } }

def test_hpo_utils():
assert validate_search_space(good, raise_exception=False)
assert not validate_search_space(bad_type, raise_exception=False)
assert not validate_search_space(bad_spec_type, raise_exception=False)
assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False)
assert not validate_search_space(bad_float_args, raise_exception=False)
assert not validate_search_space(bad_low_high, raise_exception=False)
assert not validate_search_space(bad_log, raise_exception=False)
assert not validate_search_space(bad_sigma, raise_exception=False)

assert validate_search_space(good_partial, ['choice', 'randint'], False)
assert not validate_search_space(good, ['choice', 'randint'], False)

if __name__ == '__main__':
test_hpo_utils()

0 comments on commit 32fdd32

Please sign in to comment.