Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Enable each MMA to have it's own intrinsic #287

Merged
merged 4 commits into from
Nov 22, 2024

Conversation

raikonenfnu
Copy link
Contributor

@raikonenfnu raikonenfnu commented Nov 21, 2024

In order to align layouts in chained gemm or attention inFP8, we'd need to use different intrinsics for the 1st and 2nd mma. In order to achieve this, we'd need to do:

  1. Set optional MMA_Type as an arg in tkw.mma
  2. Modify index_seq_analysis and constraints to use the MMAOp's intrinsic type as opposed to the hw_constraint type when available.

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
rmse = torch.sqrt(torch.mean(torch.square(output - torch_ref)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use assert_close with appropriate atol and rtol here instead of manually computing this?

Copy link
Contributor Author

@raikonenfnu raikonenfnu Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but the FP8 range/accuracy is quite limited so, not sure if atol/rtol becomes meaningful here. For example we have a case where the rtol is just ridiculously large:

Mismatched elements: 2613487 / 2621440 (99.7%)
Greatest absolute difference: 0.04659026861190796 at index (26, 785, 31) (up to 1e-05 allowed)
Greatest relative difference: 1510450.625 at index (2, 593, 37) (up to 1.3e-06 allowed)
(Pdb) torch_ref[2, 593, 37]
tensor(9.8498e-09, device='cuda:0')
(Pdb) output[2, 593, 37]
tensor(0.0149, device='cuda:0')

We may be able to improve it once we improve the quantization-dequantization to get better range, but for now with atol-rtol, I think it will not give very good idea on accuracy due to the inconsistency of our FP8 range.

Copy link
Contributor Author

@raikonenfnu raikonenfnu Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the same values as above, these tensors qualitatively looks quite good, especially without any rescaling/adjusting of the range.

tensor([[ 0.0500, -0.0215,  0.0078,  ..., -0.0116, -0.0850,  0.0386],
        [ 0.0232,  0.0168,  0.0137,  ...,  0.1215,  0.0033,  0.0275],
        [-0.0044, -0.0127,  0.0993,  ...,  0.0158,  0.0504,  0.0125],
        ...,
        [ 0.0696, -0.0049, -0.0455,  ...,  0.0299, -0.0396,  0.0955],
        [-0.0473,  0.1092, -0.0237,  ...,  0.0194, -0.0045,  0.0034],
        [-0.0104,  0.0096,  0.0394,  ...,  0.0266, -0.0525,  0.0692]],
       device='cuda:0')
(Pdb) torch_ref[0,:]
tensor([[ 0.0467, -0.0231,  0.0064,  ..., -0.0106, -0.0863,  0.0408],
        [ 0.0245,  0.0132,  0.0127,  ...,  0.1255,  0.0051,  0.0280],
        [-0.0083, -0.0181,  0.1016,  ...,  0.0180,  0.0563,  0.0145],
        ...,
        [ 0.0675, -0.0112, -0.0434,  ...,  0.0295, -0.0368,  0.0981],
        [-0.0471,  0.1066, -0.0230,  ...,  0.0206, -0.0002, -0.0072],
        [-0.0089,  0.0065,  0.0376,  ...,  0.0283, -0.0549,  0.0682]],
        ```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's merge for now and iterate

Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, lgtm! Just a question about using assert_close

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
@raikonenfnu raikonenfnu merged commit 965247e into iree-org:main Nov 22, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants