Skip to content

Commit

Permalink
Fix device in rope module when using dynamic updates (#35608)
Browse files Browse the repository at this point in the history
fix rope device
  • Loading branch information
Cyrilvallez authored Jan 13, 2025
1 parent 15bd3e6 commit cd44bdb
Show file tree
Hide file tree
Showing 34 changed files with 105 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/moonshine/modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/olmo2/modeling_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

Expand Down Expand Up @@ -390,6 +393,9 @@ def _longrope_frequency_update(self, position_ids, device):
)
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
else:
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)


Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/phi3/modular_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def _longrope_frequency_update(self, position_ids, device):
)
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
else:
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)

@torch.no_grad()
Expand Down
Loading

0 comments on commit cd44bdb

Please sign in to comment.