Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute flops for scaled_dot_product_flash_attention #7871

Open
pmeier opened this issue Aug 23, 2023 · 6 comments
Open

compute flops for scaled_dot_product_flash_attention #7871

pmeier opened this issue Aug 23, 2023 · 6 comments

Comments

@pmeier
Copy link
Collaborator

pmeier commented Aug 23, 2023

https://github.com/pytorch/vision/actions/runs/5941974400/job/16117254380

Failures start with 9c4f738, which is the first commit that used yesterdays (20230822) PyTorch nightly.

   =========================== short test summary info ============================
  FAILED test/test_extended_models.py::test_schema_meta_validation[vit_b_16] - AssertionError: assert not [(ViT_B_16_Weights.IMAGENET1K_V1, '_ops'), (ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1, '_ops')]
  FAILED test/test_extended_models.py::test_schema_meta_validation[vit_b_32] - AssertionError: assert not [(ViT_B_32_Weights.IMAGENET1K_V1, '_ops')]
  FAILED test/test_extended_models.py::test_schema_meta_validation[vit_h_14] - AssertionError: assert not [(ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1, '_ops'), (ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1, '_ops')]
  FAILED test/test_extended_models.py::test_schema_meta_validation[vit_l_16] - AssertionError: assert not [(ViT_L_16_Weights.IMAGENET1K_V1, '_ops'), (ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1, '_ops')]
  FAILED test/test_extended_models.py::test_schema_meta_validation[vit_l_32] - AssertionError: assert not [(ViT_L_32_Weights.IMAGENET1K_V1, '_ops')]
  ====== 5 failed, 658 passed, 1 skipped, 430 warnings in 511.39s (0:08:31) ======

It is not obvious from the error message, but the failure here comes from the fact the number that we have on record, e.g.

no longer matches what we calculate.

cc @seemethere

@NicolasHug
Copy link
Member

Looks the same as #7349. Skip the test :(

@pmeier
Copy link
Collaborator Author

pmeier commented Aug 23, 2023

If it is flakiness, the error is huge:

from common_extended_utils import get_ops
from torchvision import models

for model_fn, weights in [
    (models.vit_b_16, models.ViT_B_16_Weights.IMAGENET1K_V1),
    (models.vit_b_16, models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1),
    (models.vit_b_32, models.ViT_B_32_Weights.IMAGENET1K_V1),
    (models.vit_h_14, models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1),
    (models.vit_h_14, models.ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1),
    (models.vit_l_16, models.ViT_L_16_Weights.IMAGENET1K_V1),
    (models.vit_l_16, models.ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1),
    (models.vit_l_32, models.ViT_L_32_Weights.IMAGENET1K_V1),
]:
    expected = weights.meta["_ops"]
    actual = get_ops(model=model_fn(weights=weights), weight=weights)
    exact_float_match = expected == actual

    print(f"{model_fn.__name__} / {weights.name}: {actual=}, {expected=}, {exact_float_match=}")
vit_b_16 / IMAGENET1K_V1: actual=16.849, expected=17.564, exact_float_match=False
vit_b_16 / IMAGENET1K_SWAG_E2E_V1: actual=49.348, expected=55.484, exact_float_match=False
vit_b_32 / IMAGENET1K_V1: actual=4.363, expected=4.409, exact_float_match=False
vit_h_14 / IMAGENET1K_SWAG_E2E_V1: actual=862.961, expected=1016.717, exact_float_match=False
vit_h_14 / IMAGENET1K_SWAG_LINEAR_V1: actual=161.884, expected=167.295, exact_float_match=False
vit_l_16 / IMAGENET1K_V1: actual=59.647, expected=61.555, exact_float_match=False
vit_l_16 / IMAGENET1K_SWAG_E2E_V1: actual=310.346, expected=361.986, exact_float_match=False
vit_l_32 / IMAGENET1K_V1: actual=15.255, expected=15.378, exact_float_match=False

Maybe something was merged into core that actually reduced the number significantly?

@pmeier
Copy link
Collaborator Author

pmeier commented Aug 23, 2023

It is definitely not flakiness this time. For some reason, we no longer have values for aten.bmm, which is what is causing the drop. The other ops are unaffected.

@pmeier
Copy link
Collaborator Author

pmeier commented Aug 23, 2023

PyTorch core has a new aten._scaled_dot_product_flash_attention that is not tracked by us. See pytorch/pytorch#103826.

@pmeier
Copy link
Collaborator Author

pmeier commented Aug 23, 2023

We are currently tracking the following ops

flop_mapping = {
aten.mm: matmul_flop,
aten.matmul: matmul_flop,
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
quantized.conv2d: quant_conv_flop,
quantized.conv2d_relu: quant_conv_flop,
}

We need to add an

aten._scaled_dot_product_flash_attention: scaled_dot_product_flash_attention_flop

entry there. The function scaled_dot_product_flash_attention_flop needs to count the flops of the underlying kernel.

That is quite some work. @NicolasHug are you ok with me disabling this subtest for the offending models for now and fix this after release?

@NicolasHug
Copy link
Member

yeah

@pmeier pmeier changed the title extended unittests are failing compute flops for scaled_dot_product_flash_attention Aug 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants