diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3115cee78f76..0ce41d6d1318 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -135,7 +135,7 @@ def _dynamic_frequency_update(self, position_ids, device): self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.register_buffer("inv_freq", self.original_inv_freq.to(device), persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad()