Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore non-causal mask in more cases with SDPA #30138

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len:
`(batch_size, key_value_length)`

Args:
mask (`torch.Tensor` or `None`):
mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
Expand All @@ -429,36 +429,25 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
`(batch_size, key_value_length)`

Args:
mask (`torch.Tensor` or `None`):
mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
batch_size, key_value_length = mask.shape
_, key_value_length = mask.shape
Copy link
Contributor

@minostauros minostauros Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the input arg mask: torch.Tensor to mask: Optional[torch.Tensor] and return None immediately if mask is None? The docstring is not compliant with the actual input. (mask ("torch.Tensor" or "None" ):)
Will it break the is_tracing check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@minostauros Yes indeed ideally we would want to do that. In practice, the calls to these functions in modeling files are always guarded by:

if attention_mask is not None:

but we should IMO indeed accept Optional[torch.Tensor]. I'll leave that to an other PR.

tgt_len = tgt_len if tgt_len is not None else key_value_length

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
if not is_tracing and torch.all(mask == 1):
if tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:
return None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
return None
else:
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)

Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def forward(
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = True if self.is_decoder and attention_mask is None and tgt_len > 1 else False
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)
Comment on lines +435 to +437
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @hackyon


attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
Expand Down
35 changes: 34 additions & 1 deletion tests/models/bert/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tempfile
import unittest

from transformers import BertConfig, is_torch_available
from transformers import AutoTokenizer, BertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
CaptureLogger,
Expand Down Expand Up @@ -747,3 +747,36 @@ def test_inference_no_head_relative_embedding_key_query(self):
)

self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))

def test_sdpa_ignored_mask(self):
pkv = []

model = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager")
model_sdpa = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="sdpa")

model = model.eval()
model_sdpa = model_sdpa.eval()

for _ in range(model.config.num_hidden_layers):
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
pkv.append([torch.rand(1, num_heads, 3, head_dim), torch.rand(1, num_heads, 3, head_dim)])

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
inp = tokenizer("I am in Paris and", return_tensors="pt")

del inp["attention_mask"]

with torch.no_grad():
res_eager = model(**inp)
res_sdpa = model_sdpa(**inp)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)

# Case where query length != kv_length.
res_eager = model(**inp, past_key_values=pkv)
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
self.assertTrue(
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
)
Loading