From 6826313227f3550d8050fbec555fddb98196d2b5 Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Wed, 17 Apr 2024 15:19:20 +0900 Subject: [PATCH] revert autobs --- .../action/adapters/mmaction/task.py | 10 +- .../classification/adapters/mmcls/task.py | 19 +- .../adapters/mmcv/utils/automatic_bs.py | 264 ++++-------------- .../adapters/torch/utils/bs_search_algo.py | 204 ++++++-------- .../detection/adapters/mmdet/task.py | 18 +- .../segmentation/adapters/mmseg/task.py | 19 +- .../adapters/mmcv/utils/test_automatic_bs.py | 41 +-- .../torch/utils/test_bs_search_algo.py | 177 ++++++++---- 8 files changed, 293 insertions(+), 459 deletions(-) diff --git a/src/otx/algorithms/action/adapters/mmaction/task.py b/src/otx/algorithms/action/adapters/mmaction/task.py index 1ba1a55f29e..66a5e981e35 100644 --- a/src/otx/algorithms/action/adapters/mmaction/task.py +++ b/src/otx/algorithms/action/adapters/mmaction/task.py @@ -19,6 +19,7 @@ import time from contextlib import nullcontext from copy import deepcopy +from functools import partial from typing import Dict, Optional, Union import torch @@ -323,13 +324,12 @@ def _train_model( validate = bool(cfg.data.get("val", None)) if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE: + train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False) adapt_batch_size( - train_model, - model, - datasets, + train_func, cfg, - cfg.distributed, - meta=meta, + datasets, + validate, not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE), ) diff --git a/src/otx/algorithms/classification/adapters/mmcls/task.py b/src/otx/algorithms/classification/adapters/mmcls/task.py index f7ea06d6c13..9ae0b5721e8 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/task.py +++ b/src/otx/algorithms/classification/adapters/mmcls/task.py @@ -8,6 +8,7 @@ import time from contextlib import nullcontext from copy import deepcopy +from functools import partial from typing import Any, Dict, Optional, Type, Union import torch @@ -379,6 +380,9 @@ def _train_model( htcore.hpu.ModuleCacher(max_graphs=10)(model=model.backbone, inplace=True) htcore.hpu.ModuleCacher(max_graphs=10)(model=model.head, inplace=True) + if cfg.distributed: + convert_sync_batchnorm(model) + validate = bool(cfg.data.get("val", None)) if validate: val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) @@ -406,22 +410,15 @@ def _train_model( ) if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE: - is_nncf = isinstance(self, NNCFBaseTask) + train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False) adapt_batch_size( - train_model, - model, - datasets, + train_func, cfg, - cfg.distributed, - is_nncf, - meta=meta, + datasets, + isinstance(self, NNCFBaseTask), # nncf needs eval hooks not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE), - model_builder=getattr(self, "model_builder") if is_nncf else None, ) - if cfg.distributed: - convert_sync_batchnorm(model) - train_model( model, datasets, diff --git a/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py b/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py index 016c3da7f50..cfc4b6eb07d 100644 --- a/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py +++ b/src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py @@ -3,26 +3,14 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import inspect -from copy import copy -from importlib import import_module +from copy import deepcopy from math import sqrt -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List import numpy as np -import torch -from mmcv import Config -from mmcv.runner import wrap_fp16_model -from torch import distributed as dist from torch.cuda import is_available as cuda_available -from torch.utils.data import Dataset -from otx.algorithms.common.adapters.mmcv.utils.config_utils import OTXConfig -from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo, sync_batchnorm_2_batchnorm -from otx.algorithms.common.utils import is_xpu_available -from otx.core.data import caching +from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo from otx.utils.logger import get_logger logger = get_logger() @@ -50,142 +38,7 @@ def _set_value_at_dict_in_dict(target: Dict, key_path: str, value): target[keys[-1]] = value -def _build_model(model_builder: Callable, cfg: Config) -> torch.nn.Module: - model = model_builder(cfg) - if cfg.get("fp16", False): - wrap_fp16_model(model) - return model - - -NNCF_PATCH_MODULE = { - "mmcls": "otx.algorithms.classification.adapters.mmcls.nncf.patches", - "mmdet": "otx.algorithms.detection.adapters.mmdet.nncf.patches", - "mmseg": "otx.algorithms.segmentation.adapters.mmseg.nncf.patches", -} - - -def _train_func_single_iter( - batch_size: int, - train_func: Callable, - datasets: List[Dataset], - cfg: OTXConfig, - is_nncf: bool = False, - meta: Optional[Dict[str, Any]] = None, - model: Optional[torch.nn.Module] = None, - model_builder: Optional[Callable] = None, -) -> None: - caching.MemCacheHandlerSingleton.create("null", 0) # initialize mem cache - _set_batch_size(cfg, batch_size) - _set_max_epoch(cfg, 1) # setup for training a single iter to save time - - new_dataset = [SubDataset(datasets[0], batch_size)] - - validate = is_nncf # nncf needs eval hooks - if is_nncf: - pkg_name = inspect.getmodule(train_func).__package__ - for framework in ["mmcls", "mmdet", "mmseg"]: - if framework in pkg_name: - import_module(NNCF_PATCH_MODULE[framework]) - break - else: - framework = None - - if framework == "mmcls": - validate = False # classification task has own custom eval hook - - if model is None: - model = _build_model(model_builder, cfg) - - if is_nncf: - model.nncf._uncompressed_model_accuracy = 0 - - sync_batchnorm_2_batchnorm(model) - - train_func( - model=model, - dataset=new_dataset, - cfg=cfg, - distributed=False, - validate=validate, - meta=meta, - ) - - -def _save_nncf_model_weight(model: torch.nn.Module, cfg: OTXConfig, save_path: Path) -> str: - """Save nncf model weight after nncf finishes to build a model. - - NNCF analyzes and get some statistics when buliding a model, which is time consuming. - To skip this part, nncf model weight is saved and load it on new process. - """ - from otx.algorithms.common.adapters.nncf.compression import NNCFMetaState - - file_path = save_path / "nncf_model.pth" - for custom_hook in cfg.custom_hooks: - if custom_hook["type"] == "CompressionHook": - compression_ctrl = custom_hook["compression_ctrl"].get_compression_state() - break - else: - msg = "CompressionHook doesn't exist in custom hooks." - raise RuntimeError(msg) - - torch.save( - { - "state_dict": model.state_dict(), - "meta": { - "nncf_meta": NNCFMetaState( - state_to_build=cfg.runner.nncf_meta.state_to_build, - data_to_build=cfg.runner.nncf_meta.data_to_build, - compression_ctrl=compression_ctrl, - ), - "nncf_enable_compression": True, - }, - }, - file_path, - ) - - return str(file_path) - - -def _organize_custom_hooks(custom_hooks: List, is_nncf: bool = False) -> None: - # Remove hooks due to reasons below - # for nncf task - # OTXProgressHook and CompressionHook are added when building a model. Need to remove them to avoid duplication. - # for normal task - # OTXProgressHook => prevent progress bar from being 0 and 100 repeatably - # CancelInterfaceHook => avoid segmentation fault - # earlystoppinghook => if eval hook is excluded, this hook makes an error due to absence of score history - # CustomEvalHook => exclude validation in classification task - - if is_nncf: - hooks_to_remove = ["OTXProgressHook", "CompressionHook"] - else: - hooks_to_remove = ["OTXProgressHook", "earlystoppinghook", "CustomEvalHook", "CancelInterfaceHook"] - - idx_hooks_to_remove = [] - for i, hook in enumerate(custom_hooks): - if not is_nncf and hook["type"] == "AdaptiveTrainSchedulingHook": - hook["enable_eval_before_run"] = False - for hook_to_remove in hooks_to_remove: - if hook_to_remove.lower() in hook["type"].lower(): - idx_hooks_to_remove.append(i) - - if idx_hooks_to_remove: - idx_hooks_to_remove.sort() - for i in reversed(idx_hooks_to_remove): - custom_hooks.pop(i) - - -def adapt_batch_size( - train_func: Callable, - model: torch.nn.Module, - datasets: List[Dataset], - cfg: OTXConfig, - distributed: bool = False, - is_nncf: bool = False, - meta: Optional[Dict[str, Any]] = None, - not_increase: bool = True, - model_builder: Optional[Callable] = None, -) -> None: +def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = False, not_increase: bool = True): """Decrease batch size if default batch size isn't fit to current GPU device. This function just setup for single iteration training to reduce time for adapting. @@ -194,69 +47,59 @@ def adapt_batch_size( Args: train_func (Callable): The function to train a model. Only cfg, dataset and meta are passed to the function when invoking it. - model (torch.nn.Module): Model to train. + cfg: Configuration of a training. + meta (Dict): A dict records some meta information of a training. datasets (List): List of datasets. - cfg (OTXConfig): Configuration of a training. - distributed (bool): whether distributed training or not. - is_nncf (bool): Whether nncf or not. - meta (Optional[Dict[str, Any]]): meta information. + validate (bool): Whether do vlidation or not. not_increase (bool) : Whether adapting batch size to larger value than default value or not. - model_builder (Optional[Callable]): - Function for building a model. If it exsits, a model build from model_builder is used instead of the model - in the argument. It's required for nncf because nncf changes model , which prevent model from pickling. """ - if not (cuda_available() or is_xpu_available()): - logger.warning("Skip Auto-adaptive batch size: Adaptive batch size supports CUDA or XPU.") + if not cuda_available(): + logger.warning("Skip Auto-adaptive batch size: CUDA should be available, but it isn't.") return - copied_cfg = copy(cfg) - copied_cfg.custom_hooks = copy(cfg.custom_hooks) - copied_cfg.pop("algo_backend", None) - - if is_nncf: - if model_builder is None: - msg = "model_builder should be possed for building a nncf model." - raise RuntimeError(msg) - temp_dir = TemporaryDirectory("adaptive-bs") - copied_cfg.load_from = _save_nncf_model_weight(model, cfg, Path(temp_dir.name)) - - _organize_custom_hooks(copied_cfg.custom_hooks, is_nncf) - - default_bs = _get_batch_size(cfg) - if not distributed or (rank := dist.get_rank()) == 0: - train_func_kwargs = { - "train_func": train_func, - "datasets": datasets, - "cfg": copied_cfg, - "is_nncf": is_nncf, - "meta": meta, - } - if model_builder is None: - train_func_kwargs["model"] = model - else: - train_func_kwargs["model_builder"] = model_builder - - bs_search_algo = BsSearchAlgo( - train_func=_train_func_single_iter, - train_func_kwargs=train_func_kwargs, - default_bs=default_bs, - max_bs=len(datasets[0]), + def train_func_single_iter(batch_size): + copied_cfg = deepcopy(cfg) + _set_batch_size(copied_cfg, batch_size) + _set_max_epoch(copied_cfg, 1) # setup for training a single iter to reduce time + + # Remove hooks due to reasons below + # OTXProgressHook => prevent progress bar from being 0 and 100 repeatably + # earlystoppinghook => if eval hook is excluded, this hook makes an error due to absence of score history + # CustomEvalHook => exclude validation in classification task + idx_hooks_to_remove = [] + hooks_to_remove = ["OTXProgressHook", "earlystoppinghook", "CustomEvalHook"] + for i, hook in enumerate(copied_cfg.custom_hooks): + if not validate and hook["type"] == "AdaptiveTrainSchedulingHook": + hook["enable_eval_before_run"] = False + for hook_to_remove in hooks_to_remove: + if hook_to_remove.lower() in hook["type"].lower(): + idx_hooks_to_remove.append(i) + + if idx_hooks_to_remove: + idx_hooks_to_remove.sort() + for i in reversed(idx_hooks_to_remove): + del copied_cfg.custom_hooks[i] + + new_datasets = [SubDataset(datasets[0], batch_size)] + + train_func( + dataset=new_datasets, + cfg=copied_cfg, + validate=validate, ) - if not_increase: - new_batch_size = bs_search_algo.auto_decrease_batch_size() - else: - drop_last = cfg.data.get("train_dataloader", {}).get("drop_last", False) - new_batch_size = bs_search_algo.find_big_enough_batch_size(drop_last) - if distributed: - if rank == 0: - total_try_result = torch.tensor([new_batch_size], dtype=torch.int) - else: - total_try_result = torch.empty(1, dtype=torch.int) - total_try_result = total_try_result.cuda() if torch.cuda.is_available() else total_try_result.xpu() - dist.broadcast(total_try_result, src=0) - new_batch_size = total_try_result[0].item() + default_bs = _get_batch_size(cfg) + bs_search_algo = BsSearchAlgo( + train_func=train_func_single_iter, + default_bs=default_bs, + max_bs=len(datasets[0]), + ) + if not_increase: + new_batch_size = bs_search_algo.auto_decrease_batch_size() + else: + drop_last = cfg.data.get("train_dataloader", {}).get("drop_last", False) + new_batch_size = bs_search_algo.find_big_enough_batch_size(drop_last) if default_bs != new_batch_size: _set_batch_size(cfg, new_batch_size) @@ -315,18 +158,11 @@ def __init__(self, fullset, num_samples: int): self.fullset = fullset self.num_samples = num_samples - self._img_indices = { # for class incremental case + self.img_indices = { # for class incremental case "old": [i for i in range(num_samples // 2)], "new": [i for i in range(num_samples // 2, num_samples)], } - @property - def img_indices(self): - """img_indices getter.""" - img_indices = copy(getattr(self.fullset, "img_indices", {})) - img_indices.update(self._img_indices) - return img_indices - def __len__(self) -> int: """Get length of subset.""" return self.num_samples diff --git a/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py b/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py index 899a884fd0a..eaf8c1116e6 100644 --- a/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py +++ b/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py @@ -1,57 +1,28 @@ -"""Algorithm to find a proper batch size which is fit to current device.""" +"""Algorithm to find a proper batch size which is fit to current GPU device.""" # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import multiprocessing as mp -import queue -from typing import Any, Callable, Dict, Tuple +from typing import Callable, Dict, Tuple import torch +import torch.distributed as dist -from otx.algorithms.common.utils import is_xpu_available from otx.utils.logger import get_logger logger = get_logger() -def _run_trial(train_func: Callable, train_func_kwargs: Dict[str, Any], bs: int, trial_queue: mp.Queue) -> None: - mp.set_start_method(None, True) # reset mp start method - - oom = False - try: - kwargs = train_func_kwargs - kwargs["batch_size"] = bs - train_func(**kwargs) - except RuntimeError as e: - if str(e).startswith("CUDA out of memory.") or str(e).startswith( # CUDA OOM - "Allocation is out of device memory on current platform." # XPU OOM - ): - oom = True - else: - raise e - - max_memory_reserved = _get_max_memory_reserved() - - trial_queue.put( - { - "oom": oom, - "max_memory_reserved": max_memory_reserved, - } - ) - - class BsSearchAlgo: """Algorithm class to find optimal batch size. Args: train_func (Callable[[int], None]): Training function with single arugment to set batch size. - train_func_kwargs (Dict[str, Any]): Keyword arguments for train_func. default_bs (int): Default batch size. It should be bigger than 0. max_bs (int): Maximum batch size. It should be bigger than 0. """ - def __init__(self, train_func: Callable, train_func_kwargs: Dict[str, Any], default_bs: int, max_bs: int): + def __init__(self, train_func: Callable[[int], None], default_bs: int, max_bs: int): if default_bs <= 0: raise ValueError("Batch size should be bigger than 0.") if max_bs <= 0: @@ -61,45 +32,62 @@ def __init__(self, train_func: Callable, train_func_kwargs: Dict[str, Any], defa default_bs = max_bs self._train_func = train_func - self._train_func_kwargs = train_func_kwargs self._default_bs = default_bs self._max_bs = max_bs self._bs_try_history: Dict[int, int] = {} - self._total_mem = _get_total_memory_size() + _, self._total_mem = torch.cuda.mem_get_info() self._mem_lower_bound = 0.8 * self._total_mem self._mem_upper_bound = 0.85 * self._total_mem - self._mp_ctx = mp.get_context("spawn") def _try_batch_size(self, bs: int) -> Tuple[bool, int]: - trial_queue = self._mp_ctx.Queue() - proc = self._mp_ctx.Process( - target=_run_trial, args=(self._train_func, self._train_func_kwargs, bs, trial_queue) - ) - proc.start() - output = None - while proc.is_alive(): - try: - output = trial_queue.get(timeout=1) - break - except queue.Empty: - pass - proc.join() - if output is None: - msg = "There is no output from the trial for adaptive batch size." - raise RuntimeError(msg) + cuda_oom = False + torch.cuda.reset_max_memory_cached(device=None) + torch.cuda.empty_cache() + + try: + self._train_func(bs) + except RuntimeError as e: + if str(e).startswith("CUDA out of memory."): + cuda_oom = True + else: + raise e + + max_memory_reserved = torch.cuda.max_memory_reserved(device=None) + + if dist.is_initialized(): # Aggregate all results and broadcast to all processes + rank = dist.get_rank() + try_result = torch.tensor([int(cuda_oom), max_memory_reserved], dtype=torch.int64).cuda() + + if rank == 0: + try_result_arr = [torch.empty(2, dtype=torch.int64).cuda() for _ in range(dist.get_world_size())] + dist.gather(try_result, gather_list=try_result_arr, dst=0) + else: + dist.gather(try_result, dst=0) + + if rank == 0: + try_result_arr = torch.stack(try_result_arr) + cuda_oom = torch.any(try_result_arr[:, 0]) # type: ignore + max_memory_reserved = torch.max(try_result_arr[:, 1]) # type: ignore + total_try_result = torch.tensor([cuda_oom, max_memory_reserved], dtype=torch.int64).cuda() + else: + total_try_result = torch.empty(2, dtype=torch.int64).cuda() + + dist.broadcast(total_try_result, src=0) - oom = output["oom"] - max_memory_reserved = output["max_memory_reserved"] + cuda_oom = total_try_result[0].bool().item() + max_memory_reserved = total_try_result[1].item() - if not oom: + if not cuda_oom: + # Because heapq only supports min heap, use negatized batch size self._bs_try_history[bs] = max_memory_reserved logger.debug( - f"Adapting Batch size => bs : {bs}, OOM : {oom}, " - f"memory usage : {max_memory_reserved / self._total_mem}%" + f"Adapting Batch size => bs : {bs}, CUDA_OOM : {cuda_oom}, " + f"GPU memory usage : {max_memory_reserved / self._total_mem}%" ) + torch.cuda.empty_cache() - return oom, max_memory_reserved + return cuda_oom, max_memory_reserved @staticmethod def _get_even_center_val(val1: int, val2: int) -> int: @@ -109,7 +97,7 @@ def _get_even_center_val(val1: int, val2: int) -> int: return ret def auto_decrease_batch_size(self) -> int: - """Decrease batch size if default batch size isn't fit to current device. + """Decrease batch size if default batch size isn't fit to current GPU device. Returns: int: Proper batch size possibly decreased as default value isn't fit @@ -119,10 +107,10 @@ def auto_decrease_batch_size(self) -> int: lowest_unavailable_bs = self._default_bs + 2 while True: - oom, max_memory_reserved = self._try_batch_size(current_bs) + cuda_oom, max_memory_reserved = self._try_batch_size(current_bs) - # If memory usage is too close to limit, OOM can be raised during training - if oom or max_memory_reserved > self._mem_upper_bound: + # If GPU memory usage is too close to limit, CUDA OOM can be raised during training + if cuda_oom or max_memory_reserved > self._mem_upper_bound: if current_bs < lowest_unavailable_bs: lowest_unavailable_bs = current_bs current_bs = self._get_even_center_val(current_bs, available_bs) @@ -144,7 +132,7 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int: This function finds a big enough batch size by training with various batch sizes. It estimate a batch size using equation is estimated using training history. The reason why using the word "big enough" is that it tries to find not maxmium but big enough value which uses - memory between lower and upper bound. + GPU memory between lower and upper bound. Args: drop_last (bool): Whether to drop the last incomplete batch. @@ -158,8 +146,8 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int: estimated_bs = self._default_bs # try default batch size - oom, bs_mem_usage = self._try_batch_size(estimated_bs) - if oom or bs_mem_usage > self._mem_upper_bound: + cuda_oom, bs_mem_usage = self._try_batch_size(estimated_bs) + if cuda_oom or bs_mem_usage > self._mem_upper_bound: self._default_bs -= 2 if self._default_bs <= 0: raise RuntimeError("Current device can't train model even with 2.") @@ -170,8 +158,8 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int: estimated_bs += 2 if estimated_bs > self._max_bs: return self._default_bs - oom, bs_mem_usage = self._try_batch_size(estimated_bs) - if oom or bs_mem_usage > self._mem_upper_bound: + cuda_oom, bs_mem_usage = self._try_batch_size(estimated_bs) + if cuda_oom or bs_mem_usage > self._mem_upper_bound: return self._default_bs # estimate batch size using equation @@ -180,9 +168,9 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int: estimated_bs = self._estimate_batch_size(estimation_pct) if estimated_bs in self._bs_try_history: break - oom, mem_usage = self._try_batch_size(estimated_bs) + cuda_oom, mem_usage = self._try_batch_size(estimated_bs) - if oom: + if cuda_oom: estimation_pct -= 0.1 if estimation_pct <= 0: estimated_bs = self._default_bs + 2 @@ -199,7 +187,7 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int: def _estimate_batch_size(self, estimation_pct: float) -> int: if len(self._bs_try_history) < 2: - raise RuntimeError("At least two trials should be done without OOM to estimate batch size.") + raise RuntimeError("At least two trials should be done without CUDA OOM to estimate batch size.") def distance_from_bound(val): if val[1] < self._mem_lower_bound: @@ -209,63 +197,39 @@ def distance_from_bound(val): # if memory usage is same, then lower batch size is preferred return val[1] - self._mem_upper_bound + val[0] / 10000 else: - return min(abs(self._mem_lower_bound - val[1], abs(val[1] - self._mem_upper_bound))) + return 0 - bs_arr = sorted([(bs, mem_usage) for bs, mem_usage in self._bs_try_history.items()], key=lambda x: x[0]) - for idx in range(len(bs_arr) - 1, -1, -1): - if bs_arr[idx][1] < self._mem_upper_bound: - cur_max_bs_idx = idx - break - else: - logger.warning("All batch size tried used more memory size than upper bound.") - return bs_arr[0][0] - - def check_bs_suitable(estimated_bs) -> bool: - # Check batch size is between largest bs which uses lower memory than uppper bound - # and smallest bs which uses higher memory than upper bound. - if estimated_bs >= bs_arr[cur_max_bs_idx][0]: - if cur_max_bs_idx + 1 < len(bs_arr): - if estimated_bs < bs_arr[cur_max_bs_idx + 1][0]: - return True - else: - return True - return False - - x_idx, y_idx = 0, len(bs_arr) - 1 - - while x_idx < y_idx: - graident = (bs_arr[y_idx][1] - bs_arr[x_idx][1]) / (bs_arr[y_idx][0] - bs_arr[x_idx][0]) - b = bs_arr[y_idx][1] - graident * bs_arr[y_idx][0] + bs_arr = sorted([(bs, mem_usage) for bs, mem_usage in self._bs_try_history.items()], key=distance_from_bound) + bs1 = bs_arr[0][0] + bs1_mem_usage = bs_arr[0][1] + + for i in range(1, len(bs_arr)): + graident = (bs_arr[i][1] - bs1_mem_usage) / (bs_arr[i][0] - bs1) + b = bs1_mem_usage - graident * bs1 if graident != 0: - estimated_bs = round(((self._total_mem * estimation_pct) - b) / (graident * 2)) * 2 - if check_bs_suitable(estimated_bs): - break + break - if distance_from_bound(bs_arr[x_idx + 1]) < distance_from_bound(bs_arr[y_idx - 1]): - x_idx += 1 + if graident == 0: # all batch size history used same GPU memory + if bs1_mem_usage < self._mem_lower_bound: + return bs1 + 2 + elif bs1_mem_usage > self._mem_upper_bound: + if bs1 <= 2: + return 2 + return bs1 - 2 else: - y_idx -= 1 + return bs1 - if x_idx == y_idx: - if check_bs_suitable(bs_arr[cur_max_bs_idx][0] + 2): - estimated_bs = bs_arr[cur_max_bs_idx][0] + 2 - else: - estimated_bs = bs_arr[cur_max_bs_idx][0] + estimated_bs = round(((self._total_mem * estimation_pct) - b) / (graident * 2)) * 2 + + # If estimated_bs is already tried and it used GPU memory more than upper bound, + # set estimated_bs as lowest value of batch sizes using GPU memory more than uppoer bound - 2 + if estimated_bs in self._bs_try_history and self._bs_try_history[estimated_bs] > self._mem_upper_bound: + for bs, mem_usage in bs_arr: + if mem_usage > self._mem_upper_bound: + estimated_bs = bs - 2 + break if estimated_bs > self._max_bs: estimated_bs = self._max_bs return estimated_bs - - -def _get_max_memory_reserved() -> int: - if is_xpu_available(): - return torch.xpu.max_memory_reserved(device=None) - return torch.cuda.max_memory_reserved(device=None) - - -def _get_total_memory_size() -> int: - if is_xpu_available(): - return torch.xpu.get_device_properties(0).total_memory - _, total_mem = torch.cuda.mem_get_info() - return total_mem diff --git a/src/otx/algorithms/detection/adapters/mmdet/task.py b/src/otx/algorithms/detection/adapters/mmdet/task.py index 274f7bfebda..58910adf155 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/task.py +++ b/src/otx/algorithms/detection/adapters/mmdet/task.py @@ -254,25 +254,21 @@ def _train_model( model.train() model.CLASSES = target_classes + if cfg.distributed: + convert_sync_batchnorm(model) + validate = bool(cfg.data.get("val", None)) if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE: - is_nncf = isinstance(self, NNCFBaseTask) + train_func = partial(train_detector, meta=deepcopy(meta), model=deepcopy(model), distributed=False) adapt_batch_size( - train_detector, - model, - datasets, + train_func, cfg, - cfg.distributed, - is_nncf, - meta=meta, + datasets, + isinstance(self, NNCFBaseTask), # nncf needs eval hooks not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE), - model_builder=getattr(self, "model_builder") if is_nncf else None, ) - if cfg.distributed: - convert_sync_batchnorm(model) - train_detector( model, datasets, diff --git a/src/otx/algorithms/segmentation/adapters/mmseg/task.py b/src/otx/algorithms/segmentation/adapters/mmseg/task.py index 4789fc80bf2..ade39b64224 100644 --- a/src/otx/algorithms/segmentation/adapters/mmseg/task.py +++ b/src/otx/algorithms/segmentation/adapters/mmseg/task.py @@ -10,6 +10,7 @@ import time from contextlib import nullcontext from copy import deepcopy +from functools import partial from typing import Any, Dict, Optional, Union import torch @@ -360,25 +361,21 @@ def _train_model( htcore.hpu.ModuleCacher(max_graphs=10)(model=model.backbone, inplace=True) htcore.hpu.ModuleCacher(max_graphs=10)(model=model.decode_head, inplace=True) + if cfg.distributed: + convert_sync_batchnorm(model) + validate = bool(cfg.data.get("val", None)) if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE: - is_nncf = isinstance(self, NNCFBaseTask) + train_func = partial(train_segmentor, meta=deepcopy(meta), model=deepcopy(model), distributed=False) adapt_batch_size( - train_segmentor, - model, - datasets, + train_func, cfg, - cfg.distributed, - is_nncf, - meta=meta, + datasets, + isinstance(self, NNCFBaseTask), # nncf needs eval hooks not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE), - model_builder=getattr(self, "model_builder") if is_nncf else None, ) - if cfg.distributed: - convert_sync_batchnorm(model) - train_segmentor( model, datasets, diff --git a/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py b/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py index 3c7f4cc447c..8fd3122d5bc 100644 --- a/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py +++ b/tests/unit/algorithms/common/adapters/mmcv/utils/test_automatic_bs.py @@ -1,5 +1,4 @@ -from unittest.mock import MagicMock - +from otx.algorithms.common.utils.utils import is_xpu_available import pytest from math import sqrt @@ -13,20 +12,19 @@ class MockBsSearchAlgo: - def __init__(self, train_func, train_func_kwargs, default_bs: int, max_bs: int): + def __init__(self, train_func, default_bs: int, max_bs: int): self.train_func = train_func - self.train_func_kwargs = train_func_kwargs self.default_bs = default_bs self.max_bs = max_bs def auto_decrease_batch_size(self): - self.train_func(batch_size=self.default_bs, **self.train_func_kwargs) - self.train_func(batch_size=self.default_bs // 2, **self.train_func_kwargs) + self.train_func(self.default_bs) + self.train_func(self.default_bs // 2) return self.default_bs // 2 def find_big_enough_batch_size(self, drop_last: bool): - self.train_func(batch_size=self.default_bs, **self.train_func_kwargs) - self.train_func(batch_size=self.default_bs + 2, **self.train_func_kwargs) + self.train_func(self.default_bs) + self.train_func(self.default_bs + 2) return self.default_bs + 2 @@ -65,24 +63,14 @@ def mock_dataset(mocker): return mock_ds -@pytest.fixture -def mock_model(): - return MagicMock() - - @pytest.mark.parametrize("not_increase", [True, False]) @pytest.mark.parametrize("is_action_task", [True, False]) @pytest.mark.parametrize("is_iter_based_runner", [True, False]) def test_adapt_batch_size( - mocker, - mock_adapt_algo_cls, - common_cfg, - mock_dataset, - not_increase, - is_action_task, - is_iter_based_runner, - mock_model, + mocker, mock_adapt_algo_cls, common_cfg, mock_dataset, not_increase, is_action_task, is_iter_based_runner ): + if is_xpu_available(): + pytest.skip("Adaptive batch size is not supported on XPU") # prepare mock_train_func = mocker.MagicMock() new_bs = DEFAULT_BS // 2 if not_increase else DEFAULT_BS + 2 @@ -98,7 +86,7 @@ def test_adapt_batch_size( mock_config = set_mock_cfg_not_action(common_cfg) # execute - adapt_batch_size(mock_train_func, mock_model, mock_dataset, mock_config, not_increase=not_increase) + adapt_batch_size(mock_train_func, mock_config, mock_dataset, False, not_increase) # check adapted batch size is applied if is_action_task: @@ -145,18 +133,11 @@ def set_up(self, mocker): self.sub_dataset = SubDataset(self.fullset, self.num_samples) def test_init(self, mocker): - class MockDataset: - def __init__(self): - self.img_indices = {"cls_0": 1, "cls_1": 2} - - fullset = MockDataset() + fullset = mocker.MagicMock() subset = SubDataset(fullset, 3) # test for class incremental case. If below assert can't be passed, ClsIncrSampler can't work well. assert len(subset.img_indices["new"]) / len(subset.img_indices["old"]) + 1 <= self.num_samples - # test existing num_indices values still exist - assert subset.img_indices["cls_0"] == 1 - assert subset.img_indices["cls_1"] == 2 @pytest.mark.parametrize("num_samples", [-1, 0]) def test_init_w_wrong_num_samples(self, mocker, num_samples): diff --git a/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py b/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py index 557a65c80ac..a347968dc5e 100644 --- a/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py +++ b/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py @@ -1,57 +1,34 @@ -from unittest.mock import MagicMock -from typing import Optional, Callable +from typing import Optional, List import pytest +import torch from tests.test_suite.e2e_test_system import e2e_pytest_unit from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo from otx.algorithms.common.adapters.torch.utils import bs_search_algo -@pytest.fixture -def train_func_kwargs(): - return MagicMock() - - @e2e_pytest_unit class TestBsSearchAlgo: @pytest.fixture(autouse=True) def setup_test(self, mocker): self.mock_torch = mocker.patch.object(bs_search_algo, "torch") self.mock_torch.cuda.mem_get_info.return_value = (1, 10000) - self.mock_mp = mocker.patch.object(bs_search_algo, "mp") - mocker.patch.object(bs_search_algo, "is_xpu_available", return_value=False) + self.mock_dist = mocker.patch.object(bs_search_algo, "dist") + self.mock_dist.is_initialized.return_value = False - def test_init(self, mocker, train_func_kwargs): - BsSearchAlgo(mocker.MagicMock(), train_func_kwargs, 4, 10) + def test_init(self, mocker): + BsSearchAlgo(mocker.MagicMock(), 4, 10) @pytest.mark.parametrize("default_bs", [-2, 0]) - def test_init_w_wrong_default_bs(self, mocker, default_bs, train_func_kwargs): + def test_init_w_wrong_default_bs(self, mocker, default_bs): with pytest.raises(ValueError): - BsSearchAlgo(mocker.MagicMock(), train_func_kwargs, default_bs=default_bs, max_bs=10) + BsSearchAlgo(mocker.MagicMock(), default_bs=default_bs, max_bs=10) @pytest.mark.parametrize("max_bs", [-2, 0]) - def test_init_w_wrong_default_bs(self, mocker, max_bs, train_func_kwargs): + def test_init_w_wrong_default_bs(self, mocker, max_bs): with pytest.raises(ValueError): - BsSearchAlgo(mocker.MagicMock(), train_func_kwargs, default_bs=4, max_bs=max_bs) - - def set_mp_process(self, train_func): - def mock_process(target, args): - batch_size = args[-2] - oom = False - mem_usage = 0 - - try: - mem_usage = train_func(batch_size) - except RuntimeError: - oom = True - - trial_queue = args[-1] - trial_queue.get.return_value = {"oom": oom, "max_memory_reserved": mem_usage} - - return MagicMock() - - self.mock_mp.get_context.return_value.Process.side_effect = mock_process + BsSearchAlgo(mocker.MagicMock(), default_bs=4, max_bs=max_bs) def get_mock_train_func(self, cuda_oom_bound: int, max_runnable_bs: int): def mock_train_func(batch_size): @@ -66,41 +43,131 @@ def mock_train_func(batch_size): self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage return mem_usage - self.set_mp_process(mock_train_func) - return mock_train_func - def test_try_batch_size(self, train_func_kwargs): + def test_try_batch_size(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 128, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) batch_size = 40 cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) assert cuda_oom is False assert max_memory_reserved == mock_train_func(batch_size) + self.mock_torch.cuda.reset_max_memory_cached.assert_called() + self.mock_torch.cuda.empty_cache.assert_called() - def test_try_batch_size_cuda_oom(self, train_func_kwargs): + def test_try_batch_size_cuda_oom(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=100, max_runnable_bs=80) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 128, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) batch_size = 200 cuda_oom, _ = bs_search_algo._try_batch_size(batch_size) assert cuda_oom is True + self.mock_torch.cuda.reset_max_memory_cached.assert_called() + self.mock_torch.cuda.empty_cache.assert_called() + + def _prepare_dist_test(self, broadcast_val: torch.Tensor, gather_val: Optional[List[torch.Tensor]] = None): + self.mock_dist.is_initialized.return_value = True + + # mocking torch.distributed.broadcast + def mock_broadcast(tensor: torch.Tensor, src: int): + tensor.copy_(broadcast_val) + + self.mock_dist.broadcast.side_effect = mock_broadcast + + # mocking torch.distributed.gather if gather_val is given + def mock_gather(tensor: torch.Tensor, gather_list: Optional[List[torch.Tensor]] = None, dst: int = 0): + for i in range(len(gather_list)): + gather_list[i].copy_(gather_val[i]) + + if gather_val is not None: + self.mock_dist.gather.side_effect = mock_gather + + # revert some of torch function + def mock_tensor_cuda(self, *args, **kwargs): + return self + + torch.Tensor.cuda = mock_tensor_cuda + self.mock_torch.tensor = torch.tensor + self.mock_torch.int64 = torch.int64 + self.mock_torch.max = torch.max + self.mock_torch.any = torch.any + self.mock_torch.stack = torch.stack + self.mock_torch.empty = torch.empty + + def test_try_batch_size_distributed_not_rank_0(self): + self.mock_dist.get_rank.return_value = 1 + broadcasted_cuda_oom = False + broadcasted_max_memory_reserved = 4000 + self._prepare_dist_test( + broadcast_val=torch.tensor([broadcasted_cuda_oom, broadcasted_max_memory_reserved], dtype=torch.int64) + ) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + batch_size = 40 + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + w1_max_memory_reserved = mock_train_func(batch_size) - def test_auto_decrease_batch_size(self, train_func_kwargs): + cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) + + # check dist.gather is called and get [cuda_oom, maxmemory_reserved] as arguments. + self.mock_dist.gather.assert_called_once() + assert self.mock_dist.gather.call_args.args[0][0].item() == False + assert self.mock_dist.gather.call_args.args[0][1].item() == w1_max_memory_reserved + assert self.mock_dist.gather.call_args.kwargs["dst"] == 0 + # check dist.broadcast is called + self.mock_dist.broadcast.assert_called_once() + assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0 + # check broadcased values are returned + assert cuda_oom is broadcasted_cuda_oom + assert max_memory_reserved == broadcasted_max_memory_reserved + + def test_try_batch_size_distributed_rank_0(self): + self.mock_dist.get_rank.return_value = 0 + self.mock_dist.get_world_size.return_value = 2 + self._prepare_dist_test( + broadcast_val=torch.tensor([True, 4000], dtype=torch.int64), + gather_val=[ + torch.tensor([False, 3000], dtype=torch.int64), + torch.tensor([True, 4000], dtype=torch.int64), + ], + ) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + batch_size = 40 + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + w0_max_memory_reserved = mock_train_func(batch_size) + + cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) + + # check dist.gather is called and get [cuda_oom, max_memory_reserved] as arguments. + self.mock_dist.gather.assert_called_once() + assert self.mock_dist.gather.call_args.args[0][0].item() == False + assert self.mock_dist.gather.call_args.args[0][1].item() == w0_max_memory_reserved + assert self.mock_dist.gather.call_args.kwargs["dst"] == 0 + # check if any process get cuda oom then set cuda_oom to True and + # set max_memory_reserved to maximum value of processes' + self.mock_dist.broadcast.assert_called_once() + self.mock_dist.broadcast.assert_called_once() + assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0 + assert self.mock_dist.broadcast.call_args.args[0][0].item() == True + assert self.mock_dist.broadcast.call_args.args[0][1].item() == 4000 + # check proper values are returned + assert cuda_oom is True + assert max_memory_reserved == 4000 + + def test_auto_decrease_batch_size(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 128, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) adapted_bs = bs_search_algo.auto_decrease_batch_size() assert adapted_bs == 80 - def test_find_max_usable_bs_gpu_memory_too_small(self, train_func_kwargs): + def test_find_max_usable_bs_gpu_memory_too_small(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 128, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) with pytest.raises(RuntimeError): bs_search_algo.auto_decrease_batch_size() @@ -113,10 +180,10 @@ def test_find_max_usable_bs_gpu_memory_too_small(self, train_func_kwargs): (66, 1000, None), ], ) - def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs, train_func_kwargs): + def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs): mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=max_runnable_bs) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 64, max_bs) + bs_search_algo = BsSearchAlgo(mock_train_func, 64, max_bs) adapted_bs = bs_search_algo.find_big_enough_batch_size() if expected_bs is None: @@ -124,14 +191,14 @@ def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs, else: assert adapted_bs == expected_bs - def test_find_big_enough_batch_size_gpu_memory_too_small(self, train_func_kwargs): + def test_find_big_enough_batch_size_gpu_memory_too_small(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=4, max_runnable_bs=1) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 128, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) with pytest.raises(RuntimeError): bs_search_algo.find_big_enough_batch_size() - def test_find_big_enough_batch_size_gradient_zero(self, train_func_kwargs): + def test_find_big_enough_batch_size_gradient_zero(self): def mock_train_func(batch_size): if batch_size > 1000: mem_usage = 10000 @@ -143,14 +210,12 @@ def mock_train_func(batch_size): self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage return mem_usage - self.set_mp_process(mock_train_func) - - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 64, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000) adapted_bs = bs_search_algo.find_big_enough_batch_size() assert adapted_bs == 100 - def test_find_big_enough_batch_size_not_exceed_upper_bound(self, train_func_kwargs): + def test_find_big_enough_batch_size_not_exceed_upper_bound(self): def mock_train_func(batch_size): if batch_size > 1000: mem_usage = 10000 @@ -162,17 +227,15 @@ def mock_train_func(batch_size): self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage return mem_usage - self.set_mp_process(mock_train_func) - - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 64, 1000) + bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000) adapted_bs = bs_search_algo.find_big_enough_batch_size() assert mock_train_func(adapted_bs) <= 8500 - def test_find_big_enough_batch_size_drop_last(self, train_func_kwargs): + def test_find_big_enough_batch_size_drop_last(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=180) - bs_search_algo = BsSearchAlgo(mock_train_func, train_func_kwargs, 64, 200) + bs_search_algo = BsSearchAlgo(mock_train_func, 64, 200) adapted_bs = bs_search_algo.find_big_enough_batch_size(True) assert adapted_bs == 100