From e85d86398ac92261d3341a846990fe61103a3a9b Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 6 Aug 2024 18:18:58 +0800 Subject: [PATCH] add the missing flash attention test marker (#32419) * add flash attention check * fix * fix * add the missing marker * bug fix * add one more * remove order * add one more --- tests/models/gemma/test_modeling_gemma.py | 2 +- tests/models/llama/test_modeling_llama.py | 1 + tests/models/mistral/test_modeling_mistral.py | 3 ++- tests/models/qwen2/test_modeling_qwen2.py | 1 + tests/models/qwen2_moe/test_modeling_qwen2_moe.py | 1 + tests/models/stablelm/test_modeling_stablelm.py | 2 ++ tests/models/starcoder2/test_modeling_starcoder2.py | 1 + 7 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 46a40c6f6624..831ce1dec69b 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -628,9 +628,9 @@ def test_model_2b_sdpa(self): self.assertEqual(output_text, EXPECTED_TEXTS) - @pytest.mark.flash_attn_test @require_flash_attn @require_read_token + @pytest.mark.flash_attn_test def test_model_2b_flash_attn(self): model_id = "google/gemma-2b" EXPECTED_TEXTS = [ diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index de17c088ed77..a32fa3437e11 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -620,6 +620,7 @@ def test_flash_attn_2_generate_padding_right(self): @require_flash_attn @require_torch_gpu @slow + @pytest.mark.flash_attn_test def test_use_flash_attention_2_true(self): """ NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended. diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 3f47ddde1fa2..0da7ae72add7 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -576,9 +576,10 @@ def test_model_7b_dola_generation(self): backend_empty_cache(torch_device) gc.collect() + @require_flash_attn @require_bitsandbytes @slow - @require_flash_attn + @pytest.mark.flash_attn_test def test_model_7b_long_prompt(self): EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] # An input with 4097 tokens that is above the size of the sliding window diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index fcb7278cd591..4d6c432f2042 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -544,6 +544,7 @@ def test_model_450m_generation(self): @require_bitsandbytes @slow @require_flash_attn + @pytest.mark.flash_attn_test def test_model_450m_long_prompt(self): EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] # An input with 4097 tokens that is above the size of the sliding window diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index 36f1db4693a5..0425172a6fba 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -606,6 +606,7 @@ def test_model_a2_7b_generation(self): @require_bitsandbytes @slow @require_flash_attn + @pytest.mark.flash_attn_test def test_model_a2_7b_long_prompt(self): EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] # An input with 4097 tokens that is above the size of the sliding window diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index b0d7261de645..5f2052a0bed7 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -16,6 +16,7 @@ import unittest +import pytest from parameterized import parameterized from transformers import StableLmConfig, is_torch_available, set_seed @@ -539,6 +540,7 @@ def test_model_tiny_random_stablelm_2_generation(self): @require_bitsandbytes @slow @require_flash_attn + @pytest.mark.flash_attn_test def test_model_3b_long_prompt(self): EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3] input_ids = [306, 338] * 2047 diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index edbc1bce6396..c1c7d45d4f18 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -528,6 +528,7 @@ def test_starcoder2_batched_generation_eager(self): self.assertEqual(EXPECTED_TEXT, output_text) @require_flash_attn + @pytest.mark.flash_attn_test def test_starcoder2_batched_generation_fa2(self): EXPECTED_TEXT = [ "Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on",