Skip to content

Commit

Permalink
Fix device placement of inv_freq after reset
Browse files Browse the repository at this point in the history
  • Loading branch information
trevor-m committed Oct 7, 2024
1 parent 2e11342 commit 1a7e62a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _dynamic_frequency_update(self, position_ids, device):
self.max_seq_len_cached = seq_len

if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.register_buffer("inv_freq", self.original_inv_freq.to(device), persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@torch.no_grad()
Expand Down

0 comments on commit 1a7e62a

Please sign in to comment.