From d83f493f839d93ca10d4d515891389c536a66338 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Fri, 31 May 2024 19:16:32 +0000 Subject: [PATCH] xpu: support xpu backend from stock pytorch (>=2.4) Fixes: https://github.com/huggingface/transformers/issues/31237 XPU backend is available in the stock PyTorch starting from version 2.4, see [1]. This commit extends huggingface accelerate to support XPU from both IPEX and the stock pytorch. IPEX is being tried first. See: https://github.com/pytorch/pytorch/issues/114842 Signed-off-by: Dmitry Rogozhkin --- src/accelerate/accelerator.py | 28 +++++++++++++--------------- src/accelerate/utils/imports.py | 18 ++++++++++++------ 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ef95c5ec049..04edc60f281 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -78,7 +78,6 @@ is_bf16_available, is_deepspeed_available, is_fp8_available, - is_ipex_available, is_lomo_available, is_megatron_lm_available, is_mlu_available, @@ -473,6 +472,8 @@ def __init__( self.scaler = torch.mlu.amp.GradScaler(**kwargs) elif is_npu_available(): self.scaler = torch.npu.amp.GradScaler(**kwargs) + elif is_xpu_available(): + self.scaler = torch.amp.GradScaler('xpu', **kwargs) else: self.scaler = torch.cuda.amp.GradScaler(**kwargs) @@ -1284,9 +1285,9 @@ def prepare(self, *args, device_placement=None): if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: if self.device.type == "cpu" and self.state.use_ipex: - args = self._prepare_ipex(*args) + args = self._prepare_xpu(*args) elif self.device.type == "xpu" and is_xpu_available(): - args = self._prepare_ipex(*args) + args = self._prepare_xpu(*args) if self.distributed_type == DistributedType.DEEPSPEED: result = self._prepare_deepspeed(*args) elif self.distributed_type == DistributedType.MEGATRON_LM: @@ -1887,14 +1888,11 @@ def _prepare_megatron_lm(self, *args): ) return tuple(result) - def _prepare_ipex(self, *args): - if not is_ipex_available(): + def _prepare_xpu(self, *args): + if not is_xpu_available(): raise ImportError( - "IPEX is not installed or IPEX's version does not match current PyTorch version. Please refer" - " to https://github.com/intel/intel-extension-for-pytorch." + "XPU is not available neither from IPEX nor from stock PyTorch (>=2.4)" ) - else: - import intel_extension_for_pytorch as ipex model = None optimizer = None @@ -1907,13 +1905,13 @@ def _prepare_ipex(self, *args): optimizer = obj if optimizer is not None and model is not None: dtype = torch.bfloat16 if self.state.mixed_precision == "bf16" else None - if self.device.type == "xpu" and is_xpu_available(): + if self.device.type == "xpu": model = model.to(self.device) - model, optimizer = torch.xpu.optimize( - model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1" - ) - else: - model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1") + # torch.xpu.optimize is available only for xpu via IPEX + if hasattr(torch.xpu, "optimize"): + model, optimizer = torch.xpu.optimize( + model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1" + ) for i in range(len(result)): if isinstance(result[i], torch.nn.Module): result[i] = model diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 5c204d15d67..630875e1488 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -374,16 +374,22 @@ def is_xpu_available(check_device=False): "check if user disables it explicitly" if not parse_flag_from_env("ACCELERATE_USE_XPU", default=True): return False - "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" - if is_ipex_available(): - import torch + """ + Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or + via stock PyTorch (>=2.4) and potentially if a XPU is in the environment + """ + if importlib.util.find_spec("torch") is None: + return False + import torch + if is_ipex_available(): if is_torch_version("<=", "1.12"): return False - else: - return False - import intel_extension_for_pytorch # noqa: F401 + import intel_extension_for_pytorch # noqa: F401 + else: + if is_torch_version("<=", "2.3"): + return False if check_device: try: