Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[nas] fix issue introduced by the trial recovery feature (#5109)
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanluZhang authored Oct 12, 2022
1 parent 87677df commit bcc640c
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 41 deletions.
17 changes: 15 additions & 2 deletions nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,11 @@ def handle_trial_end(self, data):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = nni.load(data['hyper_params'])
if self.is_created_in_previous_exp(hyper_params['parameter_id']):
# The end of the recovered trial is ignored
return
logger.debug('Tuner handle trial end, result is %s', data)
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
Expand Down Expand Up @@ -695,6 +698,13 @@ def handle_report_metric_data(self, data):
ValueError
Data type not supported
"""
if self.is_created_in_previous_exp(data['parameter_id']):
if data['type'] == MetricType.FINAL:
# only deal with final metric using import data
param = self.get_previous_param(data['parameter_id'])
trial_data = [{'parameter': param, 'value': nni.load(data['value'])}]
self.handle_import_data(trial_data)
return
logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = nni.load(data['value'])
Expand Down Expand Up @@ -752,7 +762,10 @@ def handle_report_metric_data(self, data):
'Data type not supported: {}'.format(data['type']))

def handle_add_customized_trial(self, data):
pass
global _next_parameter_id
# data: parameters
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1

def handle_import_data(self, data):
"""Import additional data for tuning
Expand Down
11 changes: 10 additions & 1 deletion nni/algorithms/hpo/hyperband_advisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ def handle_trial_end(self, data):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params = nni.load(data['hyper_params'])
if self.is_created_in_previous_exp(hyper_params['parameter_id']):
# The end of the recovered trial is ignored
return
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
Expand All @@ -538,6 +541,9 @@ def handle_report_metric_data(self, data):
ValueError
Data type not supported
"""
if self.is_created_in_previous_exp(data['parameter_id']):
# do not support recovering the algorithm state
return
if 'value' in data:
data['value'] = nni.load(data['value'])
# multiphase? need to check
Expand Down Expand Up @@ -576,7 +582,10 @@ def handle_report_metric_data(self, data):
raise ValueError('Data type not supported: {}'.format(data['type']))

def handle_add_customized_trial(self, data):
pass
global _next_parameter_id
# data: parameters
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1

def handle_import_data(self, data):
pass
13 changes: 0 additions & 13 deletions nni/algorithms/hpo/tpe_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,6 @@ def import_data(self, data): # for resuming experiment
self.dedup.add_history(param)
_logger.info(f'Replayed {len(data)} FINISHED trials')

def import_customized_data(self, data): # for dedup customized / resumed
if isinstance(data, str):
data = nni.load(data)

for trial in data:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if isinstance(trial, str):
trial = nni.load(trial)
param = format_parameters(trial['parameters'], self.space)
self._running_params[trial['parameter_id']] = param
self.dedup.add_history(param)
_logger.info(f'Replayed {len(data)} RUNING/WAITING trials')

def suggest(args, rng, space, history):
params = {}
for key, spec in space.items():
Expand Down
26 changes: 23 additions & 3 deletions nni/nas/execution/common/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__all__ = ['RetiariiAdvisor']

import logging
import time
import os
from typing import Any, Callable, Optional, Dict, List, Tuple

Expand Down Expand Up @@ -60,11 +61,12 @@ def __init__(self, url: str):
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None

self.parameters_count = 0

# Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later.
self.call_queue: List[Tuple[str, list]] = []
# this is for waiting the to-be-recovered trials from nnimanager
self._advisor_initialized = False

def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]):
"""
Expand Down Expand Up @@ -167,6 +169,10 @@ def send_trial(self, parameters, placement_constraint=None):
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
while not self._advisor_initialized:
_logger.info('Wait for RetiariiAdvisor to be initialized...')
time.sleep(0.5)

self.parameters_count += 1
if placement_constraint is None:
placement_constraint = {
Expand Down Expand Up @@ -204,6 +210,7 @@ def mark_experiment_as_ending(self):
self.send(CommandType.NoMoreTrialJobs, '')

def handle_request_trial_jobs(self, num_trials):
self._advisor_initialized = True
_logger.debug('Request trial jobs: %s', num_trials)
self.invoke_callback('request_trial_jobs', num_trials)

Expand All @@ -212,10 +219,22 @@ def handle_update_search_space(self, data):
self.search_space = data

def handle_trial_end(self, data):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
id_ = nni.load(data['hyper_params'])['parameter_id']
if self.is_created_in_previous_exp(id_):
_logger.info('The end of the recovered trial %d is ignored', id_)
return
_logger.debug('Trial end: %s', data)
self.invoke_callback('trial_end', nni.load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
self.invoke_callback('trial_end', id_, data['event'] == 'SUCCEEDED')

def handle_report_metric_data(self, data):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
if self.is_created_in_previous_exp(data['parameter_id']):
_logger.info('The metrics of the recovered trial %d are ignored', data['parameter_id'])
return
# NOTE: this part is not aligned with hpo tuners.
# in hpo tuners, trial_job_id is used for intermediate results handling
# parameter_id is for final result handling.
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
Expand All @@ -239,4 +258,5 @@ def handle_import_data(self, data):
pass

def handle_add_customized_trial(self, data):
pass
previous_max_param_id = self.recover_parameter_id(data)
self.parameters_count = previous_max_param_id
1 change: 1 addition & 0 deletions nni/nas/execution/common/integration_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import nni
from nni.common.version import version_check


# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
Expand Down
30 changes: 30 additions & 0 deletions nni/recoverable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from __future__ import annotations

import os
import nni

class Recoverable:
def __init__(self):
self.recovered_max_param_id = -1
self.recovered_trial_params = {}

def load_checkpoint(self) -> None:
pass
Expand All @@ -18,3 +22,29 @@ def get_checkpoint_path(self) -> str | None:
if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path
return None

def recover_parameter_id(self, data) -> int:
# this is for handling the resuming of the interrupted data: parameters
if not isinstance(data, list):
data = [data]

previous_max_param_id = 0
for trial in data:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if isinstance(trial, str):
trial = nni.load(trial)
if not isinstance(trial['parameter_id'], int):
# for dealing with user customized trials
# skip for now
continue
self.recovered_trial_params[trial['parameter_id']] = trial['parameters']
if previous_max_param_id < trial['parameter_id']:
previous_max_param_id = trial['parameter_id']
self.recovered_max_param_id = previous_max_param_id
return previous_max_param_id

def is_created_in_previous_exp(self, param_id: int) -> bool:
return param_id <= self.recovered_max_param_id

def get_previous_param(self, param_id: int) -> dict:
return self.recovered_trial_params[param_id]
24 changes: 15 additions & 9 deletions nni/runtime/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,10 @@ def handle_import_data(self, data):
self.tuner.import_data(data)

def handle_add_customized_trial(self, data):
global _next_parameter_id
# data: parameters
if not isinstance(data, list):
data = [data]

for _ in data:
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)

self.tuner.import_customized_data(data)
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1

def handle_report_metric_data(self, data):
"""
Expand All @@ -137,6 +132,13 @@ def handle_report_metric_data(self, data):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if self.is_created_in_previous_exp(data['parameter_id']):
if data['type'] == MetricType.FINAL:
# only deal with final metric using import data
param = self.get_previous_param(data['parameter_id'])
trial_data = [{'parameter': param, 'value': load(data['value'])}]
self.handle_import_data(trial_data)
return
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = load(data['value'])
Expand Down Expand Up @@ -166,14 +168,18 @@ def handle_trial_end(self, data):
- event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner
"""
id_ = load(data['hyper_params'])['parameter_id']
if self.is_created_in_previous_exp(id_):
# The end of the recovered trial is ignored
return
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
self.tuner.trial_end(id_, data['event'] == 'SUCCEEDED')

def _handle_final_metric_data(self, data):
"""Call tuner to process final results
Expand Down
1 change: 1 addition & 0 deletions nni/runtime/msg_dispatcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable):
"""

def __init__(self, command_channel_url=None):
super().__init__()
self.stopping = False
if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
Expand Down
8 changes: 0 additions & 8 deletions nni/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,6 @@ def import_data(self, data: list[TrialRecord]) -> None:
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass

def import_customized_data(self, data: list[TrialRecord]) -> None:
"""
Internal API under revising, not recommended for end users.
"""
# Import resume data for avoiding duplications
# data: a list of dictionarys, each of which has at least two keys, 'parameter_id' and 'parameters'
pass

def _on_exit(self) -> None:
pass

Expand Down
3 changes: 3 additions & 0 deletions test/algo/nas/test_cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def test_submit_models(self):
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
# normally it becomes true when `handle_request_trial_jobs` is invoked
advisor._advisor_initialized = True

remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
Expand Down
2 changes: 2 additions & 0 deletions test/ut/nas/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._advisor_initialized = True
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
Expand All @@ -44,6 +45,7 @@ def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._advisor_initialized = True
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
Expand Down
10 changes: 5 additions & 5 deletions test/ut/sdk/test_assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class AssessorTestCase(TestCase):
def test_assessor(self):
pass
_reverse_io()
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"parameter_id": 1,"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED","hyper_params":"{\\"parameter_id\\": 0}"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED","hyper_params":"{\\"parameter_id\\": 1}"}')
send(CommandType.NewTrialJob, 'null')
_restore_io()

Expand Down

0 comments on commit bcc640c

Please sign in to comment.