Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add simple HPO search space validation #3877

Merged
merged 3 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/batch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class BatchTuner
import logging

import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner

TYPE = '_type'
Expand Down Expand Up @@ -75,6 +76,7 @@ def update_search_space(self, search_space):
----------
search_space : dict
"""
validate_search_space(search_space, ['choice'])
self._values = self.is_valid(search_space)

def generate_parameters(self, parameter_id, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/dngo_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import nni.parameter_expressions as parameter_expressions
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,6 +87,7 @@ def generate_parameters(self, parameter_id, **kwargs):
return new_x

def update_search_space(self, search_space):
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self.searchspace_json = search_space
self.random_state = np.random.RandomState()

Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/gp_tuner/gp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn.gaussian_process import GaussianProcessRegressor

from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward

Expand Down Expand Up @@ -103,6 +104,7 @@ def update_search_space(self, search_space):

Override of the abstract method in :class:`~nni.tuner.Tuner`.
"""
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self._space = TargetSpace(search_space, self._random_state)

def generate_parameters(self, parameter_id, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/gridsearch_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class GridSearchTuner
import numpy as np

import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import convert_dict2tuple

Expand Down Expand Up @@ -144,6 +145,7 @@ def update_search_space(self, search_space):
search_space : dict
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
validate_search_space(search_space, ['choice', 'randint', 'quniform'])
self.expanded_search_space = self._json2parameter(search_space)

def generate_parameters(self, parameter_id, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/hyperband_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from schema import Schema, Optional

from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
Expand Down Expand Up @@ -379,6 +380,7 @@ def _get_one_trial_job(self):
def handle_update_search_space(self, data):
"""data: JSON object, which is search space
"""
validate_search_space(data)
self.searchspace_json = data
self.random_state = np.random.RandomState()

Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/hyperopt_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from schema import Optional, Schema
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index

Expand Down Expand Up @@ -246,6 +247,7 @@ def update_search_space(self, search_space):
----------
search_space : dict
"""
validate_search_space(search_space)
self.json = search_space

search_space_instance = json2space(self.json)
Expand Down
3 changes: 3 additions & 0 deletions nni/algorithms/hpo/metis_tuner/metis_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from nni import ClassArgsValidator
from nni.tuner import Tuner
from nni.common.hpo_utils import validate_search_space
from nni.utils import OptimizeMode, extract_scalar_reward
from . import lib_constraint_summation
from . import lib_data
Expand Down Expand Up @@ -152,6 +153,8 @@ def update_search_space(self, search_space):
----------
search_space : dict
"""
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform'])

self.x_bounds = [[] for i in range(len(search_space))]
self.x_types = [NONE_TYPE for i in range(len(search_space))]

Expand Down
2 changes: 2 additions & 0 deletions nni/algorithms/hpo/smac_tuner/smac_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import nni
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward

Expand Down Expand Up @@ -143,6 +144,7 @@ def update_search_space(self, search_space):
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
self.logger.info('update search space in SMAC.')
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform'])
if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None:
Expand Down
73 changes: 73 additions & 0 deletions nni/common/hpo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
from typing import Any, List, Optional

common_search_space_types = [
'choice',
'randint',
'uniform',
'quniform',
'loguniform',
'qloguniform',
'normal',
'qnormal',
'lognormal',
'qlognormal',
]

def validate_search_space(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we never use the return value, change "validate_search_space" as void should be ok?

Copy link
Contributor Author

@liuzhe-lz liuzhe-lz Jun 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a public API, the caller might want validation result to decide following action.
For example, I think it's better to integrate validation into tuner API, and then perform early validation in nnictl. In this case, nnictl should halt launching when validation failed, but it won't want to raise the exception to end user.

search_space: Any,
support_types: Optional[List[str]] = None,
raise_exception: bool = False # for now, in case false positive
) -> bool:

if not raise_exception:
try:
validate_search_space(search_space, support_types, True)
return True
except ValueError as e:
logging.getLogger(__name__).error(e.args[0])
return False

if support_types is None:
support_types = common_search_space_types

if not isinstance(search_space, dict):
raise ValueError('search space is not a dict')

for name, spec in search_space.items():
if '_type' not in spec or '_value' not in spec:
raise ValueError(f'search space "{name}" does not have "_type" or "_value"')
type_ = spec['_type']
if type_ not in support_types:
raise ValueError(f'search space "{name}" has unsupported type {type_}')
args = spec['_value']
if not isinstance(args, list):
raise ValueError(f'search space "{name}"\'s value is not a list')

if type_ == 'choice':
continue

if type_.startswith('q'):
if len(args) != 3:
raise ValueError(f'search space "{name}" ({type_}) must have 3 values')
else:
if len(args) != 2:
raise ValueError(f'search space "{name}" ({type_}) must have 2 values')

if type_ == 'randint':
if not all(isinstance(arg, int) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have int values')
else:
if not all(isinstance(arg, (float, int)) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have float values')

if 'normal' not in type_:
if args[0] >= args[1]:
raise ValueError(f'search space "{name}" ({type_}) must have high > low')
if 'log' in type_ and args[0] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have low > 0')
else:
if args[1] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have sigma > 0')

return True
50 changes: 50 additions & 0 deletions test/ut/sdk/test_hpo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from nni.common.hpo_utils import validate_search_space

good = {
'choice': { '_type': 'choice', '_value': ['a', 'b'] },
'randint': { '_type': 'randint', '_value': [1, 10] },
'uniform': { '_type': 'uniform', '_value': [0, 1.0] },
'quniform': { '_type': 'quniform', '_value': [1, 10, 0.1] },
'loguniform': { '_type': 'loguniform', '_value': [0.001, 0.1] },
'qloguniform': { '_type': 'qloguniform', '_value': [0.001, 0.1, 0.001] },
'normal': { '_type': 'normal', '_value': [0, 0.1] },
'qnormal': { '_type': 'qnormal', '_value': [0.5, 0.1, 0.1] },
'lognormal': { '_type': 'lognormal', '_value': [0.0, 1] },
'qlognormal': { '_type': 'qlognormal', '_value': [-1, 1, 0.1] },
}
good_partial = {
'choice': good['choice'],
'randint': good['randint'],
}

bad_type = 'x'
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
bad_float_args = { 'x': { '_type': 'uniform', '_value': ['0.1', '0.2'] } }
bad_low_high = { 'x': { '_type': 'quniform', '_value': [2, 1, 0.1] } }
bad_log = { 'x': { '_type': 'loguniform', '_value': [0, 1] } }
bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } }

def test_hpo_utils():
assert validate_search_space(good, raise_exception=False)
assert not validate_search_space(bad_type, raise_exception=False)
assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False)
assert not validate_search_space(bad_float_args, raise_exception=False)
assert not validate_search_space(bad_low_high, raise_exception=False)
assert not validate_search_space(bad_log, raise_exception=False)
assert not validate_search_space(bad_sigma, raise_exception=False)

assert validate_search_space(good_partial, ['choice', 'randint'], False)
assert not validate_search_space(good, ['choice', 'randint'], False)

if __name__ == '__main__':
test_hpo_utils()