-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Enable each MMA to have it's own intrinsic #287
Conversation
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" | ||
with open(filename, "w") as f: | ||
f.write(mb.module_op.get_asm()) | ||
rmse = torch.sqrt(torch.mean(torch.square(output - torch_ref))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use assert_close with appropriate atol and rtol here instead of manually computing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, but the FP8 range/accuracy is quite limited so, not sure if atol/rtol becomes meaningful here. For example we have a case where the rtol is just ridiculously large:
Mismatched elements: 2613487 / 2621440 (99.7%)
Greatest absolute difference: 0.04659026861190796 at index (26, 785, 31) (up to 1e-05 allowed)
Greatest relative difference: 1510450.625 at index (2, 593, 37) (up to 1.3e-06 allowed)
(Pdb) torch_ref[2, 593, 37]
tensor(9.8498e-09, device='cuda:0')
(Pdb) output[2, 593, 37]
tensor(0.0149, device='cuda:0')
We may be able to improve it once we improve the quantization-dequantization to get better range, but for now with atol-rtol, I think it will not give very good idea on accuracy due to the inconsistency of our FP8 range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing the same values as above, these tensors qualitatively looks quite good, especially without any rescaling/adjusting of the range.
tensor([[ 0.0500, -0.0215, 0.0078, ..., -0.0116, -0.0850, 0.0386],
[ 0.0232, 0.0168, 0.0137, ..., 0.1215, 0.0033, 0.0275],
[-0.0044, -0.0127, 0.0993, ..., 0.0158, 0.0504, 0.0125],
...,
[ 0.0696, -0.0049, -0.0455, ..., 0.0299, -0.0396, 0.0955],
[-0.0473, 0.1092, -0.0237, ..., 0.0194, -0.0045, 0.0034],
[-0.0104, 0.0096, 0.0394, ..., 0.0266, -0.0525, 0.0692]],
device='cuda:0')
(Pdb) torch_ref[0,:]
tensor([[ 0.0467, -0.0231, 0.0064, ..., -0.0106, -0.0863, 0.0408],
[ 0.0245, 0.0132, 0.0127, ..., 0.1255, 0.0051, 0.0280],
[-0.0083, -0.0181, 0.1016, ..., 0.0180, 0.0563, 0.0145],
...,
[ 0.0675, -0.0112, -0.0434, ..., 0.0295, -0.0368, 0.0981],
[-0.0471, 0.1066, -0.0230, ..., 0.0206, -0.0002, -0.0072],
[-0.0089, 0.0065, 0.0376, ..., 0.0283, -0.0549, 0.0682]],
```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, let's merge for now and iterate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, lgtm! Just a question about using assert_close
In order to align layouts in chained gemm or attention inFP8, we'd need to use different intrinsics for the 1st and 2nd mma. In order to achieve this, we'd need to do: