Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tests regarding attention types after #35235 #36024

Merged
merged 8 commits into from
Feb 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3872,11 +3872,13 @@ def test_attn_implementation_composite_models(self):
for name, submodule in model.named_modules():
class_name = submodule.__class__.__name__
if (
"SdpaAttention" in class_name
or "SdpaSelfAttention" in class_name
or "FlashAttention" in class_name
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation != "eager"
):
raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}")
raise ValueError(
f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
)

@require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self):
Expand Down Expand Up @@ -3907,8 +3909,14 @@ def test_sdpa_can_dispatch_non_composite_models(self):

for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError(f"The eager model should not have SDPA attention layers but got {class_name}")
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError(
f"The eager model should not have SDPA attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
)

@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
Expand Down Expand Up @@ -3959,7 +3967,11 @@ def test_sdpa_can_dispatch_composite_models(self):

for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers")

@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
Expand Down Expand Up @@ -4446,7 +4458,11 @@ def test_flash_attn_2_can_dispatch_composite_models(self):
has_fa2 = False
for name, submodule in model_fa2.named_modules():
class_name = submodule.__class__.__name__
if "FlashAttention" in class_name:
if (
"Attention" in class_name
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "flash_attention_2"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(implicitly assuming such attention module always have config attribute, otherwise it might suggest something is not propagating correctly to attn module/layer)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's me know if the latest version is what you expect :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if so I will apply similar changes to other places as you suggested

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, LGTM as long as the tests are green :)

):
has_fa2 = True
break
if not has_fa2:
Expand Down