Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm][Kernel] Using the correct warp_size value #12789

Merged
merged 1 commit into from
Feb 6, 2025

Conversation

gshtras
Copy link
Contributor

@gshtras gshtras commented Feb 5, 2025

The code assumes WARP_SIZE to be equal to 32, based on the hard-coded values of shared int32_t shared_counts[32][8];
This is not the case on ROCm, which leads to half the array not being initialized

…n ROCm

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Copy link

github-actions bot commented Feb 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenyang78
Copy link
Contributor

Hmm, seems WARP_SIZE is defined to be 32 for ROCM?

vllm/csrc/cuda_compat.h

Lines 7 to 11 in bc1bdec

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif

Caveat: I am new to vllm so I could be wrong :)

@LucasWilkinson may take a look, too. Thanks!

@gshtras
Copy link
Contributor Author

gshtras commented Feb 5, 2025

#define WARP_SIZE warpSize

builtin warpSize is 64
Note that hard-coded 32 is in the ifndef USE_ROCM

@chenyang78
Copy link
Contributor

chenyang78 commented Feb 5, 2025

#define WARP_SIZE warpSize

builtin warpSize is 64 Note that hard-coded 32 is in the ifndef USE_ROCM

Ah, it's ifndef :) My mistake. Sorry for the noise.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although im not a ROCm expert

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 5, 2025
@simon-mo simon-mo merged commit 5b19b93 into vllm-project:main Feb 6, 2025
65 of 74 checks passed
Copy link
Contributor

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering it's warp size affecting the accuracy or perf or both?

@gshtras gshtras deleted the warp_size_fix_upstream branch February 7, 2025 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants