Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HIP: force max threads per block to be 1024 #11621

Merged
merged 1 commit into from
Feb 4, 2025

Conversation

fxzjshm
Copy link
Contributor

@fxzjshm fxzjshm commented Feb 3, 2025

Some old compilers still use 256. Explicitly set it to 1024 to get correct result from ops like ARGMAX and GROUP_NORM.

Related: #10610, #11619

CC @IMbackK

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Feb 3, 2025
Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

over all i am fine with this, but will defer to @slaren on if this kind of vendor behavior is something we want to support (see discussion in #11619)

@@ -40,6 +40,9 @@ find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)

# Workaround old compilers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move this down a bit as the find_package calls and the version check below are logically related operations

@slaren
Copy link
Member

slaren commented Feb 3, 2025

I saw the discussion, but don't have any knowledge about HIP/ROCm to have an opinion about this. If you think that it is not likely to cause issues to other users, feel free to merge it.

Some old compilers still use 256. Explicitly set it to 1024 to get correct
result from ops like ARGMAX and GROUP_NORM.

Related: ggml-org#10610, ggml-org#11619
Signed-off-by: fxzjshm <fxzjshm@163.com>
@fxzjshm
Copy link
Contributor Author

fxzjshm commented Feb 4, 2025

@IMbackK Moved. Is this place proper?

@slaren This compiler flag is documented at https://clang.llvm.org/docs/ClangCommandLineReference.html#cmdoption-clang-gpu-max-threads-per-block. I've also compiled with ROCm 6.3.1 and no compile error is given, now testing test-backend-ops.

Update: test-backend-ops w/ ROCm 6.3.1 on gfx1100 passed.

Copy link
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just resets the default, also maximum value for all current amd gpus, it dosent change the code generation at all on sane versions of llvm at this time. We might run in to problems in the future if amd changes this for a new gpu arch - but i think this is an acceptable risk.

@IMbackK IMbackK merged commit 3ec9fd4 into ggml-org:master Feb 4, 2025
1 check passed
tinglou pushed a commit to tinglou/llama.cpp that referenced this pull request Feb 13, 2025
Some old/vendor forked version of llvm still use 256. Explicitly set it to 1024 to align with upstream llvm.

Signed-off-by: fxzjshm <fxzjshm@163.com>
orca-zhang pushed a commit to orca-zhang/llama.cpp that referenced this pull request Feb 26, 2025
Some old/vendor forked version of llvm still use 256. Explicitly set it to 1024 to align with upstream llvm.

Signed-off-by: fxzjshm <fxzjshm@163.com>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Feb 26, 2025
Some old/vendor forked version of llvm still use 256. Explicitly set it to 1024 to align with upstream llvm.

Signed-off-by: fxzjshm <fxzjshm@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants