diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 4d7e3320be..5d794c9810 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -1203,7 +1203,16 @@ def prefill_forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) # update & get all compressed_kv, k_pe @@ -1247,7 +1256,7 @@ def prefill_forward( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) - + assert attention_mask is not None if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( @@ -1360,14 +1369,14 @@ def decode_forward( torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT) ) * self.softmax_scale + assert attention_mask is not None if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) - assert attention_mask is not None - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + + attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_nope.dtype)