Skip to content

Commit

Permalink
Fix MusicGen SDPA (#31208)
Browse files Browse the repository at this point in the history
* fix sdpa musicgen

* make style

* remove copied from statement from Musicgen SDPA
  • Loading branch information
ylacombe authored Jun 14, 2024
1 parent 833fc17 commit 43ee585
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
)


# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
class MusicgenSdpaAttention(MusicgenAttention):
def forward(
self,
Expand All @@ -572,6 +571,23 @@ def forward(
output_attentions=output_attentions,
)

if (
attention_mask is not None
and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()
):
logger.warning_once(
'`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information."
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
Expand Down

0 comments on commit 43ee585

Please sign in to comment.