From 7f2fdfd53e680ea014772572ef352efdeff260ba Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Tue, 10 Dec 2024 14:12:26 -0800 Subject: [PATCH] Fix Accuracy Calculation Issue in GPT-NeoX (#1591) --- .../models/gpt_neox/modeling_gpt_neox.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index faab1b4abf..57dfca70a0 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -457,19 +457,11 @@ def prepare_inputs_for_generation( return model_inputs -def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() - - def apply_customized_rope(q, k, cos, sin, position_ids, training=True): - if q.device.type == "hpu" and FusedRoPE: - return apply_customized_rope_module(q, k, cos, sin, position_ids, training) + if q.device.type == "hpu" and FusedRoPE is not None: + if training: + return apply_customized_rope_module(q.to(torch.float), k.to(torch.float), cos, sin, position_ids, training) + else: + return apply_customized_rope_module(q, k, cos, sin, position_ids, training) else: return apply_rotary_pos_emb(q.to(torch.float), k.to(torch.float), cos[position_ids], sin[position_ids])