Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update form pretrained to make TP a first class citizen #36335

Merged
merged 79 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
4fdb6a4
clean code
ArthurZucker Feb 21, 2025
fe79aef
Merge branch 'main' of github.com:huggingface/transformers into safe-…
ArthurZucker Feb 21, 2025
7f0ca67
oups
ArthurZucker Feb 21, 2025
db4f78b
fix merge
ArthurZucker Feb 21, 2025
9575332
yups
ArthurZucker Feb 21, 2025
7346376
fix if
ArthurZucker Feb 21, 2025
3fe75a0
now you can play
ArthurZucker Feb 21, 2025
034beab
fix shape issue
ArthurZucker Feb 21, 2025
7be4157
try non blocking
ArthurZucker Feb 21, 2025
82a471f
fix
ArthurZucker Feb 21, 2025
fbf2912
updates
ArthurZucker Feb 21, 2025
60824df
up
ArthurZucker Feb 21, 2025
4981ffc
updates
ArthurZucker Feb 24, 2025
1be42fc
fix most of thetests
ArthurZucker Feb 24, 2025
995f225
update
ArthurZucker Feb 24, 2025
8571401
update
ArthurZucker Feb 24, 2025
b1ee64c
small updates
ArthurZucker Feb 24, 2025
a78842a
up
ArthurZucker Feb 24, 2025
d92b4e3
fix the remaining bug?
ArthurZucker Feb 24, 2025
d573acb
update
ArthurZucker Feb 24, 2025
740b52b
Merge branch 'main' into safe-tensors
ArthurZucker Feb 24, 2025
25bb569
rename when you read from the file
ArthurZucker Feb 24, 2025
efb1116
Merge branch 'safe-tensors' of github.com:huggingface/transformers in…
ArthurZucker Feb 24, 2025
d8aa45f
buffer issues
ArthurZucker Feb 24, 2025
6a35248
current status
ArthurZucker Feb 24, 2025
8e3a6ae
cleanup
ArthurZucker Feb 24, 2025
dfc864f
properly allocate dumb memory
ArthurZucker Feb 24, 2025
a08c849
update a small bug
ArthurZucker Feb 24, 2025
2c7ab61
fix colwise rep issue
ArthurZucker Feb 24, 2025
7efe219
fix keep in float 32 that was keeping everything in float 32
ArthurZucker Feb 24, 2025
179e26b
typo
ArthurZucker Feb 24, 2025
046d6a1
more fixes with keep_in_fp32_modules as we use to serach on it
ArthurZucker Feb 24, 2025
52eda20
fix ROPE dtype for TP
ArthurZucker Feb 24, 2025
ae79fad
remove what's breaking the tests
ArthurZucker Feb 24, 2025
acb45d6
updates
ArthurZucker Feb 24, 2025
93555c0
update and fixes
ArthurZucker Feb 25, 2025
c14eccc
Merge branch 'main' into safe-tensors
ArthurZucker Feb 25, 2025
71fe672
Merge branch 'main' of github.com:huggingface/transformers into safe-…
ArthurZucker Feb 25, 2025
c3c6d85
Merge branch 'safe-tensors' of github.com:huggingface/transformers in…
ArthurZucker Feb 25, 2025
f140cc8
small cleanup after merging
ArthurZucker Feb 25, 2025
1fdd522
allocate 2x to be safe
ArthurZucker Feb 25, 2025
11b1107
style, auto
ArthurZucker Feb 25, 2025
d5c6023
update
ArthurZucker Feb 25, 2025
0b3a18b
yup nit
ArthurZucker Feb 25, 2025
d224cf8
fix
ArthurZucker Feb 25, 2025
f6893c9
remove slow as fuck torch api :(
ArthurZucker Feb 25, 2025
1898823
work
ArthurZucker Feb 25, 2025
4c2087f
fixup
ArthurZucker Feb 25, 2025
3e2526e
update
ArthurZucker Feb 25, 2025
42c6119
brting the fix back
ArthurZucker Feb 25, 2025
4522505
fix and update
ArthurZucker Feb 25, 2025
6b9f243
fixes
ArthurZucker Feb 26, 2025
752bc95
updates because some suggestions were wrong :eyes:
ArthurZucker Feb 26, 2025
a5b84ec
update?
ArthurZucker Feb 26, 2025
b53d381
fuck this bloated function
ArthurZucker Feb 26, 2025
0c4e173
typo
ArthurZucker Feb 26, 2025
4e8a8d5
fix the dumb prefix thing once and forall
ArthurZucker Feb 26, 2025
a9adbeb
fixes here and there
ArthurZucker Feb 26, 2025
feaf7f1
updates
ArthurZucker Feb 26, 2025
a0b7af4
remove prints
ArthurZucker Feb 26, 2025
4e03eb9
Merge branch 'main' of github.com:huggingface/transformers into safe-…
ArthurZucker Feb 26, 2025
0af4c22
fix strict cases
ArthurZucker Feb 26, 2025
366aa1f
styel
ArthurZucker Feb 26, 2025
634016a
properly fix keys on load!
ArthurZucker Feb 26, 2025
95fd001
update
ArthurZucker Feb 26, 2025
640dc38
fix base model prefix issue
ArthurZucker Feb 26, 2025
e6bbf62
style
ArthurZucker Feb 26, 2025
13522e9
update
ArthurZucker Feb 26, 2025
d85e53f
fix all?
ArthurZucker Feb 26, 2025
750c04c
remoce 1 print
ArthurZucker Feb 26, 2025
31b90df
fix the final etsts
ArthurZucker Feb 26, 2025
796cfb7
fixup
ArthurZucker Feb 26, 2025
0e2eca8
last nits
ArthurZucker Feb 26, 2025
7cab57e
fix the detach issue which cause a 2x slowdown
ArthurZucker Feb 26, 2025
3c3a51e
fixup
ArthurZucker Feb 26, 2025
050425c
small fixes
ArthurZucker Feb 26, 2025
2824530
ultra nit
ArthurZucker Feb 26, 2025
62ed1be
fix
ArthurZucker Feb 26, 2025
c049e79
fix
ArthurZucker Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 221 additions & 212 deletions src/transformers/modeling_utils.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/helium/modeling_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/moonshine/modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmo2/modeling_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def forward(self, x, position_ids):
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
Expand Down