-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Naive Support for Hopper FP8 Prefill Kernel with Per-Head Quantization #869
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm good with the PR in general and let's merge it first and then iterate.
|
||
} // namespace flashinfer | ||
|
||
#endif // FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ | ||
#endif // FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use pre-commit
to format code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Done.
Summary
This PR introduces naive FP8 tensor core computation following FA3's implementations. The main code modifications are located in
include/flashinfer/attention/hopper/quantization
, with test cases insrc/fp8-dev
. The primary changes include:In-Kernel V Transpose
Since
wgmma.fp8
requires K-major for both operands, and Q/K/V are all head_dim-major, V transpose is required before feeding into tensor core. Therefore we provide an in-kernel transpose in shared memory usingldmatrix/stmatrix
(kernel_traits.cuh#L54)P Requantization
After
Q * K
multiplication and beforeP * V
, P is requantized per tensor using an oracle scale:p_scale = std::numeric_limits::max();
. This is based on the observation that the maximum value of P is1
in online softmax. This strategy follows the approach in SageAttention, which increases the utilization of 8-bit width compared to direct cast.Fused Dequantization
To reduce CUDA core computation overhead, both
QK
andPV
dequantization steps are fused into existing online softmax operations: QK dequantization is fused intosm_scale
(code reference). PV dequantization is fused into thefinalize
step, where the denominator is applied to the output.Remaining Work
head_dim_qk
andhead_dim_v
sparse_mainloop.cuh
Perf Benchmarks on H100
FlashInfer-FP8 on average provides 20-30% throughput boost compared to FP16. However, there exists a performance gap between FA3-FP8, calling for further optimizations. Ref to scripts.
Correctness (MSE)
To validate accuracy, we compute MSE between different FP8 implementations, and output from FP16 FlashInfer. Ref to scripts. Our impl is slightly better.
