Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naive Support for Hopper FP8 Prefill Kernel with Per-Head Quantization #869

Merged
merged 10 commits into from
Feb 27, 2025

Conversation

happierpig
Copy link
Collaborator

Summary

This PR introduces naive FP8 tensor core computation following FA3's implementations. The main code modifications are located in include/flashinfer/attention/hopper/quantization, with test cases in src/fp8-dev. The primary changes include:

  1. In-Kernel V Transpose
    Since wgmma.fp8 requires K-major for both operands, and Q/K/V are all head_dim-major, V transpose is required before feeding into tensor core. Therefore we provide an in-kernel transpose in shared memory using ldmatrix/stmatrix (kernel_traits.cuh#L54)

  2. P Requantization
    After Q * K multiplication and before P * V, P is requantized per tensor using an oracle scale: p_scale = std::numeric_limits::max();. This is based on the observation that the maximum value of P is 1 in online softmax. This strategy follows the approach in SageAttention, which increases the utilization of 8-bit width compared to direct cast.

  3. Fused Dequantization
    To reduce CUDA core computation overhead, both QK and PV dequantization steps are fused into existing online softmax operations: QK dequantization is fused into sm_scale (code reference). PV dequantization is fused into the finalize step, where the denominator is applied to the output.

Remaining Work

  • Upstream modifications to separate head_dim_qk and head_dim_v
  • Optimize performance to close the gap between FA3 and FlashInfer FP8, possibly tuning kernel launch parameters
  • Add sparse and quantized support in sparse_mainloop.cuh

Perf Benchmarks on H100

FlashInfer-FP8 on average provides 20-30% throughput boost compared to FP16. However, there exists a performance gap between FA3-FP8, calling for further optimizations. Ref to scripts.

image image image

Correctness (MSE)

To validate accuracy, we compute MSE between different FP8 implementations, and output from FP16 FlashInfer. Ref to scripts. Our impl is slightly better.
image

@happierpig happierpig requested a review from yzh119 February 18, 2025 00:36
@yzh119 yzh119 mentioned this pull request Feb 19, 2025
15 tasks
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with the PR in general and let's merge it first and then iterate.


} // namespace flashinfer

#endif // FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_
#endif // FLASHINFER_ATTENTION_HOPPER_VARIANTS_CUH_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use pre-commit to format code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Done.

@yzh119 yzh119 merged commit f5dec3d into flashinfer-ai:main Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants