diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index cda12a7321f..c2968323fb1 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -737,7 +737,7 @@ def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Tensor) -> return super().lr_scheduler_step(scheduler=scheduler, metric=metric) def patch_optimizer_and_scheduler_for_hpo(self) -> None: - """Patch optimizer and scheduler for hyperparameter optimization. + """Patch optimizer and scheduler for hyperparameter optimization and adaptive batch size. This is inplace function changing inner states (`optimizer_callable` and `scheduler_callable`). Both will be changed to be picklable. In addition, `optimizer_callable` is changed diff --git a/src/otx/core/optimizer/callable.py b/src/otx/core/optimizer/callable.py index 2d8ea1177b5..ddb11d260ff 100644 --- a/src/otx/core/optimizer/callable.py +++ b/src/otx/core/optimizer/callable.py @@ -20,6 +20,9 @@ class OptimizerCallableSupportHPO: """Optimizer callable supports OTX hyper-parameter optimization (HPO) algorithm. + It makes OptimizerCallable pickelable and accessible to parameters. + It is used for HPO and adaptive batch size. + Args: optimizer_cls: Optimizer class type or string class import path. See examples for details. optimizer_kwargs: Keyword arguments used for the initialization of the given `optimizer_cls`. diff --git a/src/otx/core/schedulers/callable.py b/src/otx/core/schedulers/callable.py index 5667d40a9f1..e854f1bbb24 100644 --- a/src/otx/core/schedulers/callable.py +++ b/src/otx/core/schedulers/callable.py @@ -23,6 +23,9 @@ class SchedulerCallableSupportHPO: """LR scheduler callable supports OTX hyper-parameter optimization (HPO) algorithm. + It makes SchedulerCallable pickelable and accessible to parameters. + It is used for HPO and adaptive batch size. + Args: scheduler_cls: `LRScheduler` class type or string class import path. See examples for details. scheduler_kwargs: Keyword arguments used for the initialization of the given `scheduler_cls`. diff --git a/src/otx/engine/adaptive_bs/__init__.py b/src/otx/engine/adaptive_bs/__init__.py new file mode 100644 index 00000000000..67d70b7de47 --- /dev/null +++ b/src/otx/engine/adaptive_bs/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""API for adaptive batch size.""" + + +from .adaptive_bs_api import adapt_batch_size + +__all__ = ["adapt_batch_size"] diff --git a/src/otx/engine/adaptive_bs/adaptive_bs_api.py b/src/otx/engine/adaptive_bs/adaptive_bs_api.py new file mode 100644 index 00000000000..887d2cadf66 --- /dev/null +++ b/src/otx/engine/adaptive_bs/adaptive_bs_api.py @@ -0,0 +1,173 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Algorithm to find a proper batch size which is fit to current GPU device for tasks using mmcv.""" + +from __future__ import annotations + +import logging +import os +from functools import partial +from math import sqrt +from typing import TYPE_CHECKING, Any + +from lightning import Callback +from lightning.pytorch.loggers.logger import DummyLogger +from torch.cuda import is_available as is_cuda_available + +from otx.core.types.task import OTXTaskType +from otx.utils.utils import is_xpu_available + +from .bs_search_algo import BsSearchAlgo + +if TYPE_CHECKING: + from lightning import LightningModule, Trainer + + from otx.engine.engine import Engine + +logger = logging.getLogger(__name__) + + +def adapt_batch_size( + engine: Engine, + not_increase: bool = True, + callbacks: list[Callback] | Callback | None = None, + **train_args, +) -> None: + """Change the actual batch size depending on the current GPU status. + + If not_increase is True, check current batch size is available to GPU and if not, decrease batch size. + If not_increase is False, increase batch size to use most of GPU memory. + + Args: + engine (Engine): engine instnace. + not_increase (bool) : Whether adapting batch size to larger value than default value or not. + callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None. + """ + if not (is_cuda_available() or is_xpu_available()): + msg = "Adaptive batch size supports CUDA or XPU." + raise RuntimeError(msg) + if engine.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: # type: ignore[has-type] + msg = "Zero shot visual prompting task doesn't support adaptive batch size." + raise RuntimeError(msg) + + engine.model.patch_optimizer_and_scheduler_for_hpo() + default_bs = engine.datamodule.config.train_subset.batch_size + + if "ADAPTIVE_BS_FOR_DIST" in os.environ: # main process of distributed training already executes adapt_batch_size + new_batch_size = int(os.environ["ADAPTIVE_BS_FOR_DIST"]) + if default_bs != new_batch_size: + _apply_new_batch_size(engine, new_batch_size) + return + + train_func = partial(_train_model, engine=engine, callbacks=callbacks, **_adjust_train_args(train_args)) + bs_search_algo = BsSearchAlgo( + train_func=train_func, + default_bs=default_bs, + max_bs=( + len(engine.datamodule.subsets[engine.datamodule.config.train_subset.subset_name]) // engine.device.devices + ), + ) + if not_increase: + new_batch_size = bs_search_algo.auto_decrease_batch_size() + else: + new_batch_size = bs_search_algo.find_big_enough_batch_size() + + if engine.device.devices != 1: + os.environ["ADAPTIVE_BS_FOR_DIST"] = str(new_batch_size) + + if default_bs != new_batch_size: + origin_lr = engine.model.optimizer_callable.optimizer_kwargs["lr"] # type: ignore[attr-defined] + _apply_new_batch_size(engine, new_batch_size) + msg = ( + "Adapting batch size is done.\n" + f"Batch size is adapted : {default_bs} -> {new_batch_size}\n" + f"learning rate is adapted : {origin_lr} -> {engine.model.optimizer_callable.optimizer_kwargs['lr']}" # type: ignore[attr-defined] + ) + logger.info(msg) + else: + logger.info("Adapting batch size is done. Batch size isn't changed.") + + +def _adjust_train_args(train_args: dict[str, Any]) -> dict[str, Any]: + train_args.update(train_args.pop("kwargs", {})) + train_args.pop("self", None) + train_args.pop("run_hpo", None) + train_args.pop("adaptive_bs") + return train_args + + +def _train_model(bs: int, engine: Engine, callbacks: list[Callback] | Callback | None = None, **train_args) -> None: + if bs <= 0: + msg = f"Batch size should be greater than 0, but {bs} is given." + raise ValueError(msg) + if engine.device.devices != 1: # TODO(Eunwoo): Need to change after device api is updated + engine._cache.update(devices=1) # noqa: SLF001 + + engine.datamodule.config.train_subset.batch_size = bs + engine.train(callbacks=_register_callback(callbacks), **train_args) + + +def _register_callback(callbacks: list[Callback] | Callback | None = None) -> list[Callback]: + if isinstance(callbacks, Callback): + callbacks = [callbacks] + elif callbacks is None: + callbacks = [] + callbacks.append(BatchSizeFinder()) + return callbacks + + +class BatchSizeFinder(Callback): + """This callback makes trainer run specified iteration and exit. + + Args: + steps_per_trial: number of steps to run with a given batch size. + Ideally 1 should be enough to test if an OOM error occurs, however in practice a few are needed. + """ + + def __init__( + self, + steps_per_trial: int = 3, + ) -> None: + self._steps_per_trial = steps_per_trial + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str | None = None) -> None: + """Check current stage is fit.""" + if stage != "fit": + msg = "Adaptive batch size supports only training." + raise RuntimeError(msg) + + def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Run steps_per_trial iterations and exit.""" + _scale_batch_reset_params(trainer, self._steps_per_trial) + _try_loop_run(trainer) + + +def _try_loop_run(trainer: Trainer) -> None: + loop = trainer._active_loop # noqa: SLF001 + if loop is None: + msg = "There is no active loop." + raise RuntimeError(msg) + loop.restarting = False + loop.run() + + +def _scale_batch_reset_params(trainer: Trainer, steps_per_trial: int) -> None: + trainer.logger = DummyLogger() if trainer.logger is not None else None + trainer.callbacks = [] + + loop = trainer._active_loop # noqa: SLF001 + if loop is None: + msg = "There is no active loop." + raise RuntimeError(msg) + trainer.limit_train_batches = 1.0 + if trainer.limit_val_batches != 0: + trainer.limit_val_batches = steps_per_trial + trainer.fit_loop.epoch_loop.max_steps = steps_per_trial + + +def _apply_new_batch_size(engine: Engine, new_batch_size: int) -> None: + origin_bs = engine.datamodule.config.train_subset.batch_size + if new_batch_size == origin_bs: + return + engine.datamodule.config.train_subset.batch_size = new_batch_size + engine.model.optimizer_callable.optimizer_kwargs["lr"] *= sqrt(new_batch_size / origin_bs) # type: ignore[attr-defined] diff --git a/src/otx/engine/adaptive_bs/bs_search_algo.py b/src/otx/engine/adaptive_bs/bs_search_algo.py new file mode 100644 index 00000000000..a029d10aa6d --- /dev/null +++ b/src/otx/engine/adaptive_bs/bs_search_algo.py @@ -0,0 +1,278 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Algorithm to find a proper batch size which is fit to current device.""" + +from __future__ import annotations + +import logging +import multiprocessing as mp +import queue +from typing import Any, Callable + +import torch + +from otx.utils.utils import is_xpu_available + +logger = logging.getLogger(__name__) + + +class BsSearchAlgo: + """Algorithm class to find optimal batch size. + + Args: + train_func (Callable[[int], Any]): Training function with single arugment to set batch size. + default_bs (int): Default batch size. It should be bigger than 0. + max_bs (int): Maximum batch size. It should be bigger than 0. + """ + + def __init__( + self, + train_func: Callable[[int], Any], + default_bs: int, + max_bs: int, + ): + if default_bs <= 0: + msg = "Batch size should be bigger than 0." + raise ValueError(msg) + if max_bs <= 0: + msg = "train data set size should be bigger than 0." + raise ValueError(msg) + + if max_bs < default_bs: + default_bs = max_bs + + self._train_func = train_func + self._default_bs = default_bs + self._max_bs = max_bs + self._bs_try_history: dict[int, int] = {} + self._total_mem = _get_total_memory_size() + self._mem_lower_bound = 0.8 * self._total_mem + self._mem_upper_bound = 0.85 * self._total_mem + self._mp_ctx = mp.get_context("spawn") + + def _try_batch_size(self, bs: int) -> tuple[bool, int]: + trial_queue = self._mp_ctx.Queue() + proc = self._mp_ctx.Process(target=_run_trial, args=(self._train_func, bs, trial_queue)) + proc.start() + output = None + while proc.is_alive(): + try: + output = trial_queue.get(timeout=1) + break + except queue.Empty: + pass + proc.join() + if output is None: + msg = "There is no output from the trial for adaptive batch size." + raise RuntimeError(msg) + + oom = output["oom"] + max_memory_reserved = output["max_memory_reserved"] + + if not oom: + self._bs_try_history[bs] = max_memory_reserved + + logger.debug( + f"Adapting Batch size => bs : {bs}, OOM : {oom}, " + f"memory usage : {max_memory_reserved / self._total_mem}%", + ) + + return oom, max_memory_reserved + + @staticmethod + def _get_even_center_val(val1: int, val2: int) -> int: + ret = (val1 + val2) // 2 + if ret % 2 == 1: + ret += 1 + return ret + + def auto_decrease_batch_size(self) -> int: + """Decrease batch size if default batch size isn't fit to current device. + + Returns: + int: Proper batch size possibly decreased as default value isn't fit + """ + available_bs = 0 + current_bs = self._default_bs + lowest_unavailable_bs = self._default_bs + 2 + + while True: + oom, max_memory_reserved = self._try_batch_size(current_bs) + + # If memory usage is too close to limit, OOM can be raised during training + if oom or max_memory_reserved > self._mem_upper_bound: + if current_bs < lowest_unavailable_bs: + lowest_unavailable_bs = current_bs + current_bs = self._get_even_center_val(current_bs, available_bs) + else: + available_bs = current_bs + current_bs = self._get_even_center_val(current_bs, lowest_unavailable_bs) + + if lowest_unavailable_bs - available_bs <= 2: + break + + if available_bs == 0: + msg = "Current device can't train model even with 2." + raise RuntimeError(msg) + + return available_bs + + def find_big_enough_batch_size(self, drop_last: bool = False) -> int: + """Find a big enough batch size. + + This function finds a big enough batch size by training with various batch sizes. + It estimate a batch size using equation is estimated using training history. + The reason why using the word "big enough" is that it tries to find not maxmium but big enough value which uses + memory between lower and upper bound. + + Args: + drop_last (bool): Whether to drop the last incomplete batch. + + Raises: + RuntimeError: If training with batch size 2 can't be run, raise an error. + + Returns: + int: Big enough batch size. + """ + estimated_bs = self._default_bs + + # try default batch size + oom, bs_mem_usage = self._try_batch_size(estimated_bs) + if oom or bs_mem_usage > self._mem_upper_bound: + self._default_bs -= 2 + if self._default_bs <= 0: + msg = "Current device can't train model even with 2." + raise RuntimeError(msg) + + return self.auto_decrease_batch_size() + + # try default batch size + 2 + estimated_bs += 2 + if estimated_bs > self._max_bs: + return self._default_bs + oom, bs_mem_usage = self._try_batch_size(estimated_bs) + if oom or bs_mem_usage > self._mem_upper_bound: + return self._default_bs + + # estimate batch size using equation + estimation_pct = 0.82 + while True: + estimated_bs = self._estimate_batch_size(estimation_pct) + if estimated_bs in self._bs_try_history: + break + oom, mem_usage = self._try_batch_size(estimated_bs) + + if oom: + estimation_pct -= 0.1 + if estimation_pct <= 0: + estimated_bs = self._default_bs + 2 + break + elif self._mem_lower_bound <= mem_usage <= self._mem_upper_bound: + break + else: + estimation_pct = 0.82 + + if drop_last and (self._max_bs // 2 < estimated_bs < self._max_bs): + estimated_bs = self._max_bs // 2 + + return estimated_bs + + def _estimate_batch_size(self, estimation_pct: float) -> int: + if len(self._bs_try_history) < 2: + msg = "At least two trials should be done without OOM to estimate batch size." + raise RuntimeError(msg) + + def distance_from_bound(val: tuple[int, int | float]) -> float: + if val[1] < self._mem_lower_bound: + # if memory usage is same, then higher batch size is preferred + return self._mem_lower_bound - val[1] - val[0] / 10000 + if self._mem_upper_bound < val[1]: + # if memory usage is same, then lower batch size is preferred + return val[1] - self._mem_upper_bound + val[0] / 10000 + return min(abs(self._mem_lower_bound - val[1]), abs(val[1] - self._mem_upper_bound)) + + bs_arr = sorted(self._bs_try_history.items(), key=lambda x: x[0]) + for idx in range(len(bs_arr) - 1, -1, -1): + if bs_arr[idx][1] < self._mem_upper_bound: + cur_max_bs_idx = idx + break + else: + logger.warning("All batch size tried used more memory size than upper bound.") + return bs_arr[0][0] + + def check_bs_suitable(estimated_bs: int) -> bool: + # Check batch size is between largest bs which uses lower memory than uppper bound + # and smallest bs which uses higher memory than upper bound. + if estimated_bs >= bs_arr[cur_max_bs_idx][0]: + if cur_max_bs_idx + 1 < len(bs_arr): + if estimated_bs < bs_arr[cur_max_bs_idx + 1][0]: + return True + else: + return True + return False + + x_idx, y_idx = 0, len(bs_arr) - 1 + + while x_idx < y_idx: + graident = (bs_arr[y_idx][1] - bs_arr[x_idx][1]) / (bs_arr[y_idx][0] - bs_arr[x_idx][0]) + b = bs_arr[y_idx][1] - graident * bs_arr[y_idx][0] + if graident != 0: + estimated_bs = round(((self._total_mem * estimation_pct) - b) / (graident * 2)) * 2 + if check_bs_suitable(estimated_bs): + break + + if distance_from_bound(bs_arr[x_idx + 1]) < distance_from_bound(bs_arr[y_idx - 1]): + x_idx += 1 + else: + y_idx -= 1 + + if x_idx == y_idx: + if check_bs_suitable(bs_arr[cur_max_bs_idx][0] + 2): + estimated_bs = bs_arr[cur_max_bs_idx][0] + 2 + else: + estimated_bs = bs_arr[cur_max_bs_idx][0] + + if estimated_bs > self._max_bs: + estimated_bs = self._max_bs + + return estimated_bs + + +def _run_trial(train_func: Callable[[int], Any], bs: int, trial_queue: mp.Queue) -> None: + mp.set_start_method(None, True) # reset mp start method + + oom = False + try: + train_func(bs) + except RuntimeError as e: + if str(e).startswith("CUDA out of memory.") or str(e).startswith( # CUDA OOM + "Allocation is out of device memory on current platform.", # XPU OOM + ): + oom = True + else: + raise + except AttributeError as e: + if str(e).startswith("'NoneType' object has no attribute 'best_model_path'"): + pass + else: + raise + + trial_queue.put( + { + "oom": oom, + "max_memory_reserved": _get_max_memory_reserved(), + }, + ) + + +def _get_max_memory_reserved() -> int: + if is_xpu_available(): + return torch.xpu.max_memory_reserved(device=None) + return torch.cuda.max_memory_reserved(device=None) + + +def _get_total_memory_size() -> int: + if is_xpu_available(): + return torch.xpu.get_device_properties(0).total_memory + _, total_mem = torch.cuda.mem_get_info() + return total_mem diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 2ed0e539678..55bcaf2ad02 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -30,6 +30,7 @@ from otx.core.utils.cache import TrainerArgumentsCache from otx.utils.utils import is_xpu_available +from .adaptive_bs import adapt_batch_size from .hpo import execute_hpo, update_hyper_parameter from .utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator @@ -183,6 +184,7 @@ def train( run_hpo: bool = False, hpo_config: HpoConfig = HpoConfig(), # noqa: B008 https://github.com/omni-us/jsonargparse/issues/423 checkpoint: PathLike | None = None, + adaptive_bs: Literal["None", "Safe", "Full"] = "None", **kwargs, ) -> dict[str, Any]: """Trains the model using the provided LightningModule and OTXDataModule. @@ -203,6 +205,9 @@ def train( run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. hpo_config (HpoConfig | None, optional): Configuration for HPO. checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. + adaptive_bs (Literal["None", "Safe", "Full"]): + Change the actual batch size depending on the current GPU status. + Safe => Prevent GPU out of memory. Full => Find a batch size using most of GPU memory. **kwargs: Additional keyword arguments for pl.Trainer configuration. Returns: @@ -240,6 +245,9 @@ def train( """ checkpoint = checkpoint if checkpoint is not None else self.checkpoint + if adaptive_bs != "None": + adapt_batch_size(engine=self, **locals(), not_increase=(adaptive_bs != "Full")) + if run_hpo: best_config, best_trial_weight = execute_hpo(engine=self, **locals()) if best_config is not None: diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/engine/hpo/hpo_api.py index c974e19b998..46131dfdb3f 100644 --- a/src/otx/engine/hpo/hpo_api.py +++ b/src/otx/engine/hpo/hpo_api.py @@ -67,7 +67,7 @@ def execute_hpo( msg = "Zero shot visual prompting task doesn't support HPO." raise RuntimeError(msg) if "anomaly.padim" in str(type(engine.model)).lower(): - msg = "Padim doesn't need HPO. HPO is skipped." + msg = "Padim doesn't need HPO." raise RuntimeError(msg) engine.model.patch_optimizer_and_scheduler_for_hpo() @@ -81,6 +81,13 @@ def execute_hpo( hpo_workdir=hpo_workdir, callbacks=callbacks, ) + if ( + train_args.get("adaptive_bs", None) == "Full" + and "datamodule.config.train_subset.batch_size" in hpo_configurator.hpo_config["search_space"] + ): + logger.info("Because adaptive_bs is set as Full, batch size is excluded from HPO.") + hpo_configurator.hpo_config["search_space"].pop("datamodule.config.train_subset.batch_size") + if (hpo_algo := hpo_configurator.get_hpo_algo()) is None: logger.warning("HPO is skipped.") return None, None @@ -288,6 +295,7 @@ def _adjust_train_args(train_args: dict[str, Any]) -> dict[str, Any]: train_args.update(train_args.pop("kwargs", {})) train_args.pop("self", None) train_args.pop("run_hpo", None) + train_args.pop("adaptive_bs", None) return train_args diff --git a/tests/conftest.py b/tests/conftest.py index 0e4dcaaa958..121b2102d52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -368,3 +368,19 @@ def fxt_hlabel_multilabel_info() -> HLabelInfo: ["Spade_King", "Spade"], ], ) + + +@pytest.fixture() +def fxt_xpu_support_task() -> list[OTXTaskType]: + return [ + OTXTaskType.ANOMALY_CLASSIFICATION, + OTXTaskType.ANOMALY_DETECTION, + OTXTaskType.ANOMALY_SEGMENTATION, + OTXTaskType.MULTI_CLASS_CLS, + OTXTaskType.MULTI_LABEL_CLS, + OTXTaskType.H_LABEL_CLS, + OTXTaskType.DETECTION, + OTXTaskType.ROTATED_DETECTION, + OTXTaskType.DETECTION_SEMI_SL, + OTXTaskType.SEMANTIC_SEGMENTATION, + ] diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 07acab85b16..cf65fa0922d 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -1,6 +1,7 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations from pathlib import Path @@ -478,3 +479,59 @@ def test_otx_hpo_e2e( return assert len([val for val in hpo_work_dor.rglob("*.json") if str(val.stem).isdigit()]) == 2 + + +@pytest.mark.parametrize("task", pytest.TASK_LIST) +@pytest.mark.parametrize("bs_adapt_type", ["Safe", "Full"]) +def test_otx_adaptive_bs_e2e( + task: OTXTaskType, + tmp_path: Path, + fxt_accelerator: str, + fxt_target_dataset_per_task: dict, + fxt_cli_override_command_per_task: dict, + fxt_open_subprocess: bool, + fxt_xpu_support_task: list[OTXTaskType], + bs_adapt_type: str, +) -> None: + """ + Test adaptive batch size e2e commands with default template of each task. + + Args: + task (OTXTaskType): The task to run adaptive batch size with. + tmp_path (Path): The temporary path for storing the training outputs. + + Returns: + None + """ + if fxt_accelerator not in ["gpu", "xpu"]: + pytest.skip("Adaptive batch size only supports GPU and XPU.") + if fxt_accelerator == "xpu" and task not in fxt_xpu_support_task: + pytest.skip(f"{task} doesn't support XPU.") + if task not in DEFAULT_CONFIG_PER_TASK: + pytest.skip(f"Task {task} is not supported in the auto-configuration.") + if task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + pytest.skip("ZERO_SHOT_VISUAL_PROMPTING doesn't support adaptive batch size.") + + task = task.lower() + tmp_path_adap_bs = tmp_path / f"otx_adaptive_bs_{task}" + tmp_path_adap_bs.mkdir(parents=True) + + command_cfg = [ + "otx", + "train", + "--task", + task.upper(), + "--data_root", + fxt_target_dataset_per_task[task], + "--work_dir", + str(tmp_path_adap_bs), + "--engine.device", + fxt_accelerator, + "--adaptive_bs", + bs_adapt_type, + "--max_epoch", + "1", + *fxt_cli_override_command_per_task[task], + ] + + run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) diff --git a/tests/unit/engine/adaptive_bs/__init__.py b/tests/unit/engine/adaptive_bs/__init__.py new file mode 100644 index 00000000000..916f3a44b27 --- /dev/null +++ b/tests/unit/engine/adaptive_bs/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py new file mode 100644 index 00000000000..e1e28bbe0cd --- /dev/null +++ b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py @@ -0,0 +1,166 @@ +from unittest.mock import MagicMock + +import pytest +from otx.engine.adaptive_bs import bs_search_algo as target_file +from otx.engine.adaptive_bs.bs_search_algo import BsSearchAlgo + + +class TestBsSearchAlgo: + @pytest.fixture(autouse=True) + def setup_test(self, mocker): + self.mock_torch = mocker.patch.object(target_file, "torch") + self.mock_torch.cuda.mem_get_info.return_value = (1, 10000) + self.mock_mp = mocker.patch.object(target_file, "mp") + mocker.patch.object(target_file, "is_xpu_available", return_value=False) + + def test_init(self, mocker): + BsSearchAlgo(mocker.MagicMock(), 4, 10) + + @pytest.mark.parametrize("default_bs", [-2, 0]) + def test_init_w_wrong_default_bs(self, mocker, default_bs): + with pytest.raises(ValueError, match="Batch size should be bigger than 0."): + BsSearchAlgo(mocker.MagicMock(), default_bs=default_bs, max_bs=10) + + @pytest.mark.parametrize("max_bs", [-2, 0]) + def test_init_w_wrong_max_bs(self, mocker, max_bs): + with pytest.raises(ValueError, match="train data set size should be bigger than 0."): + BsSearchAlgo(mocker.MagicMock(), default_bs=4, max_bs=max_bs) + + def set_mp_process(self, train_func): + def mock_process(target, args) -> MagicMock: # noqa: ARG001 + batch_size = args[-2] + oom = False + mem_usage = 0 + + try: + mem_usage = train_func(batch_size) + except RuntimeError: + oom = True + + trial_queue = args[-1] + trial_queue.get.return_value = {"oom": oom, "max_memory_reserved": mem_usage} + + return MagicMock() + + self.mock_mp.get_context.return_value.Process.side_effect = mock_process + + def get_mock_train_func(self, cuda_oom_bound: int, max_runnable_bs: int): + def mock_train_func(batch_size) -> int: + if batch_size > cuda_oom_bound: + mem_usage = 10000 + msg = "CUDA out of memory." + raise RuntimeError(msg) + if batch_size > max_runnable_bs: + mem_usage = 8500 + 1500 * batch_size / (cuda_oom_bound - max_runnable_bs) + else: + mem_usage = 8500 * batch_size / max_runnable_bs + + self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage + return mem_usage + + self.set_mp_process(mock_train_func) + + return mock_train_func + + def test_try_batch_size(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + batch_size = 40 + + cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) + + assert cuda_oom is False + assert max_memory_reserved == mock_train_func(batch_size) + + def test_try_batch_size_cuda_oom(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=100, max_runnable_bs=80) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + batch_size = 200 + + cuda_oom, _ = bs_search_algo._try_batch_size(batch_size) + + assert cuda_oom is True + + def test_auto_decrease_batch_size(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + adapted_bs = bs_search_algo.auto_decrease_batch_size() + + assert adapted_bs == 80 + + def test_find_max_usable_bs_gpu_memory_too_small(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1) + + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + with pytest.raises(RuntimeError): + bs_search_algo.auto_decrease_batch_size() + + @pytest.mark.parametrize( + ("max_runnable_bs", "max_bs", "expected_bs"), + [ + (100, 1000, None), + (32, 1000, None), + (100, 64, 64), + (66, 1000, None), + ], + ) + def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=max_runnable_bs) + + bs_search_algo = BsSearchAlgo(mock_train_func, 64, max_bs) + adapted_bs = bs_search_algo.find_big_enough_batch_size() + + if expected_bs is None: + assert 7500 <= mock_train_func(adapted_bs) <= 8500 + else: + assert adapted_bs == expected_bs + + def test_find_big_enough_batch_size_gpu_memory_too_small(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1) + + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + with pytest.raises(RuntimeError): + bs_search_algo.find_big_enough_batch_size() + + def test_find_big_enough_batch_size_gradient_zero(self): + def mock_train_func(batch_size) -> int: + if batch_size > 1000: + mem_usage = 10000 + msg = "CUDA out of memory." + raise RuntimeError(msg) + mem_usage = 9000 if batch_size > 100 else 1000 + self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage + return mem_usage + + self.set_mp_process(mock_train_func) + + bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000) + adapted_bs = bs_search_algo.find_big_enough_batch_size() + + assert adapted_bs == 100 + + def test_find_big_enough_batch_size_not_exceed_upper_bound(self): + def mock_train_func(batch_size) -> int: + if batch_size > 1000: + mem_usage = 10000 + msg = "CUDA out of memory." + raise RuntimeError(msg) + mem_usage = 9000 if batch_size > 100 else 1000 + batch_size / 1000 + self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage + return mem_usage + + self.set_mp_process(mock_train_func) + + bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000) + adapted_bs = bs_search_algo.find_big_enough_batch_size() + + assert mock_train_func(adapted_bs) <= 8500 + + def test_find_big_enough_batch_size_drop_last(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=180) + + bs_search_algo = BsSearchAlgo(mock_train_func, 64, 200) + adapted_bs = bs_search_algo.find_big_enough_batch_size(True) + + assert adapted_bs == 100