-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MusicGen SDPA #31208
Fix MusicGen SDPA #31208
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -572,6 +572,24 @@ def forward( | |||
output_attentions=output_attentions, | |||
) | |||
|
|||
# Ignore copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ignore copy doesn't work, any ideas why @ydshieh ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. Copy and ignore copy only works with a named entity (like function, method, class). It could not be used in a place (block) where no declared name is given. Either refactor the code or just remove the copy statement
Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
if no code re-factorization is possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ydshieh, I'll remove the copy statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into and addressing this!
# Ignore copy | ||
if ( | ||
attention_mask is not None | ||
and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for using finfo here and not just 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention has already been processed and is filled with -inf
of the correspondant dtype!
return super().forward( | ||
hidden_states, | ||
key_value_states=key_value_states, | ||
past_key_value=past_key_value, | ||
attention_mask=attention_mask, | ||
layer_head_mask=layer_head_mask, | ||
output_attentions=output_attentions, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than fallback, I would just raise an exception. Otherwise this expensive check and forward pass can easily go unnoticed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, sdpa
and guidance_scale>1
are used by default, so we'd raise the error almost every time. Also, even if the model uses eager mode for the cross-attention layers (in which the bug happens), it'll still benefit from the speed-up of the self-attention layers.
Should we find a better way of testing the attention mask ? For example, we could raise a warning that it will happens here and here and switch the cross-attention SDPA layers to eager layers by default when it happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I see. I'm a bit concerned about this causing unexpected behaviour, in particular defaulting to eager like this as it's a bit magic. As there's other layers which can still use SDPA, this seems like a pragmatic solution.
Let's leave as-is. If more users raise issues, then we'll have to re-think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for digging into this and fixing!
@@ -545,7 +545,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query | |||
) | |||
|
|||
|
|||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the statement here, just want to make sure that it's okay for you @amyeroberts before merging!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep!
Fix #31189 and #30020.
SDPA produces Nan when given a padding mask that attends to no position at all (see pytorch/pytorch#103749 (comment)).
In the case of Musicgen, it can happen for two reasons:
There might be more elegant way to do this, WDYT @amyeroberts ?