From 59e28c30fa3a91213f569bccef73f082afa8c656 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 10 Jan 2025 11:49:12 +0100 Subject: [PATCH] Fix flex_attention in training mode (#35605) * fix flex * add test * style --- src/transformers/integrations/flex_attention.py | 2 +- tests/test_modeling_common.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 66ffc563883..53bd80bf045 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -27,7 +27,7 @@ def causal_mod(score, b, h, q_idx, kv_idx): if softcap is not None: score = softcap * torch.tanh(score / softcap) if causal_mask is not None: - score += causal_mask[b][0][q_idx][kv_idx] + score = score + causal_mask[b][0][q_idx][kv_idx] return score attn_output, attention_weights = flex_attention( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c29a15efd33..c57ef6ed05f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4790,6 +4790,19 @@ def test_forward_with_num_logits_to_keep(self): # Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops) self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5)) + @require_torch_gpu + def test_flex_attention_with_grads(self): + for model_class in self.all_model_classes: + if not model_class._supports_flex_attn: + self.skipTest(reason="This model does not support flex attention") + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flex_attention" + model = model_class(config).to(device=torch_device, dtype=torch.float16) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) + _ = model(inputs_dict["input_ids"].to(torch_device)) + global_rng = random.Random()