diff --git a/examples/trials/mnist-pbt-tuner-pytorch/__init__.py b/examples/trials/mnist-pbt-tuner-pytorch/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/trials/mnist-pbt-tuner-pytorch/mnist.py b/examples/trials/mnist-pbt-tuner-pytorch/mnist.py index 2161191e9e..b8653b40dc 100644 --- a/examples/trials/mnist-pbt-tuner-pytorch/mnist.py +++ b/examples/trials/mnist-pbt-tuner-pytorch/mnist.py @@ -155,8 +155,8 @@ def get_params(): help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') + parser.add_argument('--epochs', type=int, default=1, metavar='N', + help='number of epochs to train (default: 1)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--no_cuda', action='store_true', default=False, diff --git a/src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py b/src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py index 9e4acd586c..e943752e84 100755 --- a/src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py +++ b/src/sdk/pynni/nni/pbt_tuner/pbt_tuner.py @@ -4,9 +4,11 @@ import copy import logging import os +import random import numpy as np import nni +import nni.parameter_expressions from nni.tuner import Tuner from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space @@ -14,7 +16,42 @@ logger = logging.getLogger('pbt_tuner_AutoML') -def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_space): +def perturbation(hyperparameter_type, value, resample_probablity, uv, ub, lv, lb, random_state): + """ + Perturbation for hyperparameters + + Parameters + ---------- + hyperparameter_type : str + type of hyperparameter + value : list + parameters for sampling hyperparameter + resample_probability : float + probability for resampling + uv : float/int + upper value after perturbation + ub : float/int + upper bound + lv : float/int + lower value after perturbation + lb : float/int + lower bound + random_state : RandomState + random state + """ + if random.random() < resample_probablity: + if hyperparameter_type == "choice": + return value.index(nni.parameter_expressions.choice(value, random_state)) + else: + return getattr(nni.parameter_expressions, hyperparameter_type)(*(value + [random_state])) + else: + if random.random() > 0.5: + return min(uv, ub) + else: + return max(lv, lb) + + +def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probability, epoch, search_space): """ Replace checkpoint of bot_trial with top, and perturb hyperparameters @@ -24,8 +61,10 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s bottom model whose parameters should be replaced top_trial_info : TrialInfo better model - factors : float - factors for perturbation + factor : float + factor for perturbation + resample_probability : float + probability for resampling epoch : int step of PBTTuner search_space : dict @@ -34,21 +73,72 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s bot_checkpoint_dir = bot_trial_info.checkpoint_dir top_hyper_parameters = top_trial_info.hyper_parameters hyper_parameters = copy.deepcopy(top_hyper_parameters) - # TODO think about different type of hyperparameters for 1.perturbation 2.within search space + random_state = np.random.RandomState() for key in hyper_parameters.keys(): + hyper_parameter = hyper_parameters[key] if key == 'load_checkpoint_dir': hyper_parameters[key] = hyper_parameters['save_checkpoint_dir'] + continue elif key == 'save_checkpoint_dir': hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch)) - elif isinstance(hyper_parameters[key], float): - perturb = np.random.choice(factors) - val = hyper_parameters[key] * perturb + continue + elif search_space[key]["_type"] == "choice": + choices = search_space[key]["_value"] + ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1 + lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1 + elif search_space[key]["_type"] == "randint": lb, ub = search_space[key]["_value"][:2] - if search_space[key]["_type"] in ("uniform", "normal"): - val = np.clip(val, lb, ub).item() - hyper_parameters[key] = val + ub -= 1 + uv = hyper_parameter + 1 + lv = hyper_parameter - 1 + elif search_space[key]["_type"] == "uniform": + lb, ub = search_space[key]["_value"][:2] + perturb = (ub - lb) * factor + uv = hyper_parameter + perturb + lv = hyper_parameter - perturb + elif search_space[key]["_type"] == "quniform": + lb, ub, q = search_space[key]["_value"][:3] + multi = round(hyper_parameter / q) + uv = (multi + 1) * q + lv = (multi - 1) * q + elif search_space[key]["_type"] == "loguniform": + lb, ub = search_space[key]["_value"][:2] + perturb = (np.log(ub) - np.log(lb)) * factor + uv = np.exp(min(np.log(hyper_parameter) + perturb, np.log(ub))) + lv = np.exp(max(np.log(hyper_parameter) - perturb, np.log(lb))) + elif search_space[key]["_type"] == "qloguniform": + lb, ub, q = search_space[key]["_value"][:3] + multi = round(hyper_parameter / q) + uv = (multi + 1) * q + lv = (multi - 1) * q + elif search_space[key]["_type"] == "normal": + sigma = search_space[key]["_value"][1] + perturb = sigma * factor + uv = ub = hyper_parameter + perturb + lv = lb = hyper_parameter - perturb + elif search_space[key]["_type"] == "qnormal": + q = search_space[key]["_value"][2] + uv = ub = hyper_parameter + q + lv = lb = hyper_parameter - q + elif search_space[key]["_type"] == "lognormal": + sigma = search_space[key]["_value"][1] + perturb = sigma * factor + uv = ub = np.exp(np.log(hyper_parameter) + perturb) + lv = lb = np.exp(np.log(hyper_parameter) - perturb) + elif search_space[key]["_type"] == "qlognormal": + q = search_space[key]["_value"][2] + uv = ub = hyper_parameter + q + lv, lb = hyper_parameter - q, 1E-10 else: + logger.warning("Illegal type to perturb: %s", search_space[key]["_type"]) continue + if search_space[key]["_type"] == "choice": + idx = perturbation(search_space[key]["_type"], search_space[key]["_value"], + resample_probability, uv, ub, lv, lb, random_state) + hyper_parameters[key] = {'_index': idx, '_value': choices[idx]} + else: + hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"], + resample_probability, uv, ub, lv, lb, random_state) bot_trial_info.hyper_parameters = hyper_parameters bot_trial_info.clean_id() @@ -70,7 +160,8 @@ def clean_id(self): class PBTTuner(Tuner): - def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factors=(1.2, 0.8), fraction=0.2): + def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2, + resample_probability=0.25, fraction=0.2): """ Initialization @@ -82,8 +173,10 @@ def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population directory to store training model checkpoint population_size : int number of trials for each epoch - factors : tuple - factors for perturbation + factor : float + factor for perturbation + resample_probability : float + probability for resampling fraction : float fraction for selecting bottom and top trials """ @@ -93,7 +186,8 @@ def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population logger.info("Checkpoint dir is set to %s by default.", all_checkpoint_dir) self.all_checkpoint_dir = all_checkpoint_dir self.population_size = population_size - self.factors = factors + self.factor = factor + self.resample_probability = resample_probability self.fraction = fraction # defined in trial code #self.perturbation_interval = perturbation_interval @@ -237,7 +331,7 @@ def receive_trial_result(self, parameter_id, parameters, value, **kwargs): bottoms = self.finished[self.finished_trials - cutoff:] for bottom in bottoms: top = np.random.choice(tops) - exploit_and_explore(bottom, top, self.factors, self.epoch, self.searchspace_json) + exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json) for trial in self.finished: if trial not in bottoms: trial.clean_id()