Skip to content

Commit

Permalink
xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend
Browse files Browse the repository at this point in the history
Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
  • Loading branch information
dvrogozh authored and ydshieh committed Jun 14, 2024
1 parent 0c8a30c commit 5ecf513
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
Expand Down Expand Up @@ -219,7 +218,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
Expand Down Expand Up @@ -249,7 +248,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down

0 comments on commit 5ecf513

Please sign in to comment.