Skip to content

Commit

Permalink
Refine HPO (#2175)
Browse files Browse the repository at this point in the history
* change pipe to queue

* fix a bug that HPO in multi GPU doesn't work rightly

* add unit test

* remove unused hpo weight

* add try except statement

* stop HPO if trials failed more than tree times

* align with pre-commit

* fix mypy issue

* update CHANGELOG.md

* align with mypy
  • Loading branch information
eunwoosh authored May 24, 2023
1 parent e1004eb commit ca29281
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ All notable changes to this project will be documented in this file.
- Action task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1993>)
- Optimize data preprocessing time and enhance overall performance in semantic segmentation (<https://github.com/openvinotoolkit/training_extensions/pull/2020>)
- Support automatic batch size decrease when there is no enough GPU memory (<https://github.com/openvinotoolkit/training_extensions/pull/2022>)
- Refine HPO usability (<https://github.com/openvinotoolkit/training_extensions/pull/2175>)

### Bug fixes

Expand Down
38 changes: 28 additions & 10 deletions otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,15 @@ def _align_batch_size_search_space_to_dataset_size(self):
self._fixed_hp[batch_size_name] = self._train_dataset_size
self._environment.set_hyper_parameter_using_str_key(self._fixed_hp)

def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Dict[str, Any]:
def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Union[Dict[str, Any], None]:
"""Run HPO and provides optimized hyper parameters.
Args:
train_func (Callable): training model function
data_roots (Dict[str, Dict]): dataset path of each dataset type
Returns:
Dict[str, Any]: optimized hyper parameters
Union[Dict[str, Any], None]: Optimized hyper parameters. If there is no best hyper parameter, return None.
"""
self._environment.save_initial_weight(self._get_initial_model_weight_path())
hpo_algo = self._get_hpo_algo()
Expand All @@ -474,7 +474,8 @@ def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Dict[str
resource_type, # type: ignore
)
best_config = hpo_algo.get_best_config()
self._restore_fixed_hp(best_config["config"])
if best_config is not None:
self._restore_fixed_hp(best_config["config"])
hpo_algo.print_result()

return best_config
Expand Down Expand Up @@ -574,17 +575,28 @@ def run_hpo(
logger.info("completed hyper-parameter optimization")

env_manager = TaskEnvironmentManager(environment)
env_manager.set_hyper_parameter_using_str_key(best_config["config"])
best_hpo_weight = get_best_hpo_weight(hpo_save_path, best_config["id"])
if best_hpo_weight is None:
logger.warning("Can not find the best HPO weight. Best HPO wegiht won't be used.")
else:
logger.debug(f"{best_hpo_weight} will be loaded as best HPO weight")
env_manager.load_model_weight(best_hpo_weight, dataset)
best_hpo_weight = None

if best_config is not None:
env_manager.set_hyper_parameter_using_str_key(best_config["config"])
best_hpo_weight = get_best_hpo_weight(hpo_save_path, best_config["id"])
if best_hpo_weight is None:
logger.warning("Can not find the best HPO weight. Best HPO wegiht won't be used.")
else:
logger.debug(f"{best_hpo_weight} will be loaded as best HPO weight")
env_manager.load_model_weight(best_hpo_weight, dataset)

_remove_unused_model_weights(hpo_save_path, best_hpo_weight)
return env_manager.environment


def _remove_unused_model_weights(hpo_save_path: Path, best_hpo_weight: Optional[str] = None):
for weight in hpo_save_path.rglob("*.pth"):
if best_hpo_weight is not None and str(weight) == best_hpo_weight:
continue
weight.unlink()


def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) -> Optional[str]:
"""Get best model weight path of the HPO trial.
Expand Down Expand Up @@ -679,6 +691,12 @@ def run(self):
need_to_save_initial_weight = False
resume_weight_path = self._get_resume_weight_path()
if resume_weight_path is not None:
ret = re.search(r"(\d+)\.pth", resume_weight_path)
if ret is not None:
resume_epoch = int(ret.group(1))
if self._epoch <= resume_epoch: # given epoch is already done
self._report_func(0, 0, done=True)
return
environment.resume_model_weight(resume_weight_path, dataset)
else:
initial_weight = self._load_fixed_initial_weight()
Expand Down
18 changes: 8 additions & 10 deletions otx/hpo/hpo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, trial_id: Any, configuration: Dict, train_environment: Option
self._train_environment = train_environment
self._iteration = None
self.status: TrialStatus = TrialStatus.READY
self._done = False

@property
def id(self):
Expand All @@ -204,6 +205,8 @@ def iteration(self, val):
"""Setter for iteration."""
check_positive(val, "iteration")
self._iteration = val
if self.get_progress() < val:
self._done = False

@property
def train_environment(self):
Expand Down Expand Up @@ -279,21 +282,16 @@ def save_results(self, save_path: str):
json.dump(results, f)

def finalize(self):
"""Let the trial know that training is done.
If the trial isn't trained until given resource, then make it pretend to train until resouce.
"""
if self.get_progress() < self.iteration:
best_score = self.get_best_score()
if best_score is None:
raise RuntimeError(f"Although {self.id} trial doesn't report any score but it's done")
self.register_score(best_score, self.iteration)
"""Set done as True."""
if not self.score:
raise RuntimeError(f"Trial{self.id} didn't report any score but tries to be done.")
self._done = True

def is_done(self):
"""Check the trial is done."""
if self.iteration is None:
raise ValueError("iteration isn't set yet.")
return self.get_progress() >= self.iteration
return self._done or self.get_progress() >= self.iteration


class TrialStatus(IntEnum):
Expand Down
135 changes: 95 additions & 40 deletions otx/hpo/hpo_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import logging
import multiprocessing
import os
import queue
import signal
import sys
import time
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Literal, Optional, Union

Expand All @@ -27,6 +32,15 @@
logger = logging.getLogger(__name__)


@dataclass
class RunningTrial:
"""Data class for a running trial."""

process: multiprocessing.Process
trial: Trial
queue: multiprocessing.Queue


class HpoLoop:
"""HPO loop manager to run trials.
Expand Down Expand Up @@ -54,26 +68,41 @@ def __init__(
):
self._hpo_algo = hpo_algo
self._train_func = train_func
self._running_trials: Dict[int, Dict] = {}
self._running_trials: Dict[int, RunningTrial] = {}
self._mp = multiprocessing.get_context("spawn")
self._report_queue = self._mp.Queue()
self._uid_index = 0
self._trial_fault_count = 0
self._resource_manager = get_resource_manager(
resource_type, num_parallel_trial, num_gpu_for_single_trial, available_gpu
)
self._main_pid = os.getpid()

signal.signal(signal.SIGINT, self._terminate_signal_handler)
signal.signal(signal.SIGTERM, self._terminate_signal_handler)

def run(self):
"""Run a HPO loop."""
logger.info("HPO loop starts.")
while not self._hpo_algo.is_done():
if self._resource_manager.have_available_resource():
trial = self._hpo_algo.get_next_sample()
if trial is not None:
self._start_trial_process(trial)
try:
while not self._hpo_algo.is_done() and self._trial_fault_count < 3:
if self._resource_manager.have_available_resource():
trial = self._hpo_algo.get_next_sample()
if trial is not None:
self._start_trial_process(trial)

self._remove_finished_process()
self._get_reports()

time.sleep(1)
except Exception as e:
self._terminate_all_running_processes()
raise e
logger.info("HPO loop is done.")

self._remove_finished_process()
self._get_reports()
if self._trial_fault_count >= 3:
logger.warning("HPO trials exited abnormally more than three times. HPO is suspended.")

logger.info("HPO loop is done.")
self._get_reports()
self._join_all_processes()

Expand All @@ -90,56 +119,51 @@ def _start_trial_process(self, trial: Trial):
for key, val in env.items():
os.environ[key] = val

pipe1, pipe2 = self._mp.Pipe(True)
trial_queue = self._mp.Queue()
process = self._mp.Process(
target=_run_train,
args=(
self._train_func,
trial.get_train_configuration(),
partial(_report_score, pipe=pipe2, trial_id=trial.id),
partial(_report_score, recv_queue=trial_queue, send_queue=self._report_queue, uid=uid),
),
)
os.environ = origin_env
self._running_trials[uid] = {"process": process, "trial": trial, "pipe": pipe1}
self._running_trials[uid] = RunningTrial(process, trial, trial_queue) # type: ignore
process.start()

def _remove_finished_process(self):
trial_to_remove = []
for uid, val in self._running_trials.items():
process = val["process"]
if not process.is_alive():
val["pipe"].close()
process.join()
for uid, trial in self._running_trials.items():
if not trial.process.is_alive():
if trial.process.exitcode != 0:
self._trial_fault_count += 1
trial.queue.close()
trial.process.join()
trial_to_remove.append(uid)

for uid in trial_to_remove:
trial = self._running_trials[uid]["trial"]
trial.status = TrialStatus.STOP
self._running_trials[uid].trial.status = TrialStatus.STOP
self._resource_manager.release_resource(uid)
del self._running_trials[uid]

def _get_reports(self):
for trial in self._running_trials.values():
pipe = trial["pipe"]
if pipe.poll():
try:
report = pipe.recv()
except EOFError:
continue
trial_status = self._hpo_algo.report_score(
report["score"], report["progress"], report["trial_id"], report["done"]
)
pipe.send(trial_status)
while not self._report_queue.empty():
report = self._report_queue.get_nowait()
trial = self._running_trials[report["uid"]]
trial_status = self._hpo_algo.report_score(
report["score"], report["progress"], trial.trial.id, report["done"]
)
trial.queue.put_nowait(trial_status)

self._hpo_algo.save_results()

def _join_all_processes(self):
for val in self._running_trials.values():
val["pipe"].close()
val.queue.close()

for val in self._running_trials.values():
process = val["process"]
process.join()
val.process.join()

self._running_trials = {}

Expand All @@ -148,23 +172,54 @@ def _get_uid(self) -> int:
self._uid_index += 1
return uid

def _terminate_all_running_processes(self):
for trial in self._running_trials.values():
trial.queue.close()
process = trial.process
if process.is_alive():
logger.info(f"Kill child process {process.pid}")
process.kill()

def _terminate_signal_handler(self, signum, _frame):
# This code prevents child processses from being killed unintentionally by proccesses forked from main process
if self._main_pid != os.getpid():
sys.exit()

self._terminate_all_running_processes()

singal_name = {2: "SIGINT", 15: "SIGTERM"}
logger.warning(f"{singal_name[signum]} is sent. process exited.")

sys.exit(1)


def _run_train(train_func: Callable, hp_config: Dict, report_func: Callable):
# set multi process method as default
multiprocessing.set_start_method(None, True) # type: ignore
train_func(hp_config, report_func)


def _report_score(score: Union[int, float], progress: Union[int, float], pipe, trial_id: Any, done: bool = False):
logger.debug(f"score : {score}, progress : {progress}, trial_id : {trial_id}, pid : {os.getpid()}, done : {done}")
def _report_score(
score: Union[int, float],
progress: Union[int, float],
recv_queue: multiprocessing.Queue,
send_queue: multiprocessing.Queue,
uid: Any,
done: bool = False,
):
logger.debug(f"score : {score}, progress : {progress}, uid : {uid}, pid : {os.getpid()}, done : {done}")
try:
pipe.send({"score": score, "progress": progress, "trial_id": trial_id, "pid": os.getpid(), "done": done})
except BrokenPipeError:
send_queue.put_nowait({"score": score, "progress": progress, "uid": uid, "pid": os.getpid(), "done": done})
except ValueError:
return TrialStatus.STOP

try:
trial_status = pipe.recv()
except EOFError:
return TrialStatus.STOP
trial_status = recv_queue.get(timeout=3)
except queue.Empty:
return TrialStatus.RUNNING

while not recv_queue.empty():
trial_status = recv_queue.get_nowait()

logger.debug(f"trial_status : {trial_status}")
return trial_status
Expand Down
4 changes: 3 additions & 1 deletion otx/hpo/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, trial_id: Any, configuration: Dict, train_environment: Option
super().__init__(trial_id, configuration, train_environment)
self._rung: Optional[int] = None
self._bracket: Optional[int] = None
self.estimating_max_resource: bool = False

@property
def rung(self):
Expand Down Expand Up @@ -708,6 +709,7 @@ def _make_trial_to_estimate_resource(self) -> AshaTrial:
if len(self._trials) == 1: # first trial to estimate
trial.bracket = 0
trial.iteration = self.num_full_iterations
trial.estimating_max_resource = True
elif self._minimum_resource is not None:
trial.iteration = self._minimum_resource
else:
Expand Down Expand Up @@ -917,7 +919,7 @@ def report_score(
"""
trial = self._trials[trial_id]
if done:
if self.maximum_resource is None:
if self.maximum_resource is None and trial.estimating_max_resource:
self.maximum_resource = trial.get_progress()
self.num_full_iterations = self.maximum_resource
if not self._need_to_find_resource_value():
Expand Down
Loading

0 comments on commit ca29281

Please sign in to comment.