diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index faab1b4abf..57dfca70a0 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -457,19 +457,11 @@ def prepare_inputs_for_generation( return model_inputs -def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - def apply_customized_rope(q, k, cos, sin, position_ids, training=True): - if q.device.type == "hpu" and FusedRoPE: - return apply_customized_rope_module(q, k, cos, sin, position_ids, training) + if q.device.type == "hpu" and FusedRoPE is not None: + if training: + return apply_customized_rope_module(q.to(torch.float), k.to(torch.float), cos, sin, position_ids, training) + else: + return apply_customized_rope_module(q, k, cos, sin, position_ids, training) else: return apply_rotary_pos_emb(q.to(torch.float), k.to(torch.float), cos[position_ids], sin[position_ids])