Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm #13231

Merged
merged 4 commits into from
Feb 22, 2025

Conversation

gshtras
Copy link
Contributor

@gshtras gshtras commented Feb 13, 2025

Performance improvement for ROCm working around the hardware limitation.

In GEMM, you can have significant Tagram channel hotspot problems on MI300 if the stride of a matrix is a multiple of 512 bytes in GEMM. This is especially true for TN transpose cases, which might increase the latency of VMEM instructions, resulting in a significant drop in performance. If it's possible (or makes sense), stride padding can be used to avoid any stride multiple of 512 bytes (for example, for TN F16 GEMM, lda = M + 128 when M%256==0) from the application when allocating memory for the matrices.

One requirement for this is for w8a8_block_fp8_matmul to support the non-contiguous weights, which it seems to already do, so the leftover assertion is obsolete.
While maintaining the same correctness, this shows the following latency improvement on ROCm:
amd/Llama-3.1-8B-Instruct-FP8-KV bs=64 in=512 out=512 tp=1:
5.95s -> 5.7s (4%)
amd/Llama-3.1-70B-Instruct-FP8-KV bs=64 4in=512 out=512 tp=1:
25.6s -> 24.3s (5%)
deepseek-ai/DeepSeek-R1 bs=64 in=256 out=256 tp=8:
26.1s -> 24.9 (5%)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
… strides

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@gshtras gshtras changed the title [ROCm] Apply FP8 weights padding to 256 bytes on ROCm [ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm Feb 13, 2025
@hongxiayang hongxiayang added the rocm Related to AMD ROCm label Feb 13, 2025
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is empty_cache really necessary here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without it there is a possibility of having double the memory allocated, depending on the allocator behavior

@NickLucche
Copy link
Contributor

Thanks for contributing! 🙏🏻
I only had a few comments to add while actual review from code owners is pending.

Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras gshtras force-pushed the fp8_padding_upstream branch from 6106325 to f3da192 Compare February 18, 2025 17:35
@@ -477,7 +477,7 @@ def w8a8_block_fp8_matmul(
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]

assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert B.ndim == 2 and Bs.ndim == 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this is okay?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel works just fine with a padded non-contiguous tensor. And in any scenario other than with padding it should be contiguous already, so no existing workflow is supposed to break.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other option is just to call weight.contiguous() after we pad it in process_weights_after_loading?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would remove the padding, reverting the F.pad action

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, that was a dumb comment by me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras I agree contiguous here was overly strict. But should we still check that the stride is 1 for the last dimension? B.stride(-1) == 1?

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 21, 2025
@robertgshaw2-redhat
Copy link
Collaborator

Nice work!

@simon-mo simon-mo merged commit c904fdd into vllm-project:main Feb 22, 2025
42 of 46 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants