-
Notifications
You must be signed in to change notification settings - Fork 446
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix mixed & lower precision training (#2668)
* remove dtype argument in torch.xpu.optimize * Add `custom_auto_fp16` to use xpu autocast * Update `forward`s to use `custom_auto_fp16` * precommit * Disable FP16 training * Add `custom_force_fp32` * Removed what force casting tensors to bf16 * Add `XPUOptimizerHook` and `XPUGradScaler` * precommit * Enable lower precision training * Remove dtype check for lower precision * Add `bf16_training` in recipe * fix * Remove unused module * Change `XPUOptimizerHook` to `BFp16XPUOptimizerHook` * Fix for common devices which don't use bf16 * precommit * Enable to use `auto_fp16` as it is * Add try-except avoiding mmcv import error * Fix error type Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com> --------- Co-authored-by: Shin, Eunwoo <eunwoo.shin@intel.com>
- Loading branch information
Showing
24 changed files
with
355 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,4 @@ | |
), | ||
) | ||
|
||
fp16 = dict(loss_scale=512.0) | ||
fp16 = dict(loss_scale=512.0, bf16_training=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,4 @@ | |
), | ||
) | ||
|
||
fp16 = dict(loss_scale=512.0) | ||
fp16 = dict(loss_scale=512.0, bf16_training=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
src/otx/algorithms/common/adapters/mmcv/hooks/xpu_optimizer_hook.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
"""Custom Optimizer Hook for mixed precision training on XPU.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from typing import Optional, Union | ||
|
||
from mmcv.runner.hooks import HOOKS, Fp16OptimizerHook | ||
|
||
from otx.algorithms.common.adapters.torch.amp import XPUGradScaler | ||
|
||
|
||
@HOOKS.register_module() | ||
class BFp16XPUOptimizerHook(Fp16OptimizerHook): | ||
"""Custom Optimizer Hook for mixed & lower precision training on XPU.""" | ||
|
||
def __init__( | ||
self, | ||
grad_clip: Optional[dict] = None, | ||
coalesce: bool = True, | ||
bucket_size_mb: int = -1, | ||
loss_scale: Union[float, str, dict] = 512.0, | ||
distributed: bool = True, | ||
) -> None: | ||
self.grad_clip = grad_clip | ||
self.coalesce = coalesce | ||
self.bucket_size_mb = bucket_size_mb | ||
self.distributed = distributed | ||
self._scale_update_param = None | ||
if loss_scale == "dynamic": | ||
self.loss_scaler = XPUGradScaler() | ||
elif isinstance(loss_scale, float): | ||
self._scale_update_param = loss_scale | ||
self.loss_scaler = XPUGradScaler(init_scale=loss_scale) | ||
elif isinstance(loss_scale, dict): | ||
self.loss_scaler = XPUGradScaler(**loss_scale) | ||
else: | ||
raise ValueError("loss_scale must be of type float, dict, or " f'"dynamic", got {loss_scale}') |
133 changes: 133 additions & 0 deletions
133
src/otx/algorithms/common/adapters/mmcv/utils/fp16_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
"""Custom fp16 related modules to enable XPU modules.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import functools | ||
from inspect import getfullargspec | ||
from typing import Callable, Iterable, Optional | ||
|
||
import torch | ||
from mmcv.runner.fp16_utils import cast_tensor_type | ||
from mmcv.utils import IS_NPU_AVAILABLE, TORCH_VERSION, digit_version | ||
from torch import nn | ||
|
||
from otx.algorithms.common.utils import is_xpu_available | ||
|
||
try: | ||
if is_xpu_available(): | ||
from torch.xpu.amp import autocast | ||
elif IS_NPU_AVAILABLE: | ||
from torch.npu.amp import autocast | ||
else: | ||
from torch.cuda.amp import autocast | ||
except ImportError: | ||
pass | ||
|
||
|
||
def custom_auto_fp16( | ||
apply_to: Optional[Iterable] = None, | ||
out_fp32: bool = False, | ||
supported_types: tuple = (nn.Module,), | ||
) -> Callable: | ||
"""Custom decorator to enable fp16 training automatically on XPU as well.""" | ||
|
||
def auto_fp16_wrapper(old_func: Callable) -> Callable: | ||
@functools.wraps(old_func) | ||
def new_func(*args, **kwargs) -> Callable: | ||
# check if the module has set the attribute `fp16_enabled`, if not, | ||
# just fallback to the original method. | ||
if not isinstance(args[0], supported_types): | ||
raise TypeError( | ||
"@auto_fp16 can only be used to decorate the " f"method of those classes {supported_types}" | ||
) | ||
if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled): | ||
return old_func(*args, **kwargs) | ||
|
||
target_dtype = torch.bfloat16 if is_xpu_available() else torch.half | ||
# get the arg spec of the decorated method | ||
args_info = getfullargspec(old_func) | ||
# get the argument names to be casted | ||
args_to_cast = args_info.args if apply_to is None else apply_to | ||
# convert the args that need to be processed | ||
new_args = [] | ||
# NOTE: default args are not taken into consideration | ||
if args: | ||
arg_names = args_info.args[: len(args)] | ||
for i, arg_name in enumerate(arg_names): | ||
if arg_name in args_to_cast: | ||
new_args.append(cast_tensor_type(args[i], torch.float, target_dtype)) | ||
else: | ||
new_args.append(args[i]) | ||
# convert the kwargs that need to be processed | ||
new_kwargs = {} | ||
if kwargs: | ||
for arg_name, arg_value in kwargs.items(): | ||
if arg_name in args_to_cast: | ||
new_kwargs[arg_name] = cast_tensor_type(arg_value, torch.float, target_dtype) | ||
else: | ||
new_kwargs[arg_name] = arg_value | ||
# apply converted arguments to the decorated method | ||
if TORCH_VERSION != "parrots" and digit_version(TORCH_VERSION) >= digit_version("1.6.0"): | ||
with autocast(enabled=True): | ||
output = old_func(*new_args, **new_kwargs) | ||
else: | ||
output = old_func(*new_args, **new_kwargs) | ||
# cast the results back to fp32 if necessary | ||
if out_fp32: | ||
output = cast_tensor_type(output, target_dtype, torch.float) | ||
return output | ||
|
||
return new_func | ||
|
||
return auto_fp16_wrapper | ||
|
||
|
||
def custom_force_fp32(apply_to: Optional[Iterable] = None, out_fp16: bool = False) -> Callable: | ||
"""Custom decorator to convert input arguments to fp32 in force on XPU as well.""" | ||
|
||
def force_fp32_wrapper(old_func): | ||
@functools.wraps(old_func) | ||
def new_func(*args, **kwargs) -> Callable: | ||
# check if the module has set the attribute `fp16_enabled`, if not, | ||
# just fallback to the original method. | ||
if not isinstance(args[0], torch.nn.Module): | ||
raise TypeError("@force_fp32 can only be used to decorate the " "method of nn.Module") | ||
if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled): | ||
return old_func(*args, **kwargs) | ||
|
||
source_dtype = torch.bfloat16 if is_xpu_available() else torch.half | ||
# get the arg spec of the decorated method | ||
args_info = getfullargspec(old_func) | ||
# get the argument names to be casted | ||
args_to_cast = args_info.args if apply_to is None else apply_to | ||
# convert the args that need to be processed | ||
new_args = [] | ||
if args: | ||
arg_names = args_info.args[: len(args)] | ||
for i, arg_name in enumerate(arg_names): | ||
if arg_name in args_to_cast: | ||
new_args.append(cast_tensor_type(args[i], source_dtype, torch.float)) | ||
else: | ||
new_args.append(args[i]) | ||
# convert the kwargs that need to be processed | ||
new_kwargs = dict() | ||
if kwargs: | ||
for arg_name, arg_value in kwargs.items(): | ||
if arg_name in args_to_cast: | ||
new_kwargs[arg_name] = cast_tensor_type(arg_value, source_dtype, torch.float) | ||
else: | ||
new_kwargs[arg_name] = arg_value | ||
# apply converted arguments to the decorated method | ||
if TORCH_VERSION != "parrots" and digit_version(TORCH_VERSION) >= digit_version("1.6.0"): | ||
with autocast(enabled=False): | ||
output = old_func(*new_args, **new_kwargs) | ||
else: | ||
output = old_func(*new_args, **new_kwargs) | ||
# cast the results back to fp32 if necessary | ||
if out_fp16: | ||
output = cast_tensor_type(output, torch.float, source_dtype) | ||
return output | ||
|
||
return new_func | ||
|
||
return force_fp32_wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Custom AMP (Automatic Mixed Precision package) in OTX.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
try: | ||
from .xpu_grad_scaler import XPUGradScaler # noqa: F401 | ||
except: # noqa: E722 | ||
pass |
114 changes: 114 additions & 0 deletions
114
src/otx/algorithms/common/adapters/torch/amp/xpu_grad_scaler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""Custom GradScaler to scale loss.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
from collections import abc, defaultdict | ||
from typing import List | ||
|
||
import torch | ||
from intel_extension_for_pytorch.cpu.autocast._grad_scaler import _MultiDeviceReplicator | ||
from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state | ||
|
||
|
||
class XPUGradScaler(GradScaler): | ||
"""GradScaler for XPU.""" | ||
|
||
def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): | ||
self._enabled = enabled | ||
|
||
if self._enabled: | ||
assert growth_factor > 1.0, "The growth factor must be > 1.0." | ||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0." | ||
|
||
self._init_scale = init_scale | ||
# self._scale will be lazily initialized during the first call to scale() | ||
self._scale = None | ||
self._growth_factor = growth_factor | ||
self._backoff_factor = backoff_factor | ||
self._growth_interval = growth_interval | ||
self._init_growth_tracker = 0 | ||
# self._growth_tracker will be lazily initialized during the first call to scale() | ||
self._growth_tracker = None | ||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) | ||
|
||
def scale(self, outputs): | ||
"""Multiplies ('scales') a tensor or list of tensors by the scale factor. | ||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned | ||
unmodified. | ||
Args: | ||
outputs (Tensor or iterable of Tensors): Outputs to scale. | ||
""" | ||
if not self._enabled: | ||
return outputs | ||
|
||
# Short-circuit for the common case. | ||
if isinstance(outputs, torch.Tensor): | ||
assert outputs.device.type == "xpu" | ||
if self._scale is None: | ||
self._lazy_init_scale_growth_tracker(outputs.device) | ||
assert self._scale is not None | ||
return outputs * self._scale.to(device=outputs.device, non_blocking=True) | ||
|
||
# Invoke the more complex machinery only if we're treating multiple outputs. | ||
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale | ||
|
||
def apply_scale(val): | ||
if isinstance(val, torch.Tensor): | ||
assert val.device.type == "xpu" | ||
if len(stash) == 0: | ||
if self._scale is None: | ||
self._lazy_init_scale_growth_tracker(val.device) | ||
assert self._scale is not None | ||
stash.append(_MultiDeviceReplicator(self._scale)) | ||
return val * stash[0].get(val.device) | ||
elif isinstance(val, abc.Iterable): | ||
iterable = map(apply_scale, val) | ||
if isinstance(val, (list, tuple)): | ||
return type(val)(iterable) | ||
else: | ||
return iterable | ||
else: | ||
raise ValueError("outputs must be a Tensor or an iterable of Tensors") | ||
|
||
return apply_scale(outputs) | ||
|
||
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_bf16=False): | ||
per_device_inv_scale = _MultiDeviceReplicator(inv_scale) | ||
per_device_found_inf = _MultiDeviceReplicator(found_inf) | ||
|
||
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. | ||
# There could be hundreds of grads, so we'd like to iterate through them just once. | ||
# However, we don't know their devices or dtypes in advance. | ||
|
||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict | ||
# Google says mypy struggles with defaultdicts type annotations. | ||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] | ||
with torch.no_grad(): | ||
for group in optimizer.param_groups: | ||
for param in group["params"]: | ||
if param.grad is None: | ||
continue | ||
if param.grad.is_sparse: | ||
# is_coalesced() == False means the sparse grad has values with duplicate indices. | ||
# coalesce() deduplicates indices and adds all values that have the same index. | ||
# For scaled bf16 values, there's a good chance coalescing will cause overflow, | ||
# so we should check the coalesced _values(). | ||
if param.grad.dtype is torch.bfloat16: | ||
param.grad = param.grad.coalesce() | ||
to_unscale = param.grad._values() | ||
else: | ||
to_unscale = param.grad | ||
|
||
# TODO: is there a way to split by device and dtype without appending in the inner loop? | ||
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) | ||
|
||
for device, per_dtype_grads in per_device_and_dtype_grads.items(): | ||
for grads in per_dtype_grads.values(): | ||
torch._amp_foreach_non_finite_check_and_unscale_( | ||
grads, per_device_found_inf.get(device), per_device_inv_scale.get(device) | ||
) | ||
|
||
return per_device_found_inf._per_device_tensors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.