Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: remove DMMV, consolidate F16 mult mat vec #10318

Merged
merged 1 commit into from
Nov 17, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

The FP16 batch size 1 kernels on master are relatively old and poorly written. This PR replaces them with a single kernel which both reduces the complexity and improves the performance of the CUDA backend. Since there are now better alternatives for all previous use cases of the dequantize_mul_mat_vec kernels I removed them. This also makes it possible to remove some compilation options, those choices are essentially now made at runtime based on GPU features and matrix shape.

Performance changes
GPU Model FlashAttention Test t/s master t/s cuda-mmv-5 Speedup
RX 6800 llama 1B F16 No tg128 70.01 93.96 1.34
RX 6800 llama 1B F16 No tg8192 36.62 74.87 2.04
RX 6800 llama 1B F16 Yes tg128 71.34 91.49 1.28
RX 6800 llama 1B F16 Yes tg8192 45.12 52.22 1.16
RX 6800 llama 8B Q4_0 No tg128 57.60 62.16 1.08
RX 6800 llama 8B Q4_0 No tg8192 14.16 35.14 2.48
RX 6800 llama 8B Q4_0 Yes tg128 55.82 55.75 1.00
RX 6800 llama 8B Q4_0 Yes tg8192 27.48 27.48 1.00
RTX 3090 llama 1B F16 No tg128 251.38 259.00 1.03
RTX 3090 llama 1B F16 No tg8192 200.04 219.62 1.10
RTX 3090 llama 1B F16 Yes tg128 258.41 259.51 1.00
RTX 3090 llama 1B F16 Yes tg8192 212.36 214.25 1.01
RTX 3090 llama 8B Q4_0 No tg128 132.38 135.45 1.02
RTX 3090 llama 8B Q4_0 No tg8192 99.37 104.29 1.05
RTX 3090 llama 8B Q4_0 Yes tg128 137.31 136.56 0.99
RTX 3090 llama 8B Q4_0 Yes tg8192 111.86 112.06 1.00
RTX 4090 llama 1B F16 No tg128 274.95 285.66 1.04
RTX 4090 llama 1B F16 No tg8192 236.64 257.96 1.09
RTX 4090 llama 1B F16 Yes tg128 291.35 293.05 1.01
RTX 4090 llama 1B F16 Yes tg8192 247.34 248.90 1.01
RTX 4090 llama 8B Q4_0 No tg128 150.02 153.32 1.02
RTX 4090 llama 8B Q4_0 No tg8192 123.97 131.14 1.06
RTX 4090 llama 8B Q4_0 Yes tg128 157.97 158.06 1.00
RTX 4090 llama 8B Q4_0 Yes tg8192 138.00 138.11 1.00
P40 llama 1B F16 No tg128 100.31 103.92 1.04
P40 llama 1B F16 No tg8192 65.47 72.57 1.11
P40 llama 1B F16 Yes tg128 98.81 99.88 1.01
P40 llama 1B F16 Yes tg8192 82.43 83.28 1.01
P40 llama 8B Q4_0 No tg128 50.36 51.49 1.02
P40 llama 8B Q4_0 No tg8192 29.46 31.69 1.08
P40 llama 8B Q4_0 Yes tg128 48.49 48.54 1.00
P40 llama 8B Q4_0 Yes tg8192 39.19 39.21 1.00

This PR affects performance for FP16 models but also for all other models when not using FlashAttention. The code for LLaMA 3 8b q4_0 with FlashAttention has not changed, any differences are just noise and I only included it for comparison. I noticed that for FP16 models with this PR even on modern GPUs there are cases where the kernel I added is faster than using cuBLAS, specifically when src0 is thin in terms of rows. My suspicion is that the cuBLAS code is using tensor cores even at batch size 1 which reduces the number of CUDA blocks and causes suboptimal utilization for thin matrices.

@JohannesGaessler JohannesGaessler added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Nov 15, 2024
@ggerganov
Copy link
Owner

On RTX 3090, without FA is faster than with FA for tg8192. Is this expected?

@JohannesGaessler
Copy link
Collaborator Author

I forgot to mention: this is because the 1b model has a head size of 64 (vs. 128 in the larger LLaMA variants) and the CUDA FA kernel is comparatively poorly optimized for that head size.

@ggerganov ggerganov requested a review from Copilot November 15, 2024 20:50

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 2 out of 10 changed files in this pull request and generated no suggestions.

Files not reviewed (8)
  • Makefile: Language not supported
  • ggml/CMakeLists.txt: Language not supported
  • ggml/src/ggml-cuda/CMakeLists.txt: Language not supported
  • ggml/src/ggml-cuda/ggml-cuda.cu: Language not supported
  • ggml/src/ggml-cuda/mmv.cu: Language not supported
  • ggml/src/ggml-cuda/mmv.cuh: Language not supported
  • ggml/src/ggml-hip/CMakeLists.txt: Language not supported
  • ggml/src/ggml-musa/CMakeLists.txt: Language not supported
@JohannesGaessler
Copy link
Collaborator Author

That's one way to stop the singularity from happening lol. What's funny is that it counts the deletion of dmmv.cu as a successful review even though it's impossible for it to read the code.

@github-actions github-actions bot added documentation Improvements or additions to documentation Nvidia GPU Issues specific to Nvidia GPUs labels Nov 15, 2024
@slaren
Copy link
Collaborator

slaren commented Nov 16, 2024

Incidentally this also seems to improve performance with the stories260k.gguf F32 model, but I am not sure why. The model is F32, so it shouldn't use this path, but maybe it is used for the attention. I was also wondering if it may be worth also making a version of this kernel for F32, since currently it only supports F16.

Model Test t/s master t/s cuda-mmv-5 Speedup
llama ?B all F32 (guessed) pp512 278910.28 278139.21 1.00
llama ?B all F32 (guessed) tg1024 2522.71 2882.23 1.14

@JohannesGaessler JohannesGaessler merged commit c3ea58a into ggerganov:master Nov 17, 2024
54 checks passed
@JohannesGaessler
Copy link
Collaborator Author

For a FP32 model the only difference is in the attention, particularly the calculation of KQV. My expectation for FP32 would be that cuBLAS GEMM would be faster (but I had also expected cuBLAS GEMM to be faster for FP16).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants