-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: remove DMMV, consolidate F16 mult mat vec #10318
CUDA: remove DMMV, consolidate F16 mult mat vec #10318
Conversation
On RTX 3090, without FA is faster than with FA for tg8192. Is this expected? |
I forgot to mention: this is because the 1b model has a head size of 64 (vs. 128 in the larger LLaMA variants) and the CUDA FA kernel is comparatively poorly optimized for that head size. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 2 out of 10 changed files in this pull request and generated no suggestions.
Files not reviewed (8)
- Makefile: Language not supported
- ggml/CMakeLists.txt: Language not supported
- ggml/src/ggml-cuda/CMakeLists.txt: Language not supported
- ggml/src/ggml-cuda/ggml-cuda.cu: Language not supported
- ggml/src/ggml-cuda/mmv.cu: Language not supported
- ggml/src/ggml-cuda/mmv.cuh: Language not supported
- ggml/src/ggml-hip/CMakeLists.txt: Language not supported
- ggml/src/ggml-musa/CMakeLists.txt: Language not supported
That's one way to stop the singularity from happening lol. What's funny is that it counts the deletion of |
Incidentally this also seems to improve performance with the
|
For a FP32 model the only difference is in the attention, particularly the calculation of KQV. My expectation for FP32 would be that cuBLAS GEMM would be faster (but I had also expected cuBLAS GEMM to be faster for FP16). |
The FP16 batch size 1 kernels on master are relatively old and poorly written. This PR replaces them with a single kernel which both reduces the complexity and improves the performance of the CUDA backend. Since there are now better alternatives for all previous use cases of the
dequantize_mul_mat_vec
kernels I removed them. This also makes it possible to remove some compilation options, those choices are essentially now made at runtime based on GPU features and matrix shape.Performance changes
This PR affects performance for FP16 models but also for all other models when not using FlashAttention. The code for LLaMA 3 8b q4_0 with FlashAttention has not changed, any differences are just noise and I only included it for comparison. I noticed that for FP16 models with this PR even on modern GPUs there are cases where the kernel I added is faster than using cuBLAS, specifically when
src0
is thin in terms of rows. My suspicion is that the cuBLAS code is using tensor cores even at batch size 1 which reduces the number of CUDA blocks and causes suboptimal utilization for thin matrices.