From fb958c71977061fc0761f8f9464a51bce6ed2825 Mon Sep 17 00:00:00 2001 From: shiyu_li Date: Sat, 20 Jul 2024 12:14:48 +0800 Subject: [PATCH 1/4] support 3D/4D attention mask in bert --- src/transformers/models/bert/modeling_bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 33fa431b39a9..9c1cc3614a77 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1093,7 +1093,7 @@ def forward( ) # Expand the attention mask - if use_sdpa_attention_masks: + if use_sdpa_attention_masks and attention_mask.dim() == 2: # Expand the attention mask for SDPA. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] if self.config.is_decoder: @@ -1120,7 +1120,7 @@ def forward( if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - if use_sdpa_attention_masks: + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: # Expand the attention mask for SDPA. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( From 40900295a3ef647f52a9ce87f6459b444670889e Mon Sep 17 00:00:00 2001 From: shiyu_li Date: Sat, 20 Jul 2024 13:06:02 +0800 Subject: [PATCH 2/4] test cases --- tests/models/bert/test_modeling_bert.py | 38 +++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 6ae9f6c279de..ac83011eeefe 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -498,6 +498,14 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_3d_mask_shapes(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # manipulate input_mask + config_and_inputs = list(config_and_inputs) + batch_size, seq_length = config_and_inputs[3].shape + config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length]) + self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) @@ -530,6 +538,36 @@ def test_model_as_decoder_with_default_input_mask(self): encoder_attention_mask, ) + def test_model_as_decoder_with_3d_input_mask(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) = self.model_tester.prepare_config_and_inputs_for_decoder() + + batch_size, seq_length = input_mask.shape + input_mask = random_attention_mask([batch_size, seq_length, seq_length]) + batch_size, seq_length = encoder_attention_mask.shape + encoder_attention_mask = random_attention_mask([batch_size, seq_length, seq_length]) + + self.model_tester.create_and_check_model_as_decoder( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + def test_for_causal_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) From d9a6f6e6c866d54006fbd811d7f9121a50da944c Mon Sep 17 00:00:00 2001 From: shiyu_li Date: Fri, 2 Aug 2024 13:31:40 +0800 Subject: [PATCH 3/4] update doc --- src/transformers/models/bert/modeling_bert.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 9c1cc3614a77..1c79c26b18d8 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1020,10 +1020,15 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: + + - 1 for tokens that are not masked, + - 0 for tokens that are masked. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: From 5771792a7cb2de512b2e5952c023e63e4005c4f9 Mon Sep 17 00:00:00 2001 From: shiyu_li Date: Fri, 2 Aug 2024 13:43:16 +0800 Subject: [PATCH 4/4] fix doc --- src/transformers/models/bert/modeling_bert.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1c79c26b18d8..820d0074d91b 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -908,7 +908,7 @@ class BertForPreTrainingOutput(ModelOutput): [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, @@ -1020,11 +1020,6 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" - attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: - - - 1 for tokens that are not masked, - - 0 for tokens that are masked. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.