From bc34f7fc11d2e767624d8555acd16af246b815be Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Thu, 6 Mar 2025 07:22:55 +0000 Subject: [PATCH] add attn_weights judgment and remove inference code in train_forward --- .../deepseek_v2/modeling_deepseek_v2.py | 59 ++++++++----------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 5d794c9810..a377ddc281 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -1019,6 +1019,7 @@ def train_forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + output_attentions: bool = False, token_idx: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, attn_softmax_bf16: Optional[bool] = False, @@ -1089,8 +1090,22 @@ def train_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # optimization if use_flash_attention and FusedSDPA is not None: - if q_len == 1: - # next token + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + if flash_attention_causal_mask: + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left", + ) + else: attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, @@ -1099,43 +1114,11 @@ def train_forward( 0.0, False, None, - "None", - False, + softmax_mode, + flash_attention_recompute, None, "None", ) - else: - # first token - softmax_mode = "fast" if flash_attention_fast_softmax else "None" - if flash_attention_causal_mask: - attn_output = self.fused_scaled_dot_product_attention( - query_states, - key_states, - value_states, - None, - 0.0, - True, - None, - softmax_mode, - flash_attention_recompute, - valid_sequence_lengths, - "left", - ) - else: - attn_output = self.fused_scaled_dot_product_attention( - query_states, - key_states, - value_states, - attention_mask, - 0.0, - False, - None, - softmax_mode, - flash_attention_recompute, - None, - "None", - ) - else: query_states, key_states, value_states, attention_mask = gaudi_deepseekv2_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -1159,6 +1142,9 @@ def train_forward( attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = self.matmul_av(attn_weights, value_states) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value def prefill_forward( @@ -1428,6 +1414,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + output_attentions=output_attentions, token_idx=token_idx, cache_position=cache_position, attn_softmax_bf16=attn_softmax_bf16,