From 8b1f67e495251d840221fe5b4c42390ee31a6ab9 Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Fri, 3 May 2024 17:20:30 +0900 Subject: [PATCH 1/4] draft implementation --- src/otx/cli/utils/hpo.py | 85 +++++++++++++++++++++++++++------------- src/otx/hpo/hyperband.py | 36 +++++++++++++++++ 2 files changed, 93 insertions(+), 28 deletions(-) diff --git a/src/otx/cli/utils/hpo.py b/src/otx/cli/utils/hpo.py index 3511ebf3e15..b9f30b63140 100644 --- a/src/otx/cli/utils/hpo.py +++ b/src/otx/cli/utils/hpo.py @@ -478,6 +478,9 @@ def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Union[Di progress_updater_thread = Thread(target=self._update_hpo_progress, args=[hpo_algo], daemon=True) progress_updater_thread.start() + remove_unused_model_weight = Thread(target=self._remove_unused_weight, args=[hpo_algo, self._hpo_workdir]) + remove_unused_model_weight.start() + if torch.cuda.is_available(): resource_type = "gpu" elif is_xpu_available(): @@ -579,6 +582,23 @@ def _update_hpo_progress(self, hpo_algo: HpoBase): self._progress_updater_callback(hpo_algo.get_progress() * 100) time.sleep(1) + def _remove_unused_weight(self, hpo_algo: HpoBase, hpo_work_dir: Path): + """Function for a thread to report a HPO progress regularly. + + Args: + hpo_algo (HpoBase): HPO algorithm class + """ + + while not hpo_algo.is_done(): + finished_trials = hpo_algo.get_finished_trials() + for trial in finished_trials: + dir_to_remove = hpo_work_dir / "weight" / str(trial.id) + if dir_to_remove.exists(): + for file in dir_to_remove.iterdir(): + file.unlink() + # shutil.rmtree(dir_to_remove) + time.sleep(1) + def run_hpo( hpo_time_ratio: int, @@ -636,11 +656,11 @@ def run_hpo( logger.debug(f"{best_hpo_weight} will be loaded as best HPO weight") env_manager.load_model_weight(best_hpo_weight, dataset) - _remove_unused_model_weights(hpo_save_path, best_hpo_weight) + # _remove_model_weights_except_best(hpo_save_path, best_hpo_weight) return env_manager.environment -def _remove_unused_model_weights(hpo_save_path: Path, best_hpo_weight: Optional[str] = None): +def _remove_model_weights_except_best(hpo_save_path: Path, best_hpo_weight: Optional[str] = None): for weight in hpo_save_path.rglob("*.pth"): if best_hpo_weight is not None and str(weight) == best_hpo_weight: continue @@ -663,7 +683,27 @@ def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) - return None trial_output_file = trial_output_files[0] - with trial_output_file.open("r") as f: + _, best_epochs = _get_best_score_and_epoch(trial_output_file) + + best_weight = None + for best_epoch in best_epochs: + best_weight_path = list(hpo_dir.glob(f"weight/{trial_id}/*epoch*{best_epoch}*")) + if best_weight_path: + best_weight = str(best_weight_path[0]) + + return best_weight + + +def _get_best_score_and_epoch(trial_json: Path)-> tuple[int | float | None, list[int]]: + """Get best score and epochs according to json file of the trial. + + Args: + trial_json (Path): Json file of the trial. + + Returns: + tuple[int | float | None, list[int]]: best score and best epochs list. + """ + with trial_json.open("r") as f: trial_output = json.load(f) best_epochs = [] @@ -677,14 +717,8 @@ def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) - best_epochs = [eph] elif best_score == score: best_epochs.append(eph) - - best_weight = None - for best_epoch in best_epochs: - best_weight_path = list(hpo_dir.glob(f"weight/{trial_id}/*epoch*{best_epoch}*")) - if best_weight_path: - best_weight = str(best_weight_path[0]) - - return best_weight + + return best_score, best_epochs class Trainer: @@ -740,6 +774,7 @@ def run(self): need_to_save_initial_weight = False resume_weight_path = self._get_resume_weight_path() + resume_epoch = None if resume_weight_path is not None: ret = re.search(r"(\d+)\.pth", resume_weight_path) if ret is not None: @@ -763,7 +798,6 @@ def run(self): score_report_callback = self._prepare_score_report_callback(task) task.train(dataset=dataset, output_model=output_model, train_parameters=score_report_callback) self._finalize_trial(task) - self._delete_unused_model_weight() def _prepare_hyper_parameter(self): return create(self._model_template.hyper_parameters.data) @@ -833,27 +867,22 @@ def _get_initial_weight_path(self) -> Path: def _finalize_trial(self, task): weight_dir_path = self._get_weight_dir_path() weight_dir_path.mkdir(parents=True, exist_ok=True) - self._task.copy_weight(task.project_path, weight_dir_path) self._report_func(0, 0, done=True) + trial_id: str = self._hp_config["id"] + self._task.copy_weight(task.project_path, weight_dir_path) + weight_dir = self._hpo_workdir / "weight" / trial_id + if not weight_dir.exists(): + return + latest_model_weight = self._task.get_latest_weight(weight_dir) + best_model_weight = get_best_hpo_weight(self._hpo_workdir, trial_id) + for each_model_weight in weight_dir.iterdir(): + if str(each_model_weight) not in [latest_model_weight, best_model_weight]: + each_model_weight.unlink() + def _get_weight_dir_path(self) -> Path: return self._hpo_workdir / "weight" / self._hp_config["id"] - def _delete_unused_model_weight(self): - """Delete model weights except best and latest model weight.""" - for json_file in self._hpo_workdir.rglob("*.json"): - if not json_file.stem.isnumeric(): - continue - trial_num = json_file.stem - weight_dir = self._hpo_workdir / "weight" / trial_num - if not weight_dir.exists(): - continue - latest_model_weight = self._task.get_latest_weight(weight_dir) - best_model_weight = get_best_hpo_weight(self._hpo_workdir, trial_num) - for each_model_weight in weight_dir.iterdir(): - if str(each_model_weight) not in [latest_model_weight, best_model_weight]: - each_model_weight.unlink() - def run_trial( hp_config: Dict[str, Any], diff --git a/src/otx/hpo/hyperband.py b/src/otx/hpo/hyperband.py index 49e5b5003ed..a347be7a076 100644 --- a/src/otx/hpo/hyperband.py +++ b/src/otx/hpo/hyperband.py @@ -240,6 +240,28 @@ def get_next_trial(self) -> Optional[AshaTrial]: return trial return None + def get_finished_trials(self, mode: str = "max") -> List[AshaTrial]: + finished_trials = [] + num_temp = [] + num_trials_to_promote = self._num_required_trial // self._reduction_factor + if num_trials_to_promote <= 0: + return [] + for trial in self._trials: + if trial.rung == self._rung_idx and not trial.is_done(): + continue + num_temp.append((trial, trial.get_best_score(mode, self.resource))) + + if len(num_temp) <= num_trials_to_promote: + return [] + + num_temp = sorted(num_temp, key=lambda x : x[1], reverse=mode=="max") + criteria = num_temp[num_trials_to_promote - 1][1] + for trial, trial_score in num_temp[num_trials_to_promote:]: + if left_vlaue_is_better(criteria, trial_score, mode): + finished_trials.append(trial) + + return finished_trials + class Bracket: """Bracket class. It operates a single SHA using multiple rungs. @@ -484,6 +506,13 @@ def _get_result(self): ], } + def get_finished_trials(self) -> List[AshaTrial]: + finished_trials = [] + for rung in self._rungs[:-1]: + finished_trials.extend(rung.get_finished_trials(self._mode)) + + return finished_trials + class HyperBand(HpoBase): """It implements the Asyncronous HyperBand scheduler with iterations only. @@ -978,3 +1007,10 @@ def print_result(self): ) for bracket in self._brackets.values(): bracket.print_result() + + def get_finished_trials(self) -> List[AshaTrial]: + finished_trials = [] + for bracket in self._brackets.values(): + finished_trials.extend(bracket.get_finished_trials()) + + return finished_trials From 3f89d7705e6109c707a65bb9b20c7ce86578ad7b Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Tue, 7 May 2024 09:48:42 +0900 Subject: [PATCH 2/4] remove whole directory --- src/otx/cli/utils/hpo.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/otx/cli/utils/hpo.py b/src/otx/cli/utils/hpo.py index b9f30b63140..273b9dc067e 100644 --- a/src/otx/cli/utils/hpo.py +++ b/src/otx/cli/utils/hpo.py @@ -594,9 +594,7 @@ def _remove_unused_weight(self, hpo_algo: HpoBase, hpo_work_dir: Path): for trial in finished_trials: dir_to_remove = hpo_work_dir / "weight" / str(trial.id) if dir_to_remove.exists(): - for file in dir_to_remove.iterdir(): - file.unlink() - # shutil.rmtree(dir_to_remove) + shutil.rmtree(dir_to_remove) time.sleep(1) From 11a087090a7c31c5a007300f2436351fcd3715fb Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Tue, 7 May 2024 10:39:09 +0900 Subject: [PATCH 3/4] refactor code --- src/otx/cli/utils/hpo.py | 69 ++++++++++++++-------------------------- src/otx/hpo/hpo_base.py | 11 +++++-- src/otx/hpo/hyperband.py | 34 +++++++++++--------- 3 files changed, 51 insertions(+), 63 deletions(-) diff --git a/src/otx/cli/utils/hpo.py b/src/otx/cli/utils/hpo.py index 273b9dc067e..5d64608e41f 100644 --- a/src/otx/cli/utils/hpo.py +++ b/src/otx/cli/utils/hpo.py @@ -478,7 +478,9 @@ def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Union[Di progress_updater_thread = Thread(target=self._update_hpo_progress, args=[hpo_algo], daemon=True) progress_updater_thread.start() - remove_unused_model_weight = Thread(target=self._remove_unused_weight, args=[hpo_algo, self._hpo_workdir]) + remove_unused_model_weight = Thread( + target=self._remove_unused_weight, args=[hpo_algo, self._hpo_workdir], daemon=True + ) remove_unused_model_weight.start() if torch.cuda.is_available(): @@ -586,11 +588,12 @@ def _remove_unused_weight(self, hpo_algo: HpoBase, hpo_work_dir: Path): """Function for a thread to report a HPO progress regularly. Args: - hpo_algo (HpoBase): HPO algorithm class + hpo_algo (HpoBase): HPO algorithm instance. + hpo_work_dir (Path): HPO work directory. """ while not hpo_algo.is_done(): - finished_trials = hpo_algo.get_finished_trials() + finished_trials = hpo_algo.get_inferior_trials() for trial in finished_trials: dir_to_remove = hpo_work_dir / "weight" / str(trial.id) if dir_to_remove.exists(): @@ -654,17 +657,9 @@ def run_hpo( logger.debug(f"{best_hpo_weight} will be loaded as best HPO weight") env_manager.load_model_weight(best_hpo_weight, dataset) - # _remove_model_weights_except_best(hpo_save_path, best_hpo_weight) return env_manager.environment -def _remove_model_weights_except_best(hpo_save_path: Path, best_hpo_weight: Optional[str] = None): - for weight in hpo_save_path.rglob("*.pth"): - if best_hpo_weight is not None and str(weight) == best_hpo_weight: - continue - weight.unlink() - - def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) -> Optional[str]: """Get best model weight path of the HPO trial. @@ -681,27 +676,7 @@ def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) - return None trial_output_file = trial_output_files[0] - _, best_epochs = _get_best_score_and_epoch(trial_output_file) - - best_weight = None - for best_epoch in best_epochs: - best_weight_path = list(hpo_dir.glob(f"weight/{trial_id}/*epoch*{best_epoch}*")) - if best_weight_path: - best_weight = str(best_weight_path[0]) - - return best_weight - - -def _get_best_score_and_epoch(trial_json: Path)-> tuple[int | float | None, list[int]]: - """Get best score and epochs according to json file of the trial. - - Args: - trial_json (Path): Json file of the trial. - - Returns: - tuple[int | float | None, list[int]]: best score and best epochs list. - """ - with trial_json.open("r") as f: + with trial_output_file.open("r") as f: trial_output = json.load(f) best_epochs = [] @@ -715,8 +690,14 @@ def _get_best_score_and_epoch(trial_json: Path)-> tuple[int | float | None, list best_epochs = [eph] elif best_score == score: best_epochs.append(eph) - - return best_score, best_epochs + + best_weight = None + for best_epoch in best_epochs: + best_weight_path = list(hpo_dir.glob(f"weight/{trial_id}/*epoch*{best_epoch}*")) + if best_weight_path: + best_weight = str(best_weight_path[0]) + + return best_weight class Trainer: @@ -772,7 +753,6 @@ def run(self): need_to_save_initial_weight = False resume_weight_path = self._get_resume_weight_path() - resume_epoch = None if resume_weight_path is not None: ret = re.search(r"(\d+)\.pth", resume_weight_path) if ret is not None: @@ -863,19 +843,18 @@ def _get_initial_weight_path(self) -> Path: return self._hpo_workdir / self._initial_weight_name def _finalize_trial(self, task): - weight_dir_path = self._get_weight_dir_path() - weight_dir_path.mkdir(parents=True, exist_ok=True) self._report_func(0, 0, done=True) - trial_id: str = self._hp_config["id"] + weight_dir_path = self._get_weight_dir_path() + weight_dir_path.mkdir(parents=True, exist_ok=True) self._task.copy_weight(task.project_path, weight_dir_path) - weight_dir = self._hpo_workdir / "weight" / trial_id - if not weight_dir.exists(): - return - latest_model_weight = self._task.get_latest_weight(weight_dir) - best_model_weight = get_best_hpo_weight(self._hpo_workdir, trial_id) - for each_model_weight in weight_dir.iterdir(): - if str(each_model_weight) not in [latest_model_weight, best_model_weight]: + latest_model_weight = self._task.get_latest_weight(weight_dir_path) + best_model_weight = get_best_hpo_weight(self._hpo_workdir, self._hp_config["id"]) + for each_model_weight in weight_dir_path.iterdir(): + for neccesary_weight in [latest_model_weight, best_model_weight]: + if each_model_weight.samefile(neccesary_weight): + break + else: each_model_weight.unlink() def _get_weight_dir_path(self) -> Path: diff --git a/src/otx/hpo/hpo_base.py b/src/otx/hpo/hpo_base.py index 17ebc9da4be..2d5e4b3a1e5 100644 --- a/src/otx/hpo/hpo_base.py +++ b/src/otx/hpo/hpo_base.py @@ -18,7 +18,7 @@ import tempfile from abc import ABC, abstractmethod from enum import IntEnum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from otx.hpo.search_space import SearchSpace from otx.hpo.utils import check_mode_input, check_positive @@ -36,7 +36,7 @@ class HpoBase(ABC): Args: search_space (Dict[str, Dict[str, Any]]): hyper parameter search space to find. save_path (Optional[str]): path where result of HPO is saved. - mode (str, optinal): One of {min, max}. Determines whether objective is + mode (Literal["min", "max], optinal): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. num_trials (Optional[int]): How many training to conduct for HPO. num_workers (int): How many trains are executed in parallel. @@ -66,7 +66,7 @@ def __init__( self, search_space: Dict[str, Dict[str, Any]], save_path: Optional[str] = None, - mode: str = "max", + mode: Literal["min", "max"] = "max", num_trials: Optional[int] = None, num_workers: int = 1, num_full_iterations: Union[int, float] = 1, @@ -166,6 +166,11 @@ def get_best_config(self): """Get best config of HPO algorithm.""" raise NotImplementedError + @abstractmethod + def get_inferior_trials(self) -> List["Trial"]: + """Get trials which can't be best a trial.""" + raise NotImplementedError + class Trial: """Trial to train with given hyper parameters. diff --git a/src/otx/hpo/hyperband.py b/src/otx/hpo/hyperband.py index a347be7a076..98fd2316c46 100644 --- a/src/otx/hpo/hyperband.py +++ b/src/otx/hpo/hyperband.py @@ -18,7 +18,7 @@ import math import os from os import path as osp -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from scipy.stats.qmc import LatinHypercube @@ -240,23 +240,24 @@ def get_next_trial(self) -> Optional[AshaTrial]: return trial return None - def get_finished_trials(self, mode: str = "max") -> List[AshaTrial]: + def get_inferior_trials(self, mode: Literal["min", "max"] = "max") -> List[AshaTrial]: + """Get trials which was done but can't be promoted.""" finished_trials = [] - num_temp = [] - num_trials_to_promote = self._num_required_trial // self._reduction_factor + rung_trials: List[Tuple[AshaTrial, Union[int, float]]] = [] + num_trials_to_promote = self._num_required_trial // self._reduction_factor if num_trials_to_promote <= 0: return [] for trial in self._trials: if trial.rung == self._rung_idx and not trial.is_done(): continue - num_temp.append((trial, trial.get_best_score(mode, self.resource))) + rung_trials.append((trial, trial.get_best_score(mode, self.resource))) - if len(num_temp) <= num_trials_to_promote: + if len(rung_trials) <= num_trials_to_promote: return [] - num_temp = sorted(num_temp, key=lambda x : x[1], reverse=mode=="max") - criteria = num_temp[num_trials_to_promote - 1][1] - for trial, trial_score in num_temp[num_trials_to_promote:]: + rung_trials = sorted(rung_trials, key=lambda x: x[1], reverse=mode == "max") + criteria = rung_trials[num_trials_to_promote - 1][1] + for trial, trial_score in rung_trials[num_trials_to_promote:]: if left_vlaue_is_better(criteria, trial_score, mode): finished_trials.append(trial) @@ -273,7 +274,8 @@ class Bracket: hyper_parameter_configurations (List[AshaTrial]): Hyper parameter configuration to try. reduction_factor (int): Decicdes how many trials to promote to next rung. Only top 1 / reduction_factor of rung trials can be promoted. - mode (str, optional): Decide which trial is better between having highest score or lowest score. + mode (Literal["min", "max], optional): + Decide which trial is better between having highest score or lowest score. Defaults to "max". asynchronous_sha (bool, optional): Whether to operate SHA asynchronously. Defaults to True. """ @@ -287,7 +289,7 @@ def __init__( maximum_resource: Union[float, int], hyper_parameter_configurations: List[AshaTrial], reduction_factor: int = 3, - mode: str = "max", + mode: Literal["min", "max"] = "max", asynchronous_sha: bool = True, ): # pylint: disable=too-many-arguments @@ -506,10 +508,11 @@ def _get_result(self): ], } - def get_finished_trials(self) -> List[AshaTrial]: + def get_inferior_trials(self) -> List[AshaTrial]: + """Get trials which can't be best a trial.""" finished_trials = [] for rung in self._rungs[:-1]: - finished_trials.extend(rung.get_finished_trials(self._mode)) + finished_trials.extend(rung.get_inferior_trials(self._mode)) return finished_trials @@ -1008,9 +1011,10 @@ def print_result(self): for bracket in self._brackets.values(): bracket.print_result() - def get_finished_trials(self) -> List[AshaTrial]: + def get_inferior_trials(self) -> List[AshaTrial]: # type: ignore[override] + """Get trials which can't be best a trial.""" finished_trials = [] for bracket in self._brackets.values(): - finished_trials.extend(bracket.get_finished_trials()) + finished_trials.extend(bracket.get_inferior_trials()) return finished_trials From d6554c2868bcb709ed59328d82915c1deff4fe7f Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Tue, 7 May 2024 12:34:04 +0900 Subject: [PATCH 4/4] update unit test --- src/otx/cli/utils/hpo.py | 20 ++++++++--- tests/unit/cli/utils/test_hpo.py | 59 +++++++++----------------------- 2 files changed, 33 insertions(+), 46 deletions(-) diff --git a/src/otx/cli/utils/hpo.py b/src/otx/cli/utils/hpo.py index 5d64608e41f..062b30ec214 100644 --- a/src/otx/cli/utils/hpo.py +++ b/src/otx/cli/utils/hpo.py @@ -657,9 +657,17 @@ def run_hpo( logger.debug(f"{best_hpo_weight} will be loaded as best HPO weight") env_manager.load_model_weight(best_hpo_weight, dataset) + _remove_unused_model_weights(hpo_save_path, best_hpo_weight) return env_manager.environment +def _remove_unused_model_weights(hpo_save_path: Path, best_hpo_weight: Optional[str] = None): + for weight in hpo_save_path.rglob("*.pth"): + if best_hpo_weight is not None and str(weight) == best_hpo_weight: + continue + weight.unlink() + + def get_best_hpo_weight(hpo_dir: Union[str, Path], trial_id: Union[str, Path]) -> Optional[str]: """Get best model weight path of the HPO trial. @@ -848,11 +856,15 @@ def _finalize_trial(self, task): weight_dir_path = self._get_weight_dir_path() weight_dir_path.mkdir(parents=True, exist_ok=True) self._task.copy_weight(task.project_path, weight_dir_path) - latest_model_weight = self._task.get_latest_weight(weight_dir_path) - best_model_weight = get_best_hpo_weight(self._hpo_workdir, self._hp_config["id"]) + necessary_weights = [ + self._task.get_latest_weight(weight_dir_path), + get_best_hpo_weight(self._hpo_workdir, self._hp_config["id"]), + ] + while None in necessary_weights: + necessary_weights.remove(None) for each_model_weight in weight_dir_path.iterdir(): - for neccesary_weight in [latest_model_weight, best_model_weight]: - if each_model_weight.samefile(neccesary_weight): + for necessary_weight in necessary_weights: + if each_model_weight.samefile(necessary_weight): break else: each_model_weight.unlink() diff --git a/tests/unit/cli/utils/test_hpo.py b/tests/unit/cli/utils/test_hpo.py index f01a048a195..e9858077d96 100644 --- a/tests/unit/cli/utils/test_hpo.py +++ b/tests/unit/cli/utils/test_hpo.py @@ -464,8 +464,13 @@ def test_init_fix_batch_size(self, cls_task_env, diff_from_min_bs): hpo_runner = HpoRunner(cls_task_env, train_dataset_size, 10, "fake_path") assert batch_size_name in hpo_runner._fixed_hp + @pytest.fixture + def mock_thread(self, mocker) -> MagicMock: + mock_thread = mocker.patch.object(hpo, "Thread") + return mock_thread + @e2e_pytest_unit - def test_run_hpo(self, mocker, cls_task_env): + def test_run_hpo(self, mocker, cls_task_env, mock_thread): cls_task_env.model = None hpo_runner = HpoRunner(cls_task_env, 100, 10, "fake_path") mock_run_hpo_loop = mocker.patch("otx.cli.utils.hpo.run_hpo_loop") @@ -477,7 +482,7 @@ def test_run_hpo(self, mocker, cls_task_env): mock_hb.assert_called() # make hyperband @e2e_pytest_unit - def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env): + def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env, mock_thread): cls_task_env.model = None hpo_runner = HpoRunner(cls_task_env, 2, 10, "fake_path") mock_run_hpo_loop = mocker.patch("otx.cli.utils.hpo.run_hpo_loop") @@ -494,6 +499,8 @@ class TestTrainer: def setup(self, tmp_dir): self.weight_format = "epoch_{}.pth" self.hpo_workdir = Path(tmp_dir) / "hpo_dir" + self.hpo_workdir.mkdir() + self.trial_id = "1" @pytest.fixture def tmp_dir(self): @@ -519,6 +526,8 @@ def mock_task(self, mocker, tmp_dir): fake_project_path.mkdir(parents=True) for i in range(1, 5): (fake_project_path / self.weight_format.format(i)).write_text("fake") + with (self.hpo_workdir / f"{self.trial_id}.json").open("w") as f: + json.dump({"id": self.trial_id, "score": {"1": 1, "2": 2, "3": 5, "4": 4}}, f) mock_get_train_task = mocker.patch.object(TaskEnvironmentManager, "get_train_task") mock_task = mocker.MagicMock() @@ -552,8 +561,12 @@ def test_run(self, mocker, cls_template_path, mock_task, tmp_dir): # check mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report assert self.hpo_workdir.exists() # make a directory to copy weight - for i in range(1, 5): # check model weights are copied - assert (self.hpo_workdir / "weight" / trial_id / self.weight_format.format(i)).exists() + assert ( + self.hpo_workdir / "weight" / trial_id / self.weight_format.format(3) + ).exists() # check best weight exists + assert ( + self.hpo_workdir / "weight" / trial_id / self.weight_format.format(4) + ).exists() # check last weight exists mock_task.train.assert_called() # check task.train() is called @@ -589,44 +602,6 @@ def test_run_trial_already_done(self, mocker, cls_template_path, mock_task, tmp_ mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report mock_task.train.assert_not_called() # check task.train() is called - @e2e_pytest_unit - def test_delete_unused_model_weight(self, mocker, cls_template_path): - # prepare - trial0_weight_dir = self.hpo_workdir / "weight" / "0" - mocker.patch( - "otx.cli.utils.hpo.TaskManager.get_latest_weight", return_value=str(trial0_weight_dir / "latest.pth") - ) - mocker.patch("otx.cli.utils.hpo.get_best_hpo_weight", return_value=str(trial0_weight_dir / "best.pth")) - - self.hpo_workdir.mkdir() - (self.hpo_workdir / "0.json").touch() - for i in range(2): - weight_dir = self.hpo_workdir / "weight" / str(i) - weight_dir.mkdir(parents=True) - (weight_dir / "latest.pth").touch() - (weight_dir / "best.pth").touch() - (weight_dir / "unused.pth").touch() - - # run - trainer = Trainer( - hp_config={"configuration": {"iterations": 10}, "id": "1"}, - report_func=mocker.MagicMock(), - model_template=find_and_parse_model_template(cls_template_path), - data_roots=mocker.MagicMock(), - task_type=TaskType.CLASSIFICATION, - hpo_workdir=self.hpo_workdir, - initial_weight_name="fake", - metric="fake", - ) - trainer._delete_unused_model_weight() - - assert sorted([f.name for f in (self.hpo_workdir / "weight" / "0").iterdir()]) == sorted( - ["latest.pth", "best.pth"] - ) - assert sorted([f.name for f in (self.hpo_workdir / "weight" / "1").iterdir()]) == sorted( - ["latest.pth", "best.pth", "unused.pth"] - ) - class TestHpoCallback: @e2e_pytest_unit