Skip to content

Commit

Permalink
HPO code flow modification (#3259)
Browse files Browse the repository at this point in the history
* Update HPO documentation (#3235)

* refactor hpo code

* write draft hpo docs

* update test

* bugfix

* align with pre-commit

* fix link format in docs

* update integration test

* hpo_config only gets HpoConfig

* skip padim hpo

* add api to set metric name for HPO

* enable test_otx_hpo test

* align with pre-commit

* disable scheduler HPO

* change warnging to raising an error

* fix way to find a dataset size

* update hpo integration test

* update e2e test

* if metric is loss, change hpo model to min

* pass checkpoint args when need to resume
  • Loading branch information
eunwoosh authored Apr 16, 2024
1 parent 255be7a commit e687cd1
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 147 deletions.
3 changes: 3 additions & 0 deletions src/otx/algo/classification/efficientnet_b0.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)

def _reset_prediction_layer(self, num_classes: int) -> None:
return


class EfficientNetB0ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel):
"""EfficientNetB0 Model for multi-class classification task."""
Expand Down
5 changes: 2 additions & 3 deletions src/otx/core/config/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ class HpoConfig:
save_path: str | None = None
mode: Literal["max", "min"] = "max"
num_trials: int | None = None
num_workers: int = 1
num_workers: int = torch.cuda.device_count() if torch.cuda.is_available() else 1
expected_time_ratio: int | float | None = 4
maximum_resource: int | float | None = None
subset_ratio: float | int | None = None
min_subset_size: int = 500
prior_hyper_parameters: dict | list[dict] | None = None
acceptable_additional_time_ratio: float | int = 1.0
minimum_resource: int | float | None = None
reduction_factor: int = 3
asynchronous_bracket: bool = True
asynchronous_sha: bool = torch.cuda.device_count() != 1
metric_name: str | None = None
4 changes: 1 addition & 3 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def train(
resume: bool = False,
metric: MetricCallable | None = None,
run_hpo: bool = False,
hpo_config: HpoConfig | None = None,
hpo_config: HpoConfig = HpoConfig(), # noqa: B008 https://github.com/omni-us/jsonargparse/issues/423
checkpoint: PathLike | None = None,
**kwargs,
) -> dict[str, Any]:
Expand Down Expand Up @@ -241,8 +241,6 @@ def train(
checkpoint = checkpoint if checkpoint is not None else self.checkpoint

if run_hpo:
if hpo_config is None:
hpo_config = HpoConfig()
best_config, best_trial_weight = execute_hpo(engine=self, **locals())
if best_config is not None:
update_hyper_parameter(self, best_config)
Expand Down
80 changes: 57 additions & 23 deletions src/otx/engine/hpo/hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from otx.utils.utils import get_decimal_point, get_using_dot_delimited_key, remove_matched_files

from .hpo_trial import run_hpo_trial
from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir
from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric

if TYPE_CHECKING:
from lightning import Callback
from lightning.pytorch.cli import OptimizerCallable

from otx.engine.engine import Engine
Expand All @@ -34,45 +35,51 @@

AVAILABLE_HP_NAME_MAP = {
"data.config.train_subset.batch_size": "datamodule.config.train_subset.batch_size",
"optimizer": "optimizer.keywords",
"scheduler": "scheduler.keywords",
"optimizer": "optimizer_callable.optimizer_kwargs",
# "scheduler": "scheduler.keywords", NOTE need to revisit after SchedulerCallableSupportHPO is implemted
}


def execute_hpo(
engine: Engine,
max_epochs: int,
hpo_config: HpoConfig | None = None,
hpo_config: HpoConfig,
progress_update_callback: Callable[[int | float], None] | None = None,
callbacks: list[Callback] | Callback | None = None,
**train_args,
) -> tuple[dict[str, Any] | None, Path | None]:
"""Execute HPO.
Args:
engine (Engine): engine instnace.
max_epochs (int): max epochs to train.
hpo_config (HpoConfig | None, optional): Configuration for HPO.
hpo_config (HpoConfig): Configuration for HPO.
progress_update_callback (Callable[[int | float], None] | None, optional):
callback to update progress. If it's given, it's called with progress every second. Defaults to None.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
Returns:
tuple[dict[str, Any] | None, Path | None]:
best hyper parameters and model weight trained with best hyper parameters. If it doesn't exist,
return None.
"""
if engine.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: # type: ignore[has-type]
logger.warning("Zero shot visual prompting task doesn't support HPO.")
return None, None
msg = "Zero shot visual prompting task doesn't support HPO."
raise RuntimeError(msg)
if "anomaly.padim" in str(type(engine.model)).lower():
msg = "Padim doesn't need HPO. HPO is skipped."
raise RuntimeError(msg)

engine.model.patch_optimizer_and_scheduler_for_hpo()

hpo_workdir = Path(engine.work_dir) / "hpo"
hpo_workdir.mkdir(exist_ok=True)
hpo_configurator = HPOConfigurator(
engine,
max_epochs,
hpo_workdir,
hpo_config,
engine=engine,
max_epochs=max_epochs,
hpo_config=hpo_config,
hpo_workdir=hpo_workdir,
callbacks=callbacks,
)
if (hpo_algo := hpo_configurator.get_hpo_algo()) is None:
logger.warning("HPO is skipped.")
Expand All @@ -88,9 +95,12 @@ def execute_hpo(
hpo_workdir=hpo_workdir,
engine=engine,
max_epochs=max_epochs,
callbacks=callbacks,
metric_name=hpo_config.metric_name,
**_adjust_train_args(train_args),
),
"gpu" if torch.cuda.is_available() else "cpu",
num_parallel_trial=hpo_configurator.hpo_config["num_workers"],
)

best_trial = hpo_algo.get_best_config()
Expand All @@ -113,21 +123,24 @@ class HPOConfigurator:
Args:
engine (Engine): engine instance.
max_epoch (int): max epochs to train.
max_epochs (int): max epochs to train.
hpo_config (HpoConfig): Configuration for HPO.
hpo_workdir (Path | None, optional): HPO work directory. Defaults to None.
hpo_config (HpoConfig | None, optional): Configuration for HPO.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
"""

def __init__(
self,
engine: Engine,
max_epoch: int,
max_epochs: int,
hpo_config: HpoConfig,
hpo_workdir: Path | None = None,
hpo_config: HpoConfig | None = None,
callbacks: list[Callback] | Callback | None = None,
) -> None:
self._engine = engine
self._max_epoch = max_epoch
self._max_epochs = max_epochs
self._hpo_workdir = hpo_workdir if hpo_workdir is not None else Path(engine.work_dir) / "hpo"
self._callbacks = callbacks
self.hpo_config: dict[str, Any] = hpo_config # type: ignore[assignment]

@property
Expand All @@ -136,19 +149,40 @@ def hpo_config(self) -> dict[str, Any]:
return self._hpo_config

@hpo_config.setter
def hpo_config(self, hpo_config: HpoConfig | None) -> None:
train_dataset_size = len(self._engine.datamodule.train_dataloader())
def hpo_config(self, hpo_config: HpoConfig) -> None:
train_dataset_size = len(
self._engine.datamodule.subsets[self._engine.datamodule.config.train_subset.subset_name],
)

if hpo_config.metric_name is None:
if self._callbacks is None:
msg = (
"HPOConfigurator can't find the metric because callback doesn't exist. "
"Please set hpo_config.metric_name."
)
raise RuntimeError(msg)
hpo_config.metric_name = get_metric(self._callbacks)

if "loss" in hpo_config.metric_name and hpo_config.mode == "max":
logger.warning(
f"Because metric for HPO is {hpo_config.metric_name}, hpo_config.mode is changed from max to min.",
)
hpo_config.mode = "min"

self._hpo_config: dict[str, Any] = { # default setting
"save_path": str(self._hpo_workdir),
"num_full_iterations": self._max_epoch,
"num_full_iterations": self._max_epochs,
"full_dataset_size": train_dataset_size,
}

if hpo_config is not None:
self._hpo_config.update(
{key: val for key, val in dataclasses.asdict(hpo_config).items() if val is not None},
)
hb_arg_names = get_callable_args_name(HyperBand)
self._hpo_config.update(
{
key: val
for key, val in dataclasses.asdict(hpo_config).items()
if val is not None and key in hb_arg_names
},
)

if "search_space" not in self._hpo_config:
self._hpo_config["search_space"] = self._get_default_search_space()
Expand Down
25 changes: 12 additions & 13 deletions src/otx/engine/hpo/hpo_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from otx.hpo import TrialStatus
from otx.utils.utils import find_file_recursively, remove_matched_files, set_using_dot_delimited_key

from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir
from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir, get_metric

if TYPE_CHECKING:
from lightning import LightningModule, Trainer
Expand Down Expand Up @@ -51,6 +51,7 @@ def run_hpo_trial(
hpo_workdir: Path,
engine: Engine,
callbacks: list[Callback] | Callback | None = None,
metric_name: str | None = None,
**train_args,
) -> None:
"""Run HPO trial. After it's done, best weight and last weight are saved for later use.
Expand All @@ -61,6 +62,8 @@ def run_hpo_trial(
hpo_workdir (Path): HPO work directory.
engine (Engine): engine instance.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
metric_name (str | None, optional):
metric name to determine trial performance. If it's None, get it from ModelCheckpoint callback.
train_args: Arugments for 'engine.train'.
"""
trial_id = hp_config["id"]
Expand All @@ -69,10 +72,10 @@ def run_hpo_trial(
_set_trial_hyper_parameter(hp_config["configuration"], engine, train_args)

if (checkpoint := _find_last_weight(hpo_weight_dir)) is not None:
engine.checkpoint = checkpoint
train_args["checkpoint"] = checkpoint
train_args["resume"] = True

callbacks = _register_hpo_callback(report_func, callbacks)
callbacks = _register_hpo_callback(report_func, callbacks, metric_name)
_set_to_validate_every_epoch(callbacks, train_args)

with TemporaryDirectory(prefix="OTX-HPO-") as temp_dir:
Expand All @@ -93,23 +96,19 @@ def _find_last_weight(weight_dir: Path) -> Path | None:
return find_file_recursively(weight_dir, "last.ckpt")


def _register_hpo_callback(report_func: Callable, callbacks: list[Callback] | Callback | None) -> list[Callback]:
def _register_hpo_callback(
report_func: Callable,
callbacks: list[Callback] | Callback | None = None,
metric_name: str | None = None,
) -> list[Callback]:
if isinstance(callbacks, Callback):
callbacks = [callbacks]
elif callbacks is None:
callbacks = []
callbacks.append(HPOCallback(report_func, _get_metric(callbacks)))
callbacks.append(HPOCallback(report_func, get_metric(callbacks) if metric_name is None else metric_name))
return callbacks


def _get_metric(callbacks: list[Callback]) -> str:
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
return callback.monitor
error_msg = "Failed to find a metric. There is no ModelCheckpoint in callback list."
raise RuntimeError(error_msg)


def _set_to_validate_every_epoch(callbacks: list[Callback], train_args: dict[str, Any]) -> None:
for callback in callbacks:
if isinstance(callback, AdaptiveTrainScheduling):
Expand Down
41 changes: 40 additions & 1 deletion src/otx/engine/hpo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@

from __future__ import annotations

import inspect
import json
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

from otx.utils.utils import find_file_recursively

if TYPE_CHECKING:
from pathlib import Path

from lightning import Callback


def find_trial_file(hpo_workdir: Path, trial_id: str) -> Path | None:
"""Find a trial file which store trial record.
Expand Down Expand Up @@ -78,3 +83,37 @@ def get_hpo_weight_dir(hpo_workdir: Path, trial_id: str) -> Path:
if not hpo_weight_dir.exists():
hpo_weight_dir.mkdir(parents=True)
return hpo_weight_dir


def get_callable_args_name(module: Callable) -> list[str]:
"""Get arguments name list from callable.
Args:
module (Callable): callable to get arguments name from.
Returns:
list[str]: arguments name list.
"""
return list(inspect.signature(module).parameters)


def get_metric(callbacks: list[Callback] | Callback) -> str:
"""Find a metric name from ModelCheckpoint callback.
Args:
callbacks (list[Callback] | Callback): Callback list.
Raises:
RuntimeError: If ModelCheckpoint doesn't exist, the error is raised.
Returns:
str: metric name.
"""
if not isinstance(callbacks, list):
callbacks = [callbacks]

for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
return callback.monitor
msg = "Failed to find a metric. There is no ModelCheckpoint in callback list."
raise RuntimeError(msg)
13 changes: 0 additions & 13 deletions src/otx/hpo/hpo_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class HpoBase(ABC):
HPO use time about exepected_time_ratio *
train time after HPO times.
maximum_resource (int | float | None, optional): Maximum resource to use for training each trial.
subset_ratio (float | int | None, optional): ratio to how many train dataset to use for each trial.
The lower value is, the faster the speed is.
But If it's too low, HPO can be unstable.
min_subset_size (int, optional) : Minimum size of subset. Default value is 500.
resume (bool, optional): resume flag decide to use previous HPO results.
If HPO completed, you can just use optimized hyper parameters.
If HPO stopped in middle, you can resume in middle.
Expand All @@ -66,8 +62,6 @@ def __init__(
full_dataset_size: int = 0,
expected_time_ratio: int | float | None = None,
maximum_resource: int | float | None = None,
subset_ratio: float | int | None = None,
min_subset_size: int = 500,
resume: bool = False,
prior_hyper_parameters: dict | list[dict] | None = None,
acceptable_additional_time_ratio: float | int = 1.0,
Expand All @@ -81,11 +75,6 @@ def __init__(
if num_trials is not None:
check_positive(num_trials, "num_trials")
check_positive(num_workers, "num_workers")
if subset_ratio is not None and not 0 < subset_ratio <= 1:
error_msg = (
f"subset_ratio should be greater than 0 and lesser than or equal to 1. Your value is {subset_ratio}"
)
raise ValueError(error_msg)

if save_path is None:
save_path = tempfile.mkdtemp(prefix="OTX-hpo-")
Expand All @@ -98,8 +87,6 @@ def __init__(
self.full_dataset_size = full_dataset_size
self.expected_time_ratio = expected_time_ratio
self.maximum_resource: int | float | None = maximum_resource
self.subset_ratio = subset_ratio
self.min_subset_size = min_subset_size
self.resume = resume
self.hpo_status: dict = {}
self.acceptable_additional_time_ratio = acceptable_additional_time_ratio
Expand Down
Loading

0 comments on commit e687cd1

Please sign in to comment.