Skip to content

Commit

Permalink
Enable single thread HPO (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
joneswong authored Jun 5, 2022
1 parent b0c0bc9 commit 93fe045
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 39 deletions.
106 changes: 68 additions & 38 deletions federatedscope/autotune/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@
logger = logging.getLogger(__name__)


def make_trial(trial_cfg):
setup_seed(trial_cfg.seed)
data, modified_config = get_data(config=trial_cfg.clone())
trial_cfg.merge_from_other_cfg(modified_config)
trial_cfg.freeze()
# TODO: enable client-wise configuration
Fed_runner = FedRunner(data=data,
server_class=get_server_cls(trial_cfg),
client_class=get_client_cls(trial_cfg),
config=trial_cfg.clone())
results = Fed_runner.run()
key1, key2 = trial_cfg.hpo.metric.split('.')
return results[key1][key2]


class TrialExecutor(threading.Thread):
"""This class is responsible for executing the FL procedure with a given trial configuration in another thread.
"""
Expand Down Expand Up @@ -106,46 +121,61 @@ def _setup(self):
]

def _evaluate(self, configs):
flags = [threading.Event() for _ in range(self._cfg.hpo.num_workers)]
for i in range(len(flags)):
flags[i].set()
threads = [None for _ in range(len(flags))]
thread_results = [dict() for _ in range(len(flags))]

perfs = [None for _ in range(len(configs))]
for i, config in enumerate(configs):
available_worker = 0
while not flags[available_worker].is_set():
available_worker = (available_worker + 1) % len(threads)
if thread_results[available_worker]:
completed_trial_results = thread_results[available_worker]
cfg_idx = completed_trial_results['cfg_idx']
perfs[cfg_idx] = completed_trial_results['perf']
logger.info(
"Evaluate the {}-th config {} and get performance {}".
format(cfg_idx, configs[cfg_idx], perfs[cfg_idx]))
thread_results[available_worker].clear()

trial_cfg = self._cfg.clone()
trial_cfg.merge_from_list(config2cmdargs(config))
flags[available_worker].clear()
trial = TrialExecutor(i, flags[available_worker],
thread_results[available_worker], trial_cfg)
trial.start()
threads[available_worker] = trial

for i in range(len(flags)):
if not flags[i].is_set():
threads[i].join()
for i in range(len(thread_results)):
if thread_results[i]:
completed_trial_results = thread_results[i]
cfg_idx = completed_trial_results['cfg_idx']
perfs[cfg_idx] = completed_trial_results['perf']
if self._cfg.hpo.num_workers:
# execute FL in parallel by multi-threading
flags = [
threading.Event() for _ in range(self._cfg.hpo.num_workers)
]
for i in range(len(flags)):
flags[i].set()
threads = [None for _ in range(len(flags))]
thread_results = [dict() for _ in range(len(flags))]

perfs = [None for _ in range(len(configs))]
for i, config in enumerate(configs):
available_worker = 0
while not flags[available_worker].is_set():
available_worker = (available_worker + 1) % len(threads)
if thread_results[available_worker]:
completed_trial_results = thread_results[available_worker]
cfg_idx = completed_trial_results['cfg_idx']
perfs[cfg_idx] = completed_trial_results['perf']
logger.info(
"Evaluate the {}-th config {} and get performance {}".
format(cfg_idx, configs[cfg_idx], perfs[cfg_idx]))
thread_results[available_worker].clear()

trial_cfg = self._cfg.clone()
trial_cfg.merge_from_list(config2cmdargs(config))
flags[available_worker].clear()
trial = TrialExecutor(i, flags[available_worker],
thread_results[available_worker],
trial_cfg)
trial.start()
threads[available_worker] = trial

for i in range(len(flags)):
if not flags[i].is_set():
threads[i].join()
for i in range(len(thread_results)):
if thread_results[i]:
completed_trial_results = thread_results[i]
cfg_idx = completed_trial_results['cfg_idx']
perfs[cfg_idx] = completed_trial_results['perf']
logger.info(
"Evaluate the {}-th config {} and get performance {}".
format(cfg_idx, configs[cfg_idx], perfs[cfg_idx]))
thread_results[i].clear()

else:
perfs = [None] * len(configs)
for i, config in enumerate(configs):
trial_cfg = self._cfg.clone()
trial_cfg.merge_from_list(config2cmdargs(config))
perfs[i] = make_trial(trial_cfg)
logger.info(
"Evaluate the {}-th config {} and get performance {}".
format(cfg_idx, configs[cfg_idx], perfs[cfg_idx]))
thread_results[i].clear()
format(i, config, perfs[i]))

return perfs

Expand Down
4 changes: 3 additions & 1 deletion federatedscope/core/configs/cfg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def extend_hpo_cfg(cfg):
cfg.hpo = CN()
cfg.hpo.working_folder = 'hpo'
cfg.hpo.ss = ''
cfg.hpo.num_workers = 1
cfg.hpo.num_workers = 0
#cfg.hpo.init_strategy = 'random'
cfg.hpo.init_cand_num = 16
cfg.hpo.log_scale = False
Expand Down Expand Up @@ -56,6 +56,8 @@ def assert_hpo_cfg(cfg):
assert cfg.hpo.scheduler in ['rs', 'sha',
'pbt'], "No HPO scheduler named {}".format(
cfg.hpo.scheduler)
assert cfg.hpo.num_workers >= 0, "#worker should be non-negative but given {}".format(
cfg.hpo.num_workers)
assert len(cfg.hpo.sha.budgets) == 0 or len(
cfg.hpo.sha.budgets
) == cfg.hpo.sha.elim_round_num, \
Expand Down

0 comments on commit 93fe045

Please sign in to comment.