Skip to content

Commit

Permalink
Revert adaptive batch size (#3340)
Browse files Browse the repository at this point in the history
revert autobs
  • Loading branch information
eunwoosh authored Apr 17, 2024
1 parent 321ee2f commit f66acdd
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 459 deletions.
10 changes: 5 additions & 5 deletions src/otx/algorithms/action/adapters/mmaction/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Dict, Optional, Union

import torch
Expand Down Expand Up @@ -323,13 +324,12 @@ def _train_model(
validate = bool(cfg.data.get("val", None))

if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE:
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(
train_model,
model,
datasets,
train_func,
cfg,
cfg.distributed,
meta=meta,
datasets,
validate,
not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE),
)

Expand Down
19 changes: 8 additions & 11 deletions src/otx/algorithms/classification/adapters/mmcls/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Type, Union

import torch
Expand Down Expand Up @@ -379,6 +380,9 @@ def _train_model(
htcore.hpu.ModuleCacher(max_graphs=10)(model=model.backbone, inplace=True)
htcore.hpu.ModuleCacher(max_graphs=10)(model=model.head, inplace=True)

if cfg.distributed:
convert_sync_batchnorm(model)

validate = bool(cfg.data.get("val", None))
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
Expand Down Expand Up @@ -406,22 +410,15 @@ def _train_model(
)

if self._hyperparams.learning_parameters.auto_adapt_batch_size != BatchSizeAdaptType.NONE:
is_nncf = isinstance(self, NNCFBaseTask)
train_func = partial(train_model, meta=deepcopy(meta), model=deepcopy(model), distributed=False)
adapt_batch_size(
train_model,
model,
datasets,
train_func,
cfg,
cfg.distributed,
is_nncf,
meta=meta,
datasets,
isinstance(self, NNCFBaseTask), # nncf needs eval hooks
not_increase=(self._hyperparams.learning_parameters.auto_adapt_batch_size == BatchSizeAdaptType.SAFE),
model_builder=getattr(self, "model_builder") if is_nncf else None,
)

if cfg.distributed:
convert_sync_batchnorm(model)

train_model(
model,
datasets,
Expand Down
264 changes: 50 additions & 214 deletions src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,14 @@
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import inspect
from copy import copy
from importlib import import_module
from copy import deepcopy
from math import sqrt
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List

import numpy as np
import torch
from mmcv import Config
from mmcv.runner import wrap_fp16_model
from torch import distributed as dist
from torch.cuda import is_available as cuda_available
from torch.utils.data import Dataset

from otx.algorithms.common.adapters.mmcv.utils.config_utils import OTXConfig
from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo, sync_batchnorm_2_batchnorm
from otx.algorithms.common.utils import is_xpu_available
from otx.core.data import caching
from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo
from otx.utils.logger import get_logger

logger = get_logger()
Expand Down Expand Up @@ -50,142 +38,7 @@ def _set_value_at_dict_in_dict(target: Dict, key_path: str, value):
target[keys[-1]] = value


def _build_model(model_builder: Callable, cfg: Config) -> torch.nn.Module:
model = model_builder(cfg)
if cfg.get("fp16", False):
wrap_fp16_model(model)
return model


NNCF_PATCH_MODULE = {
"mmcls": "otx.algorithms.classification.adapters.mmcls.nncf.patches",
"mmdet": "otx.algorithms.detection.adapters.mmdet.nncf.patches",
"mmseg": "otx.algorithms.segmentation.adapters.mmseg.nncf.patches",
}


def _train_func_single_iter(
batch_size: int,
train_func: Callable,
datasets: List[Dataset],
cfg: OTXConfig,
is_nncf: bool = False,
meta: Optional[Dict[str, Any]] = None,
model: Optional[torch.nn.Module] = None,
model_builder: Optional[Callable] = None,
) -> None:
caching.MemCacheHandlerSingleton.create("null", 0) # initialize mem cache
_set_batch_size(cfg, batch_size)
_set_max_epoch(cfg, 1) # setup for training a single iter to save time

new_dataset = [SubDataset(datasets[0], batch_size)]

validate = is_nncf # nncf needs eval hooks
if is_nncf:
pkg_name = inspect.getmodule(train_func).__package__
for framework in ["mmcls", "mmdet", "mmseg"]:
if framework in pkg_name:
import_module(NNCF_PATCH_MODULE[framework])
break
else:
framework = None

if framework == "mmcls":
validate = False # classification task has own custom eval hook

if model is None:
model = _build_model(model_builder, cfg)

if is_nncf:
model.nncf._uncompressed_model_accuracy = 0

sync_batchnorm_2_batchnorm(model)

train_func(
model=model,
dataset=new_dataset,
cfg=cfg,
distributed=False,
validate=validate,
meta=meta,
)


def _save_nncf_model_weight(model: torch.nn.Module, cfg: OTXConfig, save_path: Path) -> str:
"""Save nncf model weight after nncf finishes to build a model.
NNCF analyzes and get some statistics when buliding a model, which is time consuming.
To skip this part, nncf model weight is saved and load it on new process.
"""
from otx.algorithms.common.adapters.nncf.compression import NNCFMetaState

file_path = save_path / "nncf_model.pth"
for custom_hook in cfg.custom_hooks:
if custom_hook["type"] == "CompressionHook":
compression_ctrl = custom_hook["compression_ctrl"].get_compression_state()
break
else:
msg = "CompressionHook doesn't exist in custom hooks."
raise RuntimeError(msg)

torch.save(
{
"state_dict": model.state_dict(),
"meta": {
"nncf_meta": NNCFMetaState(
state_to_build=cfg.runner.nncf_meta.state_to_build,
data_to_build=cfg.runner.nncf_meta.data_to_build,
compression_ctrl=compression_ctrl,
),
"nncf_enable_compression": True,
},
},
file_path,
)

return str(file_path)


def _organize_custom_hooks(custom_hooks: List, is_nncf: bool = False) -> None:
# Remove hooks due to reasons below
# for nncf task
# OTXProgressHook and CompressionHook are added when building a model. Need to remove them to avoid duplication.
# for normal task
# OTXProgressHook => prevent progress bar from being 0 and 100 repeatably
# CancelInterfaceHook => avoid segmentation fault
# earlystoppinghook => if eval hook is excluded, this hook makes an error due to absence of score history
# CustomEvalHook => exclude validation in classification task

if is_nncf:
hooks_to_remove = ["OTXProgressHook", "CompressionHook"]
else:
hooks_to_remove = ["OTXProgressHook", "earlystoppinghook", "CustomEvalHook", "CancelInterfaceHook"]

idx_hooks_to_remove = []
for i, hook in enumerate(custom_hooks):
if not is_nncf and hook["type"] == "AdaptiveTrainSchedulingHook":
hook["enable_eval_before_run"] = False
for hook_to_remove in hooks_to_remove:
if hook_to_remove.lower() in hook["type"].lower():
idx_hooks_to_remove.append(i)

if idx_hooks_to_remove:
idx_hooks_to_remove.sort()
for i in reversed(idx_hooks_to_remove):
custom_hooks.pop(i)


def adapt_batch_size(
train_func: Callable,
model: torch.nn.Module,
datasets: List[Dataset],
cfg: OTXConfig,
distributed: bool = False,
is_nncf: bool = False,
meta: Optional[Dict[str, Any]] = None,
not_increase: bool = True,
model_builder: Optional[Callable] = None,
) -> None:
def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool = False, not_increase: bool = True):
"""Decrease batch size if default batch size isn't fit to current GPU device.
This function just setup for single iteration training to reduce time for adapting.
Expand All @@ -194,69 +47,59 @@ def adapt_batch_size(
Args:
train_func (Callable): The function to train a model.
Only cfg, dataset and meta are passed to the function when invoking it.
model (torch.nn.Module): Model to train.
cfg: Configuration of a training.
meta (Dict): A dict records some meta information of a training.
datasets (List): List of datasets.
cfg (OTXConfig): Configuration of a training.
distributed (bool): whether distributed training or not.
is_nncf (bool): Whether nncf or not.
meta (Optional[Dict[str, Any]]): meta information.
validate (bool): Whether do vlidation or not.
not_increase (bool) : Whether adapting batch size to larger value than default value or not.
model_builder (Optional[Callable]):
Function for building a model. If it exsits, a model build from model_builder is used instead of the model
in the argument. It's required for nncf because nncf changes model , which prevent model from pickling.
"""

if not (cuda_available() or is_xpu_available()):
logger.warning("Skip Auto-adaptive batch size: Adaptive batch size supports CUDA or XPU.")
if not cuda_available():
logger.warning("Skip Auto-adaptive batch size: CUDA should be available, but it isn't.")
return

copied_cfg = copy(cfg)
copied_cfg.custom_hooks = copy(cfg.custom_hooks)
copied_cfg.pop("algo_backend", None)

if is_nncf:
if model_builder is None:
msg = "model_builder should be possed for building a nncf model."
raise RuntimeError(msg)
temp_dir = TemporaryDirectory("adaptive-bs")
copied_cfg.load_from = _save_nncf_model_weight(model, cfg, Path(temp_dir.name))

_organize_custom_hooks(copied_cfg.custom_hooks, is_nncf)

default_bs = _get_batch_size(cfg)
if not distributed or (rank := dist.get_rank()) == 0:
train_func_kwargs = {
"train_func": train_func,
"datasets": datasets,
"cfg": copied_cfg,
"is_nncf": is_nncf,
"meta": meta,
}
if model_builder is None:
train_func_kwargs["model"] = model
else:
train_func_kwargs["model_builder"] = model_builder

bs_search_algo = BsSearchAlgo(
train_func=_train_func_single_iter,
train_func_kwargs=train_func_kwargs,
default_bs=default_bs,
max_bs=len(datasets[0]),
def train_func_single_iter(batch_size):
copied_cfg = deepcopy(cfg)
_set_batch_size(copied_cfg, batch_size)
_set_max_epoch(copied_cfg, 1) # setup for training a single iter to reduce time

# Remove hooks due to reasons below
# OTXProgressHook => prevent progress bar from being 0 and 100 repeatably
# earlystoppinghook => if eval hook is excluded, this hook makes an error due to absence of score history
# CustomEvalHook => exclude validation in classification task
idx_hooks_to_remove = []
hooks_to_remove = ["OTXProgressHook", "earlystoppinghook", "CustomEvalHook"]
for i, hook in enumerate(copied_cfg.custom_hooks):
if not validate and hook["type"] == "AdaptiveTrainSchedulingHook":
hook["enable_eval_before_run"] = False
for hook_to_remove in hooks_to_remove:
if hook_to_remove.lower() in hook["type"].lower():
idx_hooks_to_remove.append(i)

if idx_hooks_to_remove:
idx_hooks_to_remove.sort()
for i in reversed(idx_hooks_to_remove):
del copied_cfg.custom_hooks[i]

new_datasets = [SubDataset(datasets[0], batch_size)]

train_func(
dataset=new_datasets,
cfg=copied_cfg,
validate=validate,
)
if not_increase:
new_batch_size = bs_search_algo.auto_decrease_batch_size()
else:
drop_last = cfg.data.get("train_dataloader", {}).get("drop_last", False)
new_batch_size = bs_search_algo.find_big_enough_batch_size(drop_last)

if distributed:
if rank == 0:
total_try_result = torch.tensor([new_batch_size], dtype=torch.int)
else:
total_try_result = torch.empty(1, dtype=torch.int)
total_try_result = total_try_result.cuda() if torch.cuda.is_available() else total_try_result.xpu()
dist.broadcast(total_try_result, src=0)
new_batch_size = total_try_result[0].item()
default_bs = _get_batch_size(cfg)
bs_search_algo = BsSearchAlgo(
train_func=train_func_single_iter,
default_bs=default_bs,
max_bs=len(datasets[0]),
)
if not_increase:
new_batch_size = bs_search_algo.auto_decrease_batch_size()
else:
drop_last = cfg.data.get("train_dataloader", {}).get("drop_last", False)
new_batch_size = bs_search_algo.find_big_enough_batch_size(drop_last)

if default_bs != new_batch_size:
_set_batch_size(cfg, new_batch_size)
Expand Down Expand Up @@ -315,18 +158,11 @@ def __init__(self, fullset, num_samples: int):

self.fullset = fullset
self.num_samples = num_samples
self._img_indices = { # for class incremental case
self.img_indices = { # for class incremental case
"old": [i for i in range(num_samples // 2)],
"new": [i for i in range(num_samples // 2, num_samples)],
}

@property
def img_indices(self):
"""img_indices getter."""
img_indices = copy(getattr(self.fullset, "img_indices", {}))
img_indices.update(self._img_indices)
return img_indices

def __len__(self) -> int:
"""Get length of subset."""
return self.num_samples
Expand Down
Loading

0 comments on commit f66acdd

Please sign in to comment.