Skip to content

Commit

Permalink
xpu: support xpu backend from stock pytorch (>=2.4)
Browse files Browse the repository at this point in the history
Fixes: huggingface/transformers#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface accelerate
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
  • Loading branch information
dvrogozh committed Jun 5, 2024
1 parent b7fa2fa commit d83f493
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 21 deletions.
28 changes: 13 additions & 15 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
is_bf16_available,
is_deepspeed_available,
is_fp8_available,
is_ipex_available,
is_lomo_available,
is_megatron_lm_available,
is_mlu_available,
Expand Down Expand Up @@ -473,6 +472,8 @@ def __init__(
self.scaler = torch.mlu.amp.GradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
elif is_xpu_available():
self.scaler = torch.amp.GradScaler('xpu', **kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)

Expand Down Expand Up @@ -1284,9 +1285,9 @@ def prepare(self, *args, device_placement=None):

if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if self.device.type == "cpu" and self.state.use_ipex:
args = self._prepare_ipex(*args)
args = self._prepare_xpu(*args)
elif self.device.type == "xpu" and is_xpu_available():
args = self._prepare_ipex(*args)
args = self._prepare_xpu(*args)
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
elif self.distributed_type == DistributedType.MEGATRON_LM:
Expand Down Expand Up @@ -1887,14 +1888,11 @@ def _prepare_megatron_lm(self, *args):
)
return tuple(result)

def _prepare_ipex(self, *args):
if not is_ipex_available():
def _prepare_xpu(self, *args):
if not is_xpu_available():
raise ImportError(
"IPEX is not installed or IPEX's version does not match current PyTorch version. Please refer"
" to https://github.com/intel/intel-extension-for-pytorch."
"XPU is not available neither from IPEX nor from stock PyTorch (>=2.4)"
)
else:
import intel_extension_for_pytorch as ipex

model = None
optimizer = None
Expand All @@ -1907,13 +1905,13 @@ def _prepare_ipex(self, *args):
optimizer = obj
if optimizer is not None and model is not None:
dtype = torch.bfloat16 if self.state.mixed_precision == "bf16" else None
if self.device.type == "xpu" and is_xpu_available():
if self.device.type == "xpu":
model = model.to(self.device)
model, optimizer = torch.xpu.optimize(
model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1"
)
else:
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1")
# torch.xpu.optimize is available only for xpu via IPEX
if hasattr(torch.xpu, "optimize"):
model, optimizer = torch.xpu.optimize(
model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1"
)
for i in range(len(result)):
if isinstance(result[i], torch.nn.Module):
result[i] = model
Expand Down
18 changes: 12 additions & 6 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,16 +374,22 @@ def is_xpu_available(check_device=False):
"check if user disables it explicitly"
if not parse_flag_from_env("ACCELERATE_USE_XPU", default=True):
return False
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
if is_ipex_available():
import torch
"""
Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or
via stock PyTorch (>=2.4) and potentially if a XPU is in the environment
"""
if importlib.util.find_spec("torch") is None:
return False

import torch
if is_ipex_available():
if is_torch_version("<=", "1.12"):
return False
else:
return False

import intel_extension_for_pytorch # noqa: F401
import intel_extension_for_pytorch # noqa: F401
else:
if is_torch_version("<=", "2.3"):
return False

if check_device:
try:
Expand Down

0 comments on commit d83f493

Please sign in to comment.