Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xpu: support xpu backend from stock pytorch (>=2.4) #31238

Merged
merged 2 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
Expand Down Expand Up @@ -219,7 +218,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
Expand Down Expand Up @@ -249,7 +248,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,23 +813,24 @@ def require_torch_multi_npu(test_case):

def require_torch_xpu(test_case):
"""
Decorator marking a test that requires XPU and IPEX.
Decorator marking a test that requires XPU (in PyTorch).

These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
must match match current PyTorch version.
"""
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)


def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires a multi-XPU setup with IPEX and at least one XPU device. These tests are
skipped on a machine without IPEX or multiple XPUs.
Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
multiple XPUs.

To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
"""
if not is_torch_xpu_available():
return unittest.skip("test requires IPEX and at least one XPU device")(test_case)
return unittest.skip("test requires PyTorch XPU")(test_case)

return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ExplicitEnum,
cached_property,
is_accelerate_available,
is_ipex_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -2136,6 +2137,8 @@ def _setup_devices(self) -> "torch.device":
if self.use_cpu:
device = torch.device("cpu")
elif is_torch_xpu_available():
if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"):
raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")
device = torch.device("xpu:0")
torch.xpu.set_device(device)
elif is_torch_mlu_available():
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,13 +747,18 @@ def get_major_and_minor_from_version(full_version):

@lru_cache
def is_torch_xpu_available(check_device=False):
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
if not is_ipex_available():
"""
Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or
via stock PyTorch (>=2.4) and potentially if a XPU is in the environment
"""
if not is_torch_available():
return False

import intel_extension_for_pytorch # noqa: F401
import torch

if is_ipex_available():
import intel_extension_for_pytorch # noqa: F401

if check_device:
try:
# Will raise a RuntimeError if no XPU is found
Expand Down
Loading