diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 36a1db65dbe1..7fedc4e75441 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4459,7 +4459,8 @@ def test_flash_attn_2_can_dispatch_composite_models(self): for name, submodule in model_fa2.named_modules(): class_name = submodule.__class__.__name__ if ( - class_name.endswith("Attention") + "Attention" in class_name + and getattr(submodule, "config", None) and submodule.config._attn_implementation == "flash_attention_2" ): has_fa2 = True