From 3a00f83f2f8411a97941cb2f93c7f9c11aeedd2b Mon Sep 17 00:00:00 2001 From: JB Lau <1557853+hackyon@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:59:29 -0400 Subject: [PATCH 01/10] Adding SDPA support for RoBERTa-based models --- docs/source/en/perf_infer_gpu_one.md | 7 +- .../bridgetower/modeling_bridgetower.py | 2 +- .../models/camembert/modeling_camembert.py | 173 ++++++++++++++-- .../models/roberta/modeling_roberta.py | 183 +++++++++++++--- .../modeling_roberta_prelayernorm.py | 1 - .../xlm_roberta/modeling_xlm_roberta.py | 182 +++++++++++++--- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 195 +++++++++++++++--- .../test_modeling_xlm_roberta_xl.py | 79 ++++++- utils/check_support_list.py | 3 +- 9 files changed, 733 insertions(+), 92 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index de49d4427b56..53f4f167255d 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -192,6 +192,7 @@ PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.o For now, Transformers supports SDPA inference and training for the following architectures: * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) +* [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) @@ -217,8 +218,10 @@ For now, Transformers supports SDPA inference and training for the following arc * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) -* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) - +* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) +* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) +* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) +* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 3fc9f755aab9..04b7fc9b2bdd 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1066,7 +1066,7 @@ class PreTrainedModel for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index f399fb3f5cfb..ca5037f1d8c6 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -20,10 +20,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,6 +45,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -297,6 +303,104 @@ def forward( return outputs +# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->Camembert +class CamembertSdpaSelfAttention(CamembertSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from CamembertSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "CamembertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case tgt_len == 1. + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->Camembert class CamembertSelfOutput(nn.Module): def __init__(self, config): @@ -314,6 +418,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to CAMEMBERT_SELF_ATTENTION_CLASSES = { "eager": CamembertSelfAttention, + "sdpa": CamembertSdpaSelfAttention, } @@ -606,6 +711,7 @@ class CamembertPreTrainedModel(PreTrainedModel): config_class = CamembertConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -752,7 +858,7 @@ class CamembertModel(CamembertPreTrainedModel): _no_split_modules = [] - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert + # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.__init__ with Roberta->Camembert def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -762,6 +868,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = CamembertPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -785,7 +894,7 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward + # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -849,9 +958,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -860,9 +966,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -871,7 +1011,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -882,13 +1030,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 640139212081..07cfce730da7 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -20,10 +20,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,6 +45,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -279,6 +285,104 @@ def forward( return outputs +# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta +class RobertaSdpaSelfAttention(RobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from RobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case tgt_len == 1. + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.bert.modeling_bert.BertSelfOutput class RobertaSelfOutput(nn.Module): def __init__(self, config): @@ -296,6 +400,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to ROBERTA_SELF_ATTENTION_CLASSES = { "eager": RobertaSelfAttention, + "sdpa": RobertaSdpaSelfAttention, } @@ -588,7 +693,8 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True - _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention"] + _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"] + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -679,23 +785,20 @@ def _init_weights(self, module): "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", ROBERTA_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->Roberta, BERT->ROBERTA class RobertaModel(RobertaPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is - all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. - - .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 - """ - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Roberta def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -705,6 +808,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = RobertaPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -728,7 +834,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -792,9 +897,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -803,9 +905,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -814,7 +950,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -825,13 +969,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 468cb1a243ca..df328223e913 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -572,7 +572,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel with Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 48c6898811d1..b5225c1b10f4 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -20,10 +20,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -40,6 +45,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -280,6 +286,104 @@ def forward( return outputs +# Copied from transformers.models.roberta.modeling_roberta.RobertaSdpaSelfAttention with Roberta->XLMRoberta +class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from XLMRobertaSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case tgt_len == 1. + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta class XLMRobertaSelfOutput(nn.Module): def __init__(self, config): @@ -297,6 +401,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to XLM_ROBERTA_SELF_ATTENTION_CLASSES = { "eager": XLMRobertaSelfAttention, + "sdpa": XLMRobertaSdpaSelfAttention, } @@ -590,7 +695,8 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): config_class = XLMRobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True - _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"] + _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"] + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -685,19 +791,15 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is - all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. - - .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 - """ - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -707,6 +809,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -730,7 +835,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -794,9 +898,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -805,9 +906,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -816,7 +951,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -827,13 +970,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index d8994e335b12..b0e9b5148f1d 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -19,10 +19,15 @@ import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -39,6 +44,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -277,6 +283,104 @@ def forward( return outputs +# Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->XLMRobertaXL +class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from XLMRobertaXLSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "XLMRobertaXLSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case tgt_len == 1. + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + class XLMRobertaXLSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -290,11 +394,19 @@ def forward(self, hidden_states, input_tensor): return hidden_states +XLMROBERTAXL_SELF_ATTENTION_CLASSES = { + "eager": XLMRobertaXLSelfAttention, + "sdpa": XLMRobertaXLSdpaSelfAttention, +} + + class XLMRobertaXLAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self = XLMRobertaXLSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = XLMROBERTAXL_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = XLMRobertaXLSelfOutput(config) self.pruned_heads = set() @@ -575,6 +687,7 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel): config_class = XLMRobertaXLConfig base_model_prefix = "roberta" + _supports_sdpa = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -653,18 +766,20 @@ def _init_weights(self, module): "The bare XLM-RoBERTa-XL Model transformer outputting raw hidden-states without any specific head on top.", XLM_ROBERTA_XL_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->XLMRobertaXL, BERT->XLM_ROBERTA_XL class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in *Attention is - all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz - Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` - argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with - both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as - an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -674,6 +789,9 @@ def __init__(self, config, add_pooling_layer=True): self.pooler = XLMRobertaXLPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -697,7 +815,6 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -761,9 +878,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -772,9 +886,43 @@ def forward( else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -783,7 +931,15 @@ def forward( encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -794,13 +950,6 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index 828d6a02a6a3..c217bb2c891d 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -14,10 +14,11 @@ # limitations under the License. +import tempfile import unittest -from transformers import XLMRobertaXLConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers import XLMRobertaXLConfig, is_torch_available, set_seed +from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -515,6 +516,80 @@ def test_create_position_ids_from_inputs_embeds(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + # This test was copied from the common test_eager_matches_sdpa_generate(), but without low_cpu_mem_usage=True. + # TODO: Remove this and use the parent method (in common tests) once XLM RoBERTa XL supports low_cpu_mem_usage=True. + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + set_seed(0) + max_new_tokens = 30 + + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch class XLMRobertaModelXLIntegrationTest(unittest.TestCase): diff --git a/utils/check_support_list.py b/utils/check_support_list.py index f6aaa2bb67dc..c670dc41ad2b 100644 --- a/utils/check_support_list.py +++ b/utils/check_support_list.py @@ -69,6 +69,7 @@ def check_sdpa_support_list(): "For now, Transformers supports SDPA inference and training for the following architectures:" )[1] doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0] + doctext = doctext.lower() patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py")) patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py")) @@ -84,7 +85,7 @@ def check_sdpa_support_list(): archs_supporting_sdpa.append(model_name) for arch in archs_supporting_sdpa: - if arch not in doctext: + if not any(term in doctext for term in [arch, arch.replace("_", "-"), arch.replace("_", " ")]): raise ValueError( f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation." ) From 715f2149cda73fa53ef105bb659cdec4f53ffab0 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 10:34:52 +0200 Subject: [PATCH 02/10] add not is_cross_attention --- src/transformers/models/camembert/modeling_camembert.py | 2 +- src/transformers/models/roberta/modeling_roberta.py | 2 +- src/transformers/models/xlm_roberta/modeling_xlm_roberta.py | 2 +- .../models/xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index ca5037f1d8c6..9e39ac4e0074 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -381,7 +381,7 @@ def forward( # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 07cfce730da7..f0e401a471b0 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -363,7 +363,7 @@ def forward( # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index b5225c1b10f4..77dd6123ae6d 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -364,7 +364,7 @@ def forward( # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index b0e9b5148f1d..bcc5e9599727 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -361,7 +361,7 @@ def forward( # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, From 1c9e1911019ab12af284c01ee3b7e449afff6a89 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 10:59:33 +0200 Subject: [PATCH 03/10] fix copies --- .../models/camembert/modeling_camembert.py | 10 +++++++--- .../models/roberta/modeling_roberta.py | 12 +++++++++--- .../models/xlm_roberta/modeling_xlm_roberta.py | 12 +++++++++--- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 15 +++++++++------ 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 8e6e87787c8a..d3d637826d47 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -376,9 +376,13 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal - # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index b43c53869329..93393081dec6 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -358,9 +358,13 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal - # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, @@ -796,6 +800,8 @@ class RobertaModel(RobertaPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ + _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index e0e62e49c471..3349cd2c858a 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -359,9 +359,13 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal - # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, @@ -797,6 +801,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ + _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaLayer"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 37dc4e0205b3..6f711cf2d9ed 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -356,9 +356,13 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal - # mask in case tgt_len == 1. - is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 and not is_cross_attention + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, @@ -778,6 +782,8 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ + _no_split_modules = ["XLMRobertaXLEmbeddings", "XLMRobertaXLLayer"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -876,9 +882,6 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] From 575fd79804deabab571b71fd24ee67a82bcf1b9b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:29:05 +0200 Subject: [PATCH 04/10] fix test --- src/transformers/pipelines/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 81d12459edd1..09f77402a143 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -850,7 +850,7 @@ def __init__( or is_torch_xpu_available(check_device=True) or is_torch_mps_available() ): - logging.warning( + logger.warning( "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument" " is passed to the `Pipeline` object. Model will be on CPU." ) From 759508c8062c9d7202258e78719d947ed74a158d Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 14:04:57 +0200 Subject: [PATCH 05/10] add minimal test for camembert and xlm_roberta as their test class does not inherit from ModelTesterMixin --- tests/models/camembert/test_modeling_camembert.py | 11 ++++++++++- tests/models/xlm_roberta/test_modeling_xlm_roberta.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/models/camembert/test_modeling_camembert.py b/tests/models/camembert/test_modeling_camembert.py index f2fba59496da..8919c2685fed 100644 --- a/tests/models/camembert/test_modeling_camembert.py +++ b/tests/models/camembert/test_modeling_camembert.py @@ -17,6 +17,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device +from transformers.utils.import_utils import is_torch_sdpa_available if is_torch_available(): @@ -31,7 +32,7 @@ class CamembertModelIntegrationTest(unittest.TestCase): @slow def test_output_embeds_base_model(self): - model = CamembertModel.from_pretrained("almanach/camembert-base") + model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="eager") model.to(torch_device) input_ids = torch.tensor( @@ -54,3 +55,11 @@ def test_output_embeds_base_model(self): # expected_slice = roberta.model.forward(input_ids)[0][:, :3, :3].detach() self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + if is_torch_sdpa_available(): + model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="sdpa").to( + torch_device + ) + with torch.no_grad(): + output_sdpa = model(input_ids)["last_hidden_state"].detach() + self.assertTrue(torch.allclose(output, output_sdpa, atol=1e-3)) diff --git a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py index d9b69bb9ab5f..14465ba1b2f2 100644 --- a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py +++ b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py @@ -18,6 +18,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow +from transformers.utils.import_utils import is_torch_sdpa_available if is_torch_available(): @@ -32,7 +33,7 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase): @slow def test_xlm_roberta_base(self): - model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base") + model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="eager") input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]) # The dog is cute and lives in the garden house @@ -49,6 +50,14 @@ def test_xlm_roberta_base(self): # compare the actual values for a slice of last dim self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) + if is_torch_sdpa_available(): + model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="sdpa") + with torch.no_grad(): + output_sdpa = model(input_ids)["last_hidden_state"].detach() + self.assertEqual(output.shape, expected_output_shape) + # compare the actual values for a slice of last dim + self.assertTrue(torch.allclose(output, output_sdpa, atol=1e-3)) + @slow def test_xlm_roberta_large(self): model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-large") From c8970ddbb152d3455c5ee29951511f0f74d35c07 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:21:56 +0200 Subject: [PATCH 06/10] address some review comments --- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 1 + .../camembert/test_modeling_camembert.py | 37 ++++++++++++++----- .../xlm_roberta/test_modeling_xlm_roberta.py | 32 +++++++++++----- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 6f711cf2d9ed..82ed47166ca0 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -1005,6 +1005,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) diff --git a/tests/models/camembert/test_modeling_camembert.py b/tests/models/camembert/test_modeling_camembert.py index 8919c2685fed..f779c3a80909 100644 --- a/tests/models/camembert/test_modeling_camembert.py +++ b/tests/models/camembert/test_modeling_camembert.py @@ -16,8 +16,14 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device -from transformers.utils.import_utils import is_torch_sdpa_available +from transformers.testing_utils import ( + require_sentencepiece, + require_tokenizers, + require_torch, + require_torch_sdpa, + slow, + torch_device, +) if is_torch_available(): @@ -56,10 +62,23 @@ def test_output_embeds_base_model(self): self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) - if is_torch_sdpa_available(): - model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="sdpa").to( - torch_device - ) - with torch.no_grad(): - output_sdpa = model(input_ids)["last_hidden_state"].detach() - self.assertTrue(torch.allclose(output, output_sdpa, atol=1e-3)) + @slow + @require_torch_sdpa + def test_output_embeds_base_model_sdpa(self): + input_ids = torch.tensor( + [[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], + device=torch_device, + dtype=torch.long, + ) # J'aime le camembert ! + + expected_slice = torch.tensor( + [[[-0.0254, 0.0235, 0.1027], [0.0606, -0.1811, -0.0418], [-0.1561, -0.1127, 0.2687]]], + device=torch_device, + dtype=torch.float, + ) + + model = CamembertModel.from_pretrained("almanach/camembert-base", attn_implementation="sdpa").to(torch_device) + with torch.no_grad(): + output = model(input_ids)["last_hidden_state"].detach() + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py index 14465ba1b2f2..f8ec1f5b7671 100644 --- a/tests/models/xlm_roberta/test_modeling_xlm_roberta.py +++ b/tests/models/xlm_roberta/test_modeling_xlm_roberta.py @@ -17,8 +17,13 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow -from transformers.utils.import_utils import is_torch_sdpa_available +from transformers.testing_utils import ( + require_sentencepiece, + require_tokenizers, + require_torch, + require_torch_sdpa, + slow, +) if is_torch_available(): @@ -50,13 +55,22 @@ def test_xlm_roberta_base(self): # compare the actual values for a slice of last dim self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) - if is_torch_sdpa_available(): - model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="sdpa") - with torch.no_grad(): - output_sdpa = model(input_ids)["last_hidden_state"].detach() - self.assertEqual(output.shape, expected_output_shape) - # compare the actual values for a slice of last dim - self.assertTrue(torch.allclose(output, output_sdpa, atol=1e-3)) + @require_torch_sdpa + def test_xlm_roberta_base_sdpa(self): + input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]]) + # The dog is cute and lives in the garden house + + expected_output_shape = torch.Size((1, 12, 768)) # batch_size, sequence_length, embedding_vector_dim + expected_output_values_last_dim = torch.tensor( + [[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]] + ) + + model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base", attn_implementation="sdpa") + with torch.no_grad(): + output = model(input_ids)["last_hidden_state"].detach() + self.assertEqual(output.shape, expected_output_shape) + # compare the actual values for a slice of last dim + self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3)) @slow def test_xlm_roberta_large(self): From 0d3ea404412b1991ae34d8b435f0af82a686b73b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:36:24 +0200 Subject: [PATCH 07/10] use copied from --- .../xlm_roberta_xl/test_modeling_xlm_roberta_xl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index d195d080b656..f72a5b944564 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -516,12 +516,14 @@ def test_create_position_ids_from_inputs_embeds(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) - # This test was copied from the common test_eager_matches_sdpa_generate(), but without low_cpu_mem_usage=True. # TODO: Remove this and use the parent method (in common tests) once XLM RoBERTa XL supports low_cpu_mem_usage=True. @require_torch_sdpa @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_generate def test_eager_matches_sdpa_generate(self): - set_seed(0) + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + max_new_tokens = 30 if len(self.all_generative_model_classes) == 0: @@ -548,18 +550,20 @@ def test_eager_matches_sdpa_generate(self): dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # Ignore copy model_sdpa = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - # low_cpu_mem_usage=True, + low_cpu_mem_usage=False, ).to(torch_device) self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + # Ignore copy model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - # low_cpu_mem_usage=True, + low_cpu_mem_usage=False, attn_implementation="eager", ).to(torch_device) @@ -590,7 +594,6 @@ def test_eager_matches_sdpa_generate(self): self.assertTrue(torch.allclose(res_eager, res_sdpa)) - @require_torch class XLMRobertaModelXLIntegrationTest(unittest.TestCase): @slow From bed22497448c7a69d0426ff31939d7ab95ebd108 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:43:25 +0200 Subject: [PATCH 08/10] style --- tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py index f72a5b944564..a73f5618ff7e 100644 --- a/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py +++ b/tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py @@ -17,7 +17,7 @@ import tempfile import unittest -from transformers import XLMRobertaXLConfig, is_torch_available, set_seed +from transformers import XLMRobertaXLConfig, is_torch_available from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -594,6 +594,7 @@ def test_eager_matches_sdpa_generate(self): self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch class XLMRobertaModelXLIntegrationTest(unittest.TestCase): @slow From cff2fdad37ef9fe4aef5a5e9029fc45b3dbaff69 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:29:31 +0200 Subject: [PATCH 09/10] consistency --- docs/source/en/perf_infer_gpu_one.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 40603db8b347..2c7b1b9b0750 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -79,6 +79,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) +* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) From 9b817cec5b76fb7670f0e7109493dc80e2a80ee2 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:24:22 +0200 Subject: [PATCH 10/10] fix lists --- docs/source/en/perf_infer_gpu_one.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 2c7b1b9b0750..f6afc8cf1341 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -209,6 +209,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) * [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) +* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) @@ -229,7 +230,8 @@ For now, Transformers supports SDPA inference and training for the following arc * [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) * [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) -* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) +* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaModel) * [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel) * [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel) * [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)