Skip to content

Commit

Permalink
Fix mixed & lower precision training (#2668)
Browse files Browse the repository at this point in the history
* remove dtype argument in torch.xpu.optimize

* Add `custom_auto_fp16` to use xpu autocast

* Update `forward`s to use `custom_auto_fp16`

* precommit

* Disable FP16 training

* Add `custom_force_fp32`

* Removed what force casting tensors to bf16

* Add `XPUOptimizerHook` and `XPUGradScaler`

* precommit

* Enable lower precision training

* Remove dtype check for lower precision

* Add `bf16_training` in recipe

* fix

* Remove unused module

* Change `XPUOptimizerHook` to `BFp16XPUOptimizerHook`

* Fix for common devices which don't use bf16

* precommit

* Enable to use `auto_fp16` as it is

* Add try-except avoiding mmcv import error

* Fix error type

Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>

---------

Co-authored-by: Shin, Eunwoo <eunwoo.shin@intel.com>
  • Loading branch information
sungchul2 and eunwoosh authored Nov 29, 2023
1 parent 2239ac1 commit 03e87f5
Show file tree
Hide file tree
Showing 24 changed files with 355 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ def train_model(model, dataset, cfg, distributed=False, validate=False, timestam
optimizer = build_optimizer(model, cfg.optimizer)

if cfg.device == "xpu":
if fp16_cfg is not None:
dtype = torch.bfloat16
else:
dtype = torch.float32
dtype = torch.bfloat16 if cfg.optimizer_config.get("bf16_training", False) else torch.float32
model.train()
model, optimizer = torch.xpu.optimize(model, optimizer=optimizer, dtype=dtype)

if "bf16_training" in cfg.optimizer_config:
# Remove unused parameters in runner
cfg.optimizer_config.pop("bf16_training")

if cfg.get("runner") is None:
cfg.runner = {"type": "EpochBasedRunner", "max_epochs": cfg.total_epochs}
warnings.warn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
backbone=dict(arch="deit-tiny", init_cfg=dict(type="Pretrained", checkpoint=ckpt_url, prefix="backbone")),
)

fp16 = dict(loss_scale=512.0)
fp16 = dict(loss_scale=512.0, bf16_training=False)

optimizer = dict(_delete_=True, type="AdamW", lr=0.01, weight_decay=0.05)
optimizer_config = dict(_delete_=True)
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
),
)

fp16 = dict(loss_scale=512.0)
fp16 = dict(loss_scale=512.0, bf16_training=False)
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
head=dict(type="CustomLinearClsHead", loss=dict(type="CrossEntropyLoss", loss_weight=1.0)),
)

fp16 = dict(loss_scale=512.0)
fp16 = dict(loss_scale=512.0, bf16_training=False)
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
),
)

fp16 = dict(loss_scale=512.0)
fp16 = dict(loss_scale=512.0, bf16_training=False)
15 changes: 14 additions & 1 deletion src/otx/algorithms/common/adapters/mmcv/configurer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,20 @@ def configure_fp16(cfg: Config):
logger.warning("SAMOptimizerHook is not supported on HPU. Changed to OptimizerHook.")
opts["type"] = "HPUOptimizerHook"
cfg.optimizer_config.update(opts)
elif torch.cuda.is_available() or is_xpu_available():
elif is_xpu_available():
opts.update({"distributed": distributed, **fp16_config})
if optim_type == "SAMOptimizerHook":
logger.warning("SAMOptimizerHook is not supported on XPU yet, changed to OptimizerHook.")
opts["type"] = "OptimizerHook"
if optim_type == "OptimizerHook":
opts["type"] = "BFp16XPUOptimizerHook"
else:
# does not support optimizerhook type
# let mm library handle it
cfg.fp16 = fp16_config
opts = dict()
cfg.optimizer_config.update(opts)
elif torch.cuda.is_available():
opts.update({"distributed": distributed, **fp16_config})
if optim_type == "SAMOptimizerHook":
opts["type"] = "Fp16SAMOptimizerHook"
Expand Down
7 changes: 7 additions & 0 deletions src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,10 @@
__all__ += ["HPUOptimizerHook"]
except: # noqa: E722
pass

try:
from .xpu_optimizer_hook import BFp16XPUOptimizerHook

__all__ += ["BFp16XPUOptimizerHook"]
except: # noqa: E722
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Custom Optimizer Hook for mixed precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from typing import Optional, Union

from mmcv.runner.hooks import HOOKS, Fp16OptimizerHook

from otx.algorithms.common.adapters.torch.amp import XPUGradScaler


@HOOKS.register_module()
class BFp16XPUOptimizerHook(Fp16OptimizerHook):
"""Custom Optimizer Hook for mixed & lower precision training on XPU."""

def __init__(
self,
grad_clip: Optional[dict] = None,
coalesce: bool = True,
bucket_size_mb: int = -1,
loss_scale: Union[float, str, dict] = 512.0,
distributed: bool = True,
) -> None:
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.distributed = distributed
self._scale_update_param = None
if loss_scale == "dynamic":
self.loss_scaler = XPUGradScaler()
elif isinstance(loss_scale, float):
self._scale_update_param = loss_scale
self.loss_scaler = XPUGradScaler(init_scale=loss_scale)
elif isinstance(loss_scale, dict):
self.loss_scaler = XPUGradScaler(**loss_scale)
else:
raise ValueError("loss_scale must be of type float, dict, or " f'"dynamic", got {loss_scale}')
133 changes: 133 additions & 0 deletions src/otx/algorithms/common/adapters/mmcv/utils/fp16_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Custom fp16 related modules to enable XPU modules."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import functools
from inspect import getfullargspec
from typing import Callable, Iterable, Optional

import torch
from mmcv.runner.fp16_utils import cast_tensor_type
from mmcv.utils import IS_NPU_AVAILABLE, TORCH_VERSION, digit_version
from torch import nn

from otx.algorithms.common.utils import is_xpu_available

try:
if is_xpu_available():
from torch.xpu.amp import autocast
elif IS_NPU_AVAILABLE:
from torch.npu.amp import autocast
else:
from torch.cuda.amp import autocast
except ImportError:
pass


def custom_auto_fp16(
apply_to: Optional[Iterable] = None,
out_fp32: bool = False,
supported_types: tuple = (nn.Module,),
) -> Callable:
"""Custom decorator to enable fp16 training automatically on XPU as well."""

def auto_fp16_wrapper(old_func: Callable) -> Callable:
@functools.wraps(old_func)
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], supported_types):
raise TypeError(
"@auto_fp16 can only be used to decorate the " f"method of those classes {supported_types}"
)
if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled):
return old_func(*args, **kwargs)

target_dtype = torch.bfloat16 if is_xpu_available() else torch.half
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
args_to_cast = args_info.args if apply_to is None else apply_to
# convert the args that need to be processed
new_args = []
# NOTE: default args are not taken into consideration
if args:
arg_names = args_info.args[: len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(cast_tensor_type(args[i], torch.float, target_dtype))
else:
new_args.append(args[i])
# convert the kwargs that need to be processed
new_kwargs = {}
if kwargs:
for arg_name, arg_value in kwargs.items():
if arg_name in args_to_cast:
new_kwargs[arg_name] = cast_tensor_type(arg_value, torch.float, target_dtype)
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != "parrots" and digit_version(TORCH_VERSION) >= digit_version("1.6.0"):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary
if out_fp32:
output = cast_tensor_type(output, target_dtype, torch.float)
return output

return new_func

return auto_fp16_wrapper


def custom_force_fp32(apply_to: Optional[Iterable] = None, out_fp16: bool = False) -> Callable:
"""Custom decorator to convert input arguments to fp32 in force on XPU as well."""

def force_fp32_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
raise TypeError("@force_fp32 can only be used to decorate the " "method of nn.Module")
if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled):
return old_func(*args, **kwargs)

source_dtype = torch.bfloat16 if is_xpu_available() else torch.half
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
args_to_cast = args_info.args if apply_to is None else apply_to
# convert the args that need to be processed
new_args = []
if args:
arg_names = args_info.args[: len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(cast_tensor_type(args[i], source_dtype, torch.float))
else:
new_args.append(args[i])
# convert the kwargs that need to be processed
new_kwargs = dict()
if kwargs:
for arg_name, arg_value in kwargs.items():
if arg_name in args_to_cast:
new_kwargs[arg_name] = cast_tensor_type(arg_value, source_dtype, torch.float)
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != "parrots" and digit_version(TORCH_VERSION) >= digit_version("1.6.0"):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary
if out_fp16:
output = cast_tensor_type(output, torch.float, source_dtype)
return output

return new_func

return force_fp32_wrapper
9 changes: 9 additions & 0 deletions src/otx/algorithms/common/adapters/torch/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Custom AMP (Automatic Mixed Precision package) in OTX."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

try:
from .xpu_grad_scaler import XPUGradScaler # noqa: F401
except: # noqa: E722
pass
114 changes: 114 additions & 0 deletions src/otx/algorithms/common/adapters/torch/amp/xpu_grad_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Custom GradScaler to scale loss."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from collections import abc, defaultdict
from typing import List

import torch
from intel_extension_for_pytorch.cpu.autocast._grad_scaler import _MultiDeviceReplicator
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state


class XPUGradScaler(GradScaler):
"""GradScaler for XPU."""

def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
self._enabled = enabled

if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."

self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)

def scale(self, outputs):
"""Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs

# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.device.type == "xpu"
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)

# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale

def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.device.type == "xpu"
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")

return apply_scale(outputs)

def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_bf16=False):
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)

# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.

# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled bf16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.bfloat16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad

# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)

for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(
grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)
)

return per_device_found_inf._per_device_tensors
12 changes: 12 additions & 0 deletions src/otx/algorithms/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,15 @@
if is_hpu_available():
os.environ["PT_HPU_LAZY_MODE"] = "1"
import habana_frameworks.torch.gpu_migration # noqa: F401


if is_xpu_available():
try:
import mmcv

from otx.algorithms.common.adapters.mmcv.utils.fp16_utils import custom_auto_fp16, custom_force_fp32

mmcv.runner.auto_fp16 = custom_auto_fp16
mmcv.runner.force_fp32 = custom_force_fp32
except ImportError:
pass
Loading

0 comments on commit 03e87f5

Please sign in to comment.