Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary model weights during HPO #3459

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 35 additions & 17 deletions src/otx/cli/utils/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,11 @@ def run_hpo(self, train_func: Callable, data_roots: Dict[str, Dict]) -> Union[Di
progress_updater_thread = Thread(target=self._update_hpo_progress, args=[hpo_algo], daemon=True)
progress_updater_thread.start()

remove_unused_model_weight = Thread(
target=self._remove_unused_weight, args=[hpo_algo, self._hpo_workdir], daemon=True
)
remove_unused_model_weight.start()

if torch.cuda.is_available():
resource_type = "gpu"
elif is_xpu_available():
Expand Down Expand Up @@ -579,6 +584,22 @@ def _update_hpo_progress(self, hpo_algo: HpoBase):
self._progress_updater_callback(hpo_algo.get_progress() * 100)
time.sleep(1)

def _remove_unused_weight(self, hpo_algo: HpoBase, hpo_work_dir: Path):
"""Function for a thread to report a HPO progress regularly.

Args:
hpo_algo (HpoBase): HPO algorithm instance.
hpo_work_dir (Path): HPO work directory.
"""

while not hpo_algo.is_done():
finished_trials = hpo_algo.get_inferior_trials()
for trial in finished_trials:
dir_to_remove = hpo_work_dir / "weight" / str(trial.id)
if dir_to_remove.exists():
shutil.rmtree(dir_to_remove)
time.sleep(1)


def run_hpo(
hpo_time_ratio: int,
Expand Down Expand Up @@ -763,7 +784,6 @@ def run(self):
score_report_callback = self._prepare_score_report_callback(task)
task.train(dataset=dataset, output_model=output_model, train_parameters=score_report_callback)
self._finalize_trial(task)
self._delete_unused_model_weight()

def _prepare_hyper_parameter(self):
return create(self._model_template.hyper_parameters.data)
Expand Down Expand Up @@ -831,29 +851,27 @@ def _get_initial_weight_path(self) -> Path:
return self._hpo_workdir / self._initial_weight_name

def _finalize_trial(self, task):
self._report_func(0, 0, done=True)

weight_dir_path = self._get_weight_dir_path()
weight_dir_path.mkdir(parents=True, exist_ok=True)
self._task.copy_weight(task.project_path, weight_dir_path)
self._report_func(0, 0, done=True)
necessary_weights = [
self._task.get_latest_weight(weight_dir_path),
get_best_hpo_weight(self._hpo_workdir, self._hp_config["id"]),
]
while None in necessary_weights:
necessary_weights.remove(None)
for each_model_weight in weight_dir_path.iterdir():
for necessary_weight in necessary_weights:
if each_model_weight.samefile(necessary_weight):
break
else:
each_model_weight.unlink()

def _get_weight_dir_path(self) -> Path:
return self._hpo_workdir / "weight" / self._hp_config["id"]

def _delete_unused_model_weight(self):
"""Delete model weights except best and latest model weight."""
for json_file in self._hpo_workdir.rglob("*.json"):
if not json_file.stem.isnumeric():
continue
trial_num = json_file.stem
weight_dir = self._hpo_workdir / "weight" / trial_num
if not weight_dir.exists():
continue
latest_model_weight = self._task.get_latest_weight(weight_dir)
best_model_weight = get_best_hpo_weight(self._hpo_workdir, trial_num)
for each_model_weight in weight_dir.iterdir():
if str(each_model_weight) not in [latest_model_weight, best_model_weight]:
each_model_weight.unlink()


def run_trial(
hp_config: Dict[str, Any],
Expand Down
11 changes: 8 additions & 3 deletions src/otx/hpo/hpo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tempfile
from abc import ABC, abstractmethod
from enum import IntEnum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from otx.hpo.search_space import SearchSpace
from otx.hpo.utils import check_mode_input, check_positive
Expand All @@ -36,7 +36,7 @@ class HpoBase(ABC):
Args:
search_space (Dict[str, Dict[str, Any]]): hyper parameter search space to find.
save_path (Optional[str]): path where result of HPO is saved.
mode (str, optinal): One of {min, max}. Determines whether objective is
mode (Literal["min", "max], optinal): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
num_trials (Optional[int]): How many training to conduct for HPO.
num_workers (int): How many trains are executed in parallel.
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
self,
search_space: Dict[str, Dict[str, Any]],
save_path: Optional[str] = None,
mode: str = "max",
mode: Literal["min", "max"] = "max",
num_trials: Optional[int] = None,
num_workers: int = 1,
num_full_iterations: Union[int, float] = 1,
Expand Down Expand Up @@ -166,6 +166,11 @@ def get_best_config(self):
"""Get best config of HPO algorithm."""
raise NotImplementedError

@abstractmethod
def get_inferior_trials(self) -> List["Trial"]:
"""Get trials which can't be best a trial."""
raise NotImplementedError


class Trial:
"""Trial to train with given hyper parameters.
Expand Down
46 changes: 43 additions & 3 deletions src/otx/hpo/hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import os
from os import path as osp
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from scipy.stats.qmc import LatinHypercube

Expand Down Expand Up @@ -240,6 +240,29 @@ def get_next_trial(self) -> Optional[AshaTrial]:
return trial
return None

def get_inferior_trials(self, mode: Literal["min", "max"] = "max") -> List[AshaTrial]:
"""Get trials which was done but can't be promoted."""
finished_trials = []
rung_trials: List[Tuple[AshaTrial, Union[int, float]]] = []
num_trials_to_promote = self._num_required_trial // self._reduction_factor
if num_trials_to_promote <= 0:
return []
for trial in self._trials:
if trial.rung == self._rung_idx and not trial.is_done():
continue
rung_trials.append((trial, trial.get_best_score(mode, self.resource)))

if len(rung_trials) <= num_trials_to_promote:
return []

rung_trials = sorted(rung_trials, key=lambda x: x[1], reverse=mode == "max")
criteria = rung_trials[num_trials_to_promote - 1][1]
for trial, trial_score in rung_trials[num_trials_to_promote:]:
if left_vlaue_is_better(criteria, trial_score, mode):
finished_trials.append(trial)

return finished_trials


class Bracket:
"""Bracket class. It operates a single SHA using multiple rungs.
Expand All @@ -251,7 +274,8 @@ class Bracket:
hyper_parameter_configurations (List[AshaTrial]): Hyper parameter configuration to try.
reduction_factor (int): Decicdes how many trials to promote to next rung.
Only top 1 / reduction_factor of rung trials can be promoted.
mode (str, optional): Decide which trial is better between having highest score or lowest score.
mode (Literal["min", "max], optional):
Decide which trial is better between having highest score or lowest score.
Defaults to "max".
asynchronous_sha (bool, optional): Whether to operate SHA asynchronously. Defaults to True.
"""
Expand All @@ -265,7 +289,7 @@ def __init__(
maximum_resource: Union[float, int],
hyper_parameter_configurations: List[AshaTrial],
reduction_factor: int = 3,
mode: str = "max",
mode: Literal["min", "max"] = "max",
asynchronous_sha: bool = True,
):
# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -484,6 +508,14 @@ def _get_result(self):
],
}

def get_inferior_trials(self) -> List[AshaTrial]:
"""Get trials which can't be best a trial."""
finished_trials = []
for rung in self._rungs[:-1]:
finished_trials.extend(rung.get_inferior_trials(self._mode))

return finished_trials


class HyperBand(HpoBase):
"""It implements the Asyncronous HyperBand scheduler with iterations only.
Expand Down Expand Up @@ -978,3 +1010,11 @@ def print_result(self):
)
for bracket in self._brackets.values():
bracket.print_result()

def get_inferior_trials(self) -> List[AshaTrial]: # type: ignore[override]
"""Get trials which can't be best a trial."""
finished_trials = []
for bracket in self._brackets.values():
finished_trials.extend(bracket.get_inferior_trials())

return finished_trials
59 changes: 17 additions & 42 deletions tests/unit/cli/utils/test_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,13 @@ def test_init_fix_batch_size(self, cls_task_env, diff_from_min_bs):
hpo_runner = HpoRunner(cls_task_env, train_dataset_size, 10, "fake_path")
assert batch_size_name in hpo_runner._fixed_hp

@pytest.fixture
def mock_thread(self, mocker) -> MagicMock:
mock_thread = mocker.patch.object(hpo, "Thread")
return mock_thread

@e2e_pytest_unit
def test_run_hpo(self, mocker, cls_task_env):
def test_run_hpo(self, mocker, cls_task_env, mock_thread):
cls_task_env.model = None
hpo_runner = HpoRunner(cls_task_env, 100, 10, "fake_path")
mock_run_hpo_loop = mocker.patch("otx.cli.utils.hpo.run_hpo_loop")
Expand All @@ -477,7 +482,7 @@ def test_run_hpo(self, mocker, cls_task_env):
mock_hb.assert_called() # make hyperband

@e2e_pytest_unit
def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env):
def test_run_hpo_w_dataset_smaller_than_batch(self, mocker, cls_task_env, mock_thread):
cls_task_env.model = None
hpo_runner = HpoRunner(cls_task_env, 2, 10, "fake_path")
mock_run_hpo_loop = mocker.patch("otx.cli.utils.hpo.run_hpo_loop")
Expand All @@ -494,6 +499,8 @@ class TestTrainer:
def setup(self, tmp_dir):
self.weight_format = "epoch_{}.pth"
self.hpo_workdir = Path(tmp_dir) / "hpo_dir"
self.hpo_workdir.mkdir()
self.trial_id = "1"

@pytest.fixture
def tmp_dir(self):
Expand All @@ -519,6 +526,8 @@ def mock_task(self, mocker, tmp_dir):
fake_project_path.mkdir(parents=True)
for i in range(1, 5):
(fake_project_path / self.weight_format.format(i)).write_text("fake")
with (self.hpo_workdir / f"{self.trial_id}.json").open("w") as f:
json.dump({"id": self.trial_id, "score": {"1": 1, "2": 2, "3": 5, "4": 4}}, f)

mock_get_train_task = mocker.patch.object(TaskEnvironmentManager, "get_train_task")
mock_task = mocker.MagicMock()
Expand Down Expand Up @@ -552,8 +561,12 @@ def test_run(self, mocker, cls_template_path, mock_task, tmp_dir):
# check
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
assert self.hpo_workdir.exists() # make a directory to copy weight
for i in range(1, 5): # check model weights are copied
assert (self.hpo_workdir / "weight" / trial_id / self.weight_format.format(i)).exists()
assert (
self.hpo_workdir / "weight" / trial_id / self.weight_format.format(3)
).exists() # check best weight exists
assert (
self.hpo_workdir / "weight" / trial_id / self.weight_format.format(4)
).exists() # check last weight exists

mock_task.train.assert_called() # check task.train() is called

Expand Down Expand Up @@ -589,44 +602,6 @@ def test_run_trial_already_done(self, mocker, cls_template_path, mock_task, tmp_
mock_report_func.assert_called_once_with(0, 0, done=True) # finilize report
mock_task.train.assert_not_called() # check task.train() is called

@e2e_pytest_unit
def test_delete_unused_model_weight(self, mocker, cls_template_path):
# prepare
trial0_weight_dir = self.hpo_workdir / "weight" / "0"
mocker.patch(
"otx.cli.utils.hpo.TaskManager.get_latest_weight", return_value=str(trial0_weight_dir / "latest.pth")
)
mocker.patch("otx.cli.utils.hpo.get_best_hpo_weight", return_value=str(trial0_weight_dir / "best.pth"))

self.hpo_workdir.mkdir()
(self.hpo_workdir / "0.json").touch()
for i in range(2):
weight_dir = self.hpo_workdir / "weight" / str(i)
weight_dir.mkdir(parents=True)
(weight_dir / "latest.pth").touch()
(weight_dir / "best.pth").touch()
(weight_dir / "unused.pth").touch()

# run
trainer = Trainer(
hp_config={"configuration": {"iterations": 10}, "id": "1"},
report_func=mocker.MagicMock(),
model_template=find_and_parse_model_template(cls_template_path),
data_roots=mocker.MagicMock(),
task_type=TaskType.CLASSIFICATION,
hpo_workdir=self.hpo_workdir,
initial_weight_name="fake",
metric="fake",
)
trainer._delete_unused_model_weight()

assert sorted([f.name for f in (self.hpo_workdir / "weight" / "0").iterdir()]) == sorted(
["latest.pth", "best.pth"]
)
assert sorted([f.name for f in (self.hpo_workdir / "weight" / "1").iterdir()]) == sorted(
["latest.pth", "best.pth", "unused.pth"]
)


class TestHpoCallback:
@e2e_pytest_unit
Expand Down
Loading