diff --git a/federatedscope/autotune/algos.py b/federatedscope/autotune/algos.py index 392f5bfcb..c37a885e0 100644 --- a/federatedscope/autotune/algos.py +++ b/federatedscope/autotune/algos.py @@ -19,6 +19,21 @@ logger = logging.getLogger(__name__) +def make_trial(trial_cfg): + setup_seed(trial_cfg.seed) + data, modified_config = get_data(config=trial_cfg.clone()) + trial_cfg.merge_from_other_cfg(modified_config) + trial_cfg.freeze() + # TODO: enable client-wise configuration + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(trial_cfg), + client_class=get_client_cls(trial_cfg), + config=trial_cfg.clone()) + results = Fed_runner.run() + key1, key2 = trial_cfg.hpo.metric.split('.') + return results[key1][key2] + + class TrialExecutor(threading.Thread): """This class is responsible for executing the FL procedure with a given trial configuration in another thread. """ @@ -106,46 +121,61 @@ def _setup(self): ] def _evaluate(self, configs): - flags = [threading.Event() for _ in range(self._cfg.hpo.num_workers)] - for i in range(len(flags)): - flags[i].set() - threads = [None for _ in range(len(flags))] - thread_results = [dict() for _ in range(len(flags))] - - perfs = [None for _ in range(len(configs))] - for i, config in enumerate(configs): - available_worker = 0 - while not flags[available_worker].is_set(): - available_worker = (available_worker + 1) % len(threads) - if thread_results[available_worker]: - completed_trial_results = thread_results[available_worker] - cfg_idx = completed_trial_results['cfg_idx'] - perfs[cfg_idx] = completed_trial_results['perf'] - logger.info( - "Evaluate the {}-th config {} and get performance {}". - format(cfg_idx, configs[cfg_idx], perfs[cfg_idx])) - thread_results[available_worker].clear() - - trial_cfg = self._cfg.clone() - trial_cfg.merge_from_list(config2cmdargs(config)) - flags[available_worker].clear() - trial = TrialExecutor(i, flags[available_worker], - thread_results[available_worker], trial_cfg) - trial.start() - threads[available_worker] = trial - - for i in range(len(flags)): - if not flags[i].is_set(): - threads[i].join() - for i in range(len(thread_results)): - if thread_results[i]: - completed_trial_results = thread_results[i] - cfg_idx = completed_trial_results['cfg_idx'] - perfs[cfg_idx] = completed_trial_results['perf'] + if self._cfg.hpo.num_workers: + # execute FL in parallel by multi-threading + flags = [ + threading.Event() for _ in range(self._cfg.hpo.num_workers) + ] + for i in range(len(flags)): + flags[i].set() + threads = [None for _ in range(len(flags))] + thread_results = [dict() for _ in range(len(flags))] + + perfs = [None for _ in range(len(configs))] + for i, config in enumerate(configs): + available_worker = 0 + while not flags[available_worker].is_set(): + available_worker = (available_worker + 1) % len(threads) + if thread_results[available_worker]: + completed_trial_results = thread_results[available_worker] + cfg_idx = completed_trial_results['cfg_idx'] + perfs[cfg_idx] = completed_trial_results['perf'] + logger.info( + "Evaluate the {}-th config {} and get performance {}". + format(cfg_idx, configs[cfg_idx], perfs[cfg_idx])) + thread_results[available_worker].clear() + + trial_cfg = self._cfg.clone() + trial_cfg.merge_from_list(config2cmdargs(config)) + flags[available_worker].clear() + trial = TrialExecutor(i, flags[available_worker], + thread_results[available_worker], + trial_cfg) + trial.start() + threads[available_worker] = trial + + for i in range(len(flags)): + if not flags[i].is_set(): + threads[i].join() + for i in range(len(thread_results)): + if thread_results[i]: + completed_trial_results = thread_results[i] + cfg_idx = completed_trial_results['cfg_idx'] + perfs[cfg_idx] = completed_trial_results['perf'] + logger.info( + "Evaluate the {}-th config {} and get performance {}". + format(cfg_idx, configs[cfg_idx], perfs[cfg_idx])) + thread_results[i].clear() + + else: + perfs = [None] * len(configs) + for i, config in enumerate(configs): + trial_cfg = self._cfg.clone() + trial_cfg.merge_from_list(config2cmdargs(config)) + perfs[i] = make_trial(trial_cfg) logger.info( "Evaluate the {}-th config {} and get performance {}". - format(cfg_idx, configs[cfg_idx], perfs[cfg_idx])) - thread_results[i].clear() + format(i, config, perfs[i])) return perfs diff --git a/federatedscope/core/configs/cfg_hpo.py b/federatedscope/core/configs/cfg_hpo.py index a54bf0298..dae42b6e3 100644 --- a/federatedscope/core/configs/cfg_hpo.py +++ b/federatedscope/core/configs/cfg_hpo.py @@ -10,7 +10,7 @@ def extend_hpo_cfg(cfg): cfg.hpo = CN() cfg.hpo.working_folder = 'hpo' cfg.hpo.ss = '' - cfg.hpo.num_workers = 1 + cfg.hpo.num_workers = 0 #cfg.hpo.init_strategy = 'random' cfg.hpo.init_cand_num = 16 cfg.hpo.log_scale = False @@ -56,6 +56,8 @@ def assert_hpo_cfg(cfg): assert cfg.hpo.scheduler in ['rs', 'sha', 'pbt'], "No HPO scheduler named {}".format( cfg.hpo.scheduler) + assert cfg.hpo.num_workers >= 0, "#worker should be non-negative but given {}".format( + cfg.hpo.num_workers) assert len(cfg.hpo.sha.budgets) == 0 or len( cfg.hpo.sha.budgets ) == cfg.hpo.sha.elim_round_num, \