-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MusicGen SDPA #31208
Fix MusicGen SDPA #31208
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -545,7 +545,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query | |
) | ||
|
||
|
||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen | ||
class MusicgenSdpaAttention(MusicgenAttention): | ||
def forward( | ||
self, | ||
|
@@ -572,6 +571,23 @@ def forward( | |
output_attentions=output_attentions, | ||
) | ||
|
||
if ( | ||
attention_mask is not None | ||
and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the reason for using finfo here and not just 0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The attention has already been processed and is filled with |
||
): | ||
logger.warning_once( | ||
'`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' | ||
"Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information." | ||
) | ||
return super().forward( | ||
hidden_states, | ||
key_value_states=key_value_states, | ||
past_key_value=past_key_value, | ||
attention_mask=attention_mask, | ||
layer_head_mask=layer_head_mask, | ||
output_attentions=output_attentions, | ||
) | ||
Comment on lines
+582
to
+589
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than fallback, I would just raise an exception. Otherwise this expensive check and forward pass can easily go unnoticed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, Should we find a better way of testing the attention mask ? For example, we could raise a warning that it will happens here and here and switch the cross-attention SDPA layers to eager layers by default when it happens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see. I'm a bit concerned about this causing unexpected behaviour, in particular defaulting to eager like this as it's a bit magic. As there's other layers which can still use SDPA, this seems like a pragmatic solution. Let's leave as-is. If more users raise issues, then we'll have to re-think. |
||
|
||
# if key_value_states are provided this layer is used as a cross-attention layer | ||
# for the decoder | ||
is_cross_attention = key_value_states is not None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the statement here, just want to make sure that it's okay for you @amyeroberts before merging!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep!