Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash-Attn: fix generation when no attention mask or no pading #32241

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,11 @@ def _flash_attention_forward(
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# if position_ids is provided and check not all examples (row) contain only 1 sequence,
# if position_ids is provided and check not all examples (row) contain only 1 sequence, and is in pre-fill/training stage
# then use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all():
elif (
position_ids is not None and not (position_ids[:, -1] == position_ids.size(1) - 1).all() and query_length != 1
):
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
Expand Down
15 changes: 14 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4270,6 +4270,18 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)

# Generate with one batch only to test generation when attention mask will be None
# when real inputs are used, because there is no padding. See issue #32237 for more
dummy_input = dummy_input[:1, ...]
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention mask gets thrown away in the update_causal_mask function, if it's full of ones. Thus we always end up int the case were we have position ids automatically generated, and position_ids[:, -1] == position_ids.size(1) - 1).all().

I should have asked for a test with dummy inputs being packed: batch size one but position ids are gonna be:
0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5 ,6 ,7, 0,1,2,3] . Which I don't think was tested

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was the issue. I added test following the path how the issue was discovered, but surely can be tested within the test_padding_flash_attention_with_position_ids

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OUps, I meant can we update the test to make sure we try a case with packed inputs 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh i see, anyway we have to take some time to sort out tests, including flash-attn, Too many are failing now and who know when it started

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down Expand Up @@ -4342,6 +4354,8 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not 0 in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
Expand All @@ -4356,7 +4370,6 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

assert 0 in inputs_dict["attention_mask"], "assert padding in testing inputs"
# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
Expand Down
Loading