diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index fb85d018c9f9..9340dbe9f6cb 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -413,7 +413,7 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: `(batch_size, key_value_length)` Args: - mask (`torch.Tensor` or `None`): + mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` dtype (`torch.dtype`): The torch dtype the created mask shall have. @@ -429,36 +429,25 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, `(batch_size, key_value_length)` Args: - mask (`torch.Tensor` or `None`): + mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` dtype (`torch.dtype`): The torch dtype the created mask shall have. tgt_len (`int`): The target length or query length the created mask shall have. """ - batch_size, key_value_length = mask.shape + _, key_value_length = mask.shape tgt_len = tgt_len if tgt_len is not None else key_value_length - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` - # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). is_tracing = ( torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) + # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. if not is_tracing and torch.all(mask == 1): - if tgt_len == 1: - # For query_length == 1, causal attention and bi-directional attention are the same. - return None - elif key_value_length == tgt_len: - return None - else: - # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + return None else: return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 957944435b85..33fa431b39a9 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -432,7 +432,9 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create # a causal mask in case tgt_len == 1. - is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index ff5f65f26b90..8b2dbc3634ba 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -16,7 +16,7 @@ import tempfile import unittest -from transformers import BertConfig, is_torch_available +from transformers import AutoTokenizer, BertConfig, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import ( CaptureLogger, @@ -747,3 +747,36 @@ def test_inference_no_head_relative_embedding_key_query(self): ) self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4)) + + def test_sdpa_ignored_mask(self): + pkv = [] + + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager") + model_sdpa = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="sdpa") + + model = model.eval() + model_sdpa = model_sdpa.eval() + + for _ in range(model.config.num_hidden_layers): + num_heads = model.config.num_attention_heads + head_dim = model.config.hidden_size // model.config.num_attention_heads + pkv.append([torch.rand(1, num_heads, 3, head_dim), torch.rand(1, num_heads, 3, head_dim)]) + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel") + inp = tokenizer("I am in Paris and", return_tensors="pt") + + del inp["attention_mask"] + + with torch.no_grad(): + res_eager = model(**inp) + res_sdpa = model_sdpa(**inp) + self.assertTrue( + torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) + ) + + # Case where query length != kv_length. + res_eager = model(**inp, past_key_values=pkv) + res_sdpa = model_sdpa(**inp, past_key_values=pkv) + self.assertTrue( + torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4) + )