-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compute flops for scaled_dot_product_flash_attention #7871
Comments
Looks the same as #7349. Skip the test :( |
If it is flakiness, the error is huge: from common_extended_utils import get_ops
from torchvision import models
for model_fn, weights in [
(models.vit_b_16, models.ViT_B_16_Weights.IMAGENET1K_V1),
(models.vit_b_16, models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1),
(models.vit_b_32, models.ViT_B_32_Weights.IMAGENET1K_V1),
(models.vit_h_14, models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1),
(models.vit_h_14, models.ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1),
(models.vit_l_16, models.ViT_L_16_Weights.IMAGENET1K_V1),
(models.vit_l_16, models.ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1),
(models.vit_l_32, models.ViT_L_32_Weights.IMAGENET1K_V1),
]:
expected = weights.meta["_ops"]
actual = get_ops(model=model_fn(weights=weights), weight=weights)
exact_float_match = expected == actual
print(f"{model_fn.__name__} / {weights.name}: {actual=}, {expected=}, {exact_float_match=}")
Maybe something was merged into core that actually reduced the number significantly? |
It is definitely not flakiness this time. For some reason, we no longer have values for |
PyTorch core has a new |
We are currently tracking the following ops vision/test/common_extended_utils.py Lines 143 to 153 in c486bb1
We need to add an aten._scaled_dot_product_flash_attention: scaled_dot_product_flash_attention_flop entry there. The function That is quite some work. @NicolasHug are you ok with me disabling this subtest for the offending models for now and fix this after release? |
yeah |
https://github.com/pytorch/vision/actions/runs/5941974400/job/16117254380
Failures start with 9c4f738, which is the first commit that used yesterdays (20230822) PyTorch nightly.
It is not obvious from the error message, but the failure here comes from the fact the number that we have on record, e.g.
vision/torchvision/models/vision_transformer.py
Line 392 in 11e49de
no longer matches what we calculate.
cc @seemethere
The text was updated successfully, but these errors were encountered: