Skip to content

Commit

Permalink
add attn_weights judgment and remove inference code in train_forward
Browse files Browse the repository at this point in the history
  • Loading branch information
ranzhejiang committed Mar 6, 2025
1 parent 1dc414b commit bc34f7f
Showing 1 changed file with 23 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ def train_forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
token_idx: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attn_softmax_bf16: Optional[bool] = False,
Expand Down Expand Up @@ -1089,8 +1090,22 @@ def train_forward(
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# optimization
if use_flash_attention and FusedSDPA is not None:
if q_len == 1:
# next token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
Expand All @@ -1099,43 +1114,11 @@ def train_forward(
0.0,
False,
None,
"None",
False,
softmax_mode,
flash_attention_recompute,
None,
"None",
)
else:
# first token
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
if flash_attention_causal_mask:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
None,
"None",
)

else:
query_states, key_states, value_states, attention_mask = gaudi_deepseekv2_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand All @@ -1159,6 +1142,9 @@ def train_forward(
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = self.matmul_av(attn_weights, value_states)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

def prefill_forward(
Expand Down Expand Up @@ -1428,6 +1414,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
token_idx=token_idx,
cache_position=cache_position,
attn_softmax_bf16=attn_softmax_bf16,
Expand Down

0 comments on commit bc34f7f

Please sign in to comment.