diff --git a/src/otx/algorithms/classification/adapters/mmcls/apis/train.py b/src/otx/algorithms/classification/adapters/mmcls/apis/train.py index de22e38087f..5527e19a847 100644 --- a/src/otx/algorithms/classification/adapters/mmcls/apis/train.py +++ b/src/otx/algorithms/classification/adapters/mmcls/apis/train.py @@ -93,13 +93,14 @@ def train_model(model, dataset, cfg, distributed=False, validate=False, timestam optimizer = build_optimizer(model, cfg.optimizer) if cfg.device == "xpu": - if fp16_cfg is not None: - dtype = torch.bfloat16 - else: - dtype = torch.float32 + dtype = torch.bfloat16 if cfg.optimizer_config.get("bf16_training", False) else torch.float32 model.train() model, optimizer = torch.xpu.optimize(model, optimizer=optimizer, dtype=dtype) + if "bf16_training" in cfg.optimizer_config: + # Remove unused parameters in runner + cfg.optimizer_config.pop("bf16_training") + if cfg.get("runner") is None: cfg.runner = {"type": "EpochBasedRunner", "max_epochs": cfg.total_epochs} warnings.warn( diff --git a/src/otx/algorithms/classification/configs/deit_tiny/model.py b/src/otx/algorithms/classification/configs/deit_tiny/model.py index f69c8cdbbb6..083166b9e49 100644 --- a/src/otx/algorithms/classification/configs/deit_tiny/model.py +++ b/src/otx/algorithms/classification/configs/deit_tiny/model.py @@ -14,7 +14,7 @@ backbone=dict(arch="deit-tiny", init_cfg=dict(type="Pretrained", checkpoint=ckpt_url, prefix="backbone")), ) -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) optimizer = dict(_delete_=True, type="AdamW", lr=0.01, weight_decay=0.05) optimizer_config = dict(_delete_=True) diff --git a/src/otx/algorithms/classification/configs/efficientnet_b0_cls_incr/model.py b/src/otx/algorithms/classification/configs/efficientnet_b0_cls_incr/model.py index 4055b7ff90b..c397f704649 100644 --- a/src/otx/algorithms/classification/configs/efficientnet_b0_cls_incr/model.py +++ b/src/otx/algorithms/classification/configs/efficientnet_b0_cls_incr/model.py @@ -22,4 +22,4 @@ ), ) -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) diff --git a/src/otx/algorithms/classification/configs/efficientnet_v2_s_cls_incr/model.py b/src/otx/algorithms/classification/configs/efficientnet_v2_s_cls_incr/model.py index b0a278536cd..ea5ef1ff773 100644 --- a/src/otx/algorithms/classification/configs/efficientnet_v2_s_cls_incr/model.py +++ b/src/otx/algorithms/classification/configs/efficientnet_v2_s_cls_incr/model.py @@ -16,4 +16,4 @@ head=dict(type="CustomLinearClsHead", loss=dict(type="CrossEntropyLoss", loss_weight=1.0)), ) -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) diff --git a/src/otx/algorithms/classification/configs/mobilenet_v3_large_1_cls_incr/model.py b/src/otx/algorithms/classification/configs/mobilenet_v3_large_1_cls_incr/model.py index 6441b503ee8..f8cbfe4b01a 100644 --- a/src/otx/algorithms/classification/configs/mobilenet_v3_large_1_cls_incr/model.py +++ b/src/otx/algorithms/classification/configs/mobilenet_v3_large_1_cls_incr/model.py @@ -22,4 +22,4 @@ ), ) -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) diff --git a/src/otx/algorithms/common/adapters/mmcv/configurer.py b/src/otx/algorithms/common/adapters/mmcv/configurer.py index f013c45c7ef..7342735f43d 100644 --- a/src/otx/algorithms/common/adapters/mmcv/configurer.py +++ b/src/otx/algorithms/common/adapters/mmcv/configurer.py @@ -263,7 +263,20 @@ def configure_fp16(cfg: Config): logger.warning("SAMOptimizerHook is not supported on HPU. Changed to OptimizerHook.") opts["type"] = "HPUOptimizerHook" cfg.optimizer_config.update(opts) - elif torch.cuda.is_available() or is_xpu_available(): + elif is_xpu_available(): + opts.update({"distributed": distributed, **fp16_config}) + if optim_type == "SAMOptimizerHook": + logger.warning("SAMOptimizerHook is not supported on XPU yet, changed to OptimizerHook.") + opts["type"] = "OptimizerHook" + if optim_type == "OptimizerHook": + opts["type"] = "BFp16XPUOptimizerHook" + else: + # does not support optimizerhook type + # let mm library handle it + cfg.fp16 = fp16_config + opts = dict() + cfg.optimizer_config.update(opts) + elif torch.cuda.is_available(): opts.update({"distributed": distributed, **fp16_config}) if optim_type == "SAMOptimizerHook": opts["type"] = "Fp16SAMOptimizerHook" diff --git a/src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py b/src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py index a7c41d80fee..4aed0db6e6d 100644 --- a/src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py +++ b/src/otx/algorithms/common/adapters/mmcv/hooks/__init__.py @@ -98,3 +98,10 @@ __all__ += ["HPUOptimizerHook"] except: # noqa: E722 pass + +try: + from .xpu_optimizer_hook import BFp16XPUOptimizerHook + + __all__ += ["BFp16XPUOptimizerHook"] +except: # noqa: E722 + pass diff --git a/src/otx/algorithms/common/adapters/mmcv/hooks/xpu_optimizer_hook.py b/src/otx/algorithms/common/adapters/mmcv/hooks/xpu_optimizer_hook.py new file mode 100644 index 00000000000..2f3bd5d944a --- /dev/null +++ b/src/otx/algorithms/common/adapters/mmcv/hooks/xpu_optimizer_hook.py @@ -0,0 +1,38 @@ +"""Custom Optimizer Hook for mixed precision training on XPU.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Optional, Union + +from mmcv.runner.hooks import HOOKS, Fp16OptimizerHook + +from otx.algorithms.common.adapters.torch.amp import XPUGradScaler + + +@HOOKS.register_module() +class BFp16XPUOptimizerHook(Fp16OptimizerHook): + """Custom Optimizer Hook for mixed & lower precision training on XPU.""" + + def __init__( + self, + grad_clip: Optional[dict] = None, + coalesce: bool = True, + bucket_size_mb: int = -1, + loss_scale: Union[float, str, dict] = 512.0, + distributed: bool = True, + ) -> None: + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + self._scale_update_param = None + if loss_scale == "dynamic": + self.loss_scaler = XPUGradScaler() + elif isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.loss_scaler = XPUGradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.loss_scaler = XPUGradScaler(**loss_scale) + else: + raise ValueError("loss_scale must be of type float, dict, or " f'"dynamic", got {loss_scale}') diff --git a/src/otx/algorithms/common/adapters/mmcv/utils/fp16_utils.py b/src/otx/algorithms/common/adapters/mmcv/utils/fp16_utils.py new file mode 100644 index 00000000000..b5961575db7 --- /dev/null +++ b/src/otx/algorithms/common/adapters/mmcv/utils/fp16_utils.py @@ -0,0 +1,133 @@ +"""Custom fp16 related modules to enable XPU modules.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import functools +from inspect import getfullargspec +from typing import Callable, Iterable, Optional + +import torch +from mmcv.runner.fp16_utils import cast_tensor_type +from mmcv.utils import IS_NPU_AVAILABLE, TORCH_VERSION, digit_version +from torch import nn + +from otx.algorithms.common.utils import is_xpu_available + +try: + if is_xpu_available(): + from torch.xpu.amp import autocast + elif IS_NPU_AVAILABLE: + from torch.npu.amp import autocast + else: + from torch.cuda.amp import autocast +except ImportError: + pass + + +def custom_auto_fp16( + apply_to: Optional[Iterable] = None, + out_fp32: bool = False, + supported_types: tuple = (nn.Module,), +) -> Callable: + """Custom decorator to enable fp16 training automatically on XPU as well.""" + + def auto_fp16_wrapper(old_func: Callable) -> Callable: + @functools.wraps(old_func) + def new_func(*args, **kwargs) -> Callable: + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], supported_types): + raise TypeError( + "@auto_fp16 can only be used to decorate the " f"method of those classes {supported_types}" + ) + if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled): + return old_func(*args, **kwargs) + + target_dtype = torch.bfloat16 if is_xpu_available() else torch.half + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[: len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append(cast_tensor_type(args[i], torch.float, target_dtype)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type(arg_value, torch.float, target_dtype) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if TORCH_VERSION != "parrots" and digit_version(TORCH_VERSION) >= digit_version("1.6.0"): + with autocast(enabled=True): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, target_dtype, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def custom_force_fp32(apply_to: Optional[Iterable] = None, out_fp16: bool = False) -> Callable: + """Custom decorator to convert input arguments to fp32 in force on XPU as well.""" + + def force_fp32_wrapper(old_func): + @functools.wraps(old_func) + def new_func(*args, **kwargs) -> Callable: + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError("@force_fp32 can only be used to decorate the " "method of nn.Module") + if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled): + return old_func(*args, **kwargs) + + source_dtype = torch.bfloat16 if is_xpu_available() else torch.half + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[: len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append(cast_tensor_type(args[i], source_dtype, torch.float)) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type(arg_value, source_dtype, torch.float) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if TORCH_VERSION != "parrots" and digit_version(TORCH_VERSION) >= digit_version("1.6.0"): + with autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, source_dtype) + return output + + return new_func + + return force_fp32_wrapper diff --git a/src/otx/algorithms/common/adapters/torch/amp/__init__.py b/src/otx/algorithms/common/adapters/torch/amp/__init__.py new file mode 100644 index 00000000000..1b0ce69ed34 --- /dev/null +++ b/src/otx/algorithms/common/adapters/torch/amp/__init__.py @@ -0,0 +1,9 @@ +"""Custom AMP (Automatic Mixed Precision package) in OTX.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +try: + from .xpu_grad_scaler import XPUGradScaler # noqa: F401 +except: # noqa: E722 + pass diff --git a/src/otx/algorithms/common/adapters/torch/amp/xpu_grad_scaler.py b/src/otx/algorithms/common/adapters/torch/amp/xpu_grad_scaler.py new file mode 100644 index 00000000000..f3994050cae --- /dev/null +++ b/src/otx/algorithms/common/adapters/torch/amp/xpu_grad_scaler.py @@ -0,0 +1,114 @@ +"""Custom GradScaler to scale loss.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from collections import abc, defaultdict +from typing import List + +import torch +from intel_extension_for_pytorch.cpu.autocast._grad_scaler import _MultiDeviceReplicator +from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state + + +class XPUGradScaler(GradScaler): + """GradScaler for XPU.""" + + def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def scale(self, outputs): + """Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.device.type == "xpu" + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.device.type == "xpu" + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_bf16=False): + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled bf16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.bfloat16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # TODO: is there a way to split by device and dtype without appending in the inner loop? + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, per_device_found_inf.get(device), per_device_inv_scale.get(device) + ) + + return per_device_found_inf._per_device_tensors diff --git a/src/otx/algorithms/common/utils/__init__.py b/src/otx/algorithms/common/utils/__init__.py index 80372c59b4b..fd5bccd3657 100644 --- a/src/otx/algorithms/common/utils/__init__.py +++ b/src/otx/algorithms/common/utils/__init__.py @@ -63,3 +63,15 @@ if is_hpu_available(): os.environ["PT_HPU_LAZY_MODE"] = "1" import habana_frameworks.torch.gpu_migration # noqa: F401 + + +if is_xpu_available(): + try: + import mmcv + + from otx.algorithms.common.adapters.mmcv.utils.fp16_utils import custom_auto_fp16, custom_force_fp32 + + mmcv.runner.auto_fp16 = custom_auto_fp16 + mmcv.runner.force_fp32 = custom_force_fp32 + except ImportError: + pass diff --git a/src/otx/algorithms/detection/adapters/mmdet/apis/train.py b/src/otx/algorithms/detection/adapters/mmdet/apis/train.py index 4565631880d..cd2638ec1ec 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/apis/train.py +++ b/src/otx/algorithms/detection/adapters/mmdet/apis/train.py @@ -142,13 +142,14 @@ def train_detector(model, dataset, cfg, distributed=False, validate=False, times optimizer = build_optimizer(model, cfg.optimizer) if cfg.device == "xpu": - if fp16_cfg is not None: - dtype = torch.bfloat16 - else: - dtype = torch.float32 + dtype = torch.bfloat16 if cfg.optimizer_config.get("bf16_training", False) else torch.float32 model.train() model, optimizer = torch.xpu.optimize(model, optimizer=optimizer, dtype=dtype) + if "bf16_training" in cfg.optimizer_config: + # Remove unused parameters in runner + cfg.optimizer_config.pop("bf16_training") + runner = build_runner( cfg.runner, default_args=dict(model=model, optimizer=optimizer, work_dir=cfg.work_dir, logger=logger, meta=meta) ) diff --git a/src/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_yolox_detector.py b/src/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_yolox_detector.py index b53cf2777db..c8658c489f6 100644 --- a/src/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_yolox_detector.py +++ b/src/otx/algorithms/detection/adapters/mmdet/models/detectors/custom_yolox_detector.py @@ -58,9 +58,6 @@ def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=N def extract_feat(self, img): """Directly extract features from the backbone+neck.""" - # workaround for xpu device, since the input converted to fp16 by mmcv - if "xpu" in str(img.device) and img.dtype == torch.float16: - img = img.to(torch.bfloat16) x = self.backbone(img) if self.with_neck: x = self.neck(x) diff --git a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_l/model.py b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_l/model.py index dd300926d08..2d20a9ede43 100644 --- a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_l/model.py +++ b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_l/model.py @@ -20,5 +20,5 @@ load_from = "https://download.openmmlab.com/mmdetection/v2.0/yolox/\ yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) ignore = False diff --git a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_s/model.py b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_s/model.py index 1a0356f69c5..58114e4f173 100644 --- a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_s/model.py +++ b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_s/model.py @@ -20,5 +20,5 @@ load_from = "https://download.openmmlab.com/mmdetection/v2.0/yolox/\ yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) ignore = False diff --git a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_tiny/model.py b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_tiny/model.py index d8c806e3c57..6d88d486b7e 100644 --- a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_tiny/model.py +++ b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_tiny/model.py @@ -31,5 +31,5 @@ load_from = "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions\ /models/object_detection/v2/yolox_tiny_8x8.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) ignore = False diff --git a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_x/model.py b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_x/model.py index 857021810d1..e8c197507bc 100644 --- a/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_x/model.py +++ b/src/otx/algorithms/detection/configs/detection/cspdarknet_yolox_x/model.py @@ -20,5 +20,5 @@ load_from = "https://download.openmmlab.com/mmdetection/v2.0/yolox\ /yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) ignore = False diff --git a/src/otx/algorithms/detection/configs/detection/mobilenetv2_atss/model.py b/src/otx/algorithms/detection/configs/detection/mobilenetv2_atss/model.py index 2b0e1d9d71c..fccdd79317b 100644 --- a/src/otx/algorithms/detection/configs/detection/mobilenetv2_atss/model.py +++ b/src/otx/algorithms/detection/configs/detection/mobilenetv2_atss/model.py @@ -88,4 +88,4 @@ load_from = "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions\ /models/object_detection/v2/mobilenet_v2-atss.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) diff --git a/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/model.py b/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/model.py index 45847b0b80c..4c88b23d60a 100644 --- a/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/model.py +++ b/src/otx/algorithms/detection/configs/detection/mobilenetv2_ssd/model.py @@ -95,5 +95,5 @@ load_from = "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions\ /models/object_detection/v2/mobilenet_v2-2s_ssd-992x736.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) ignore = False diff --git a/src/otx/algorithms/detection/configs/detection/resnext101_atss/model.py b/src/otx/algorithms/detection/configs/detection/resnext101_atss/model.py index 579e382adb7..53f592fb229 100644 --- a/src/otx/algorithms/detection/configs/detection/resnext101_atss/model.py +++ b/src/otx/algorithms/detection/configs/detection/resnext101_atss/model.py @@ -80,4 +80,4 @@ load_from = "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/\ models/object_detection/v2/resnext101_atss_070623.pth" -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) diff --git a/src/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py b/src/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py index 03cb21733dc..65e751b21dc 100644 --- a/src/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py +++ b/src/otx/algorithms/detection/configs/instance_segmentation/efficientnetb2b_maskrcnn/model.py @@ -127,5 +127,5 @@ v2/efficientnet_b2b-mask_rcnn-576x576.pth" evaluation = dict(interval=1, metric="mAP", save_best="mAP", iou_thr=[0.5]) -fp16 = dict(loss_scale=512.0) +fp16 = dict(loss_scale=512.0, bf16_training=False) ignore = True diff --git a/src/otx/algorithms/detection/configs/instance_segmentation/maskrcnn_swin_t/model.py b/src/otx/algorithms/detection/configs/instance_segmentation/maskrcnn_swin_t/model.py index 203470d2fac..9c41dd65052 100644 --- a/src/otx/algorithms/detection/configs/instance_segmentation/maskrcnn_swin_t/model.py +++ b/src/otx/algorithms/detection/configs/instance_segmentation/maskrcnn_swin_t/model.py @@ -153,7 +153,7 @@ optimizer_config = dict(_delete_=True, grad_clip=None) -fp16 = dict(loss_scale=dict(init_scale=512)) +fp16 = dict(loss_scale=dict(init_scale=512), bf16_training=False) load_from = ( "https://download.openmmlab.com/mmdetection/v2.0/swin/" diff --git a/src/otx/algorithms/segmentation/adapters/mmseg/apis/train.py b/src/otx/algorithms/segmentation/adapters/mmseg/apis/train.py index 6ed2ec50dc8..77aba2c8066 100644 --- a/src/otx/algorithms/segmentation/adapters/mmseg/apis/train.py +++ b/src/otx/algorithms/segmentation/adapters/mmseg/apis/train.py @@ -92,14 +92,14 @@ def train_segmentor(model, dataset, cfg, distributed=False, validate=False, time optimizer = build_optimizer(model, cfg.optimizer) if cfg.device == "xpu": - fp16_cfg = cfg.get("fp16_", None) - if fp16_cfg is not None: - dtype = torch.bfloat16 - else: - dtype = torch.float32 + dtype = torch.bfloat16 if cfg.optimizer_config.get("bf16_training", False) else torch.float32 model.train() model, optimizer = torch.xpu.optimize(model, optimizer=optimizer, dtype=dtype) + if "bf16_training" in cfg.optimizer_config: + # Remove unused parameters in runner + cfg.optimizer_config.pop("bf16_training") + if cfg.get("runner") is None: cfg.runner = {"type": "IterBasedRunner", "max_iters": cfg.total_iters} warnings.warn(