Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: Optimize some mat-vec mul quant shaders #10296

Merged
merged 1 commit into from
Nov 16, 2024

Conversation

jeffbolznv
Copy link
Collaborator

Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses the B loads across the rows and also reuses some addressing calculations. This required manually partially unrolling the loop, since the compiler is less willing to unroll outer loops.

Add bounds-checking on the last iteration of the loop. I think this was at least partly broken before. I'd also like to be able to disable robustness for some of these pipelines in the future, to get a bit more perf.

Optimize the Q4_K shader to vectorize most loads and reduce the number of bit twiddling instructions. It should be possible to do something similar to other Qi_K shaders. I can maybe do this, but happy for somebody else to do it.

Perf results below. Still slower than CUDA (which is using dp4a), but a nice boost. Definitely worth testing on some other hardware, too.

Before:
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   10224 runs -   492.38 us/run - 117.44 MFLOP/run - 238.52 GFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   20448 runs -   248.94 us/run - 117.44 MFLOP/run - 471.76 GFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  54528 runs -    92.30 us/run - 117.44 MFLOP/run -   1.27 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  45156 runs -   110.73 us/run - 117.44 MFLOP/run -   1.06 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  21300 runs -   241.06 us/run - 117.44 MFLOP/run - 487.19 GFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  16188 runs -   309.08 us/run - 117.44 MFLOP/run - 379.97 GFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  34932 runs -   145.27 us/run - 117.44 MFLOP/run - 808.44 GFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  51120 runs -    97.85 us/run - 117.44 MFLOP/run -   1.20 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  39192 runs -   128.78 us/run - 117.44 MFLOP/run - 911.94 GFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  50268 runs -    99.85 us/run - 117.44 MFLOP/run -   1.18 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  40896 runs -   124.83 us/run - 117.44 MFLOP/run - 940.77 GFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  40896 runs -   123.86 us/run - 117.44 MFLOP/run - 948.14 GFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                11928 runs -   440.35 us/run - 117.44 MFLOP/run - 266.70 GFLOPS
  
| baichuan 13B Q4_0              |   7.44 GiB |    13.90 B | Vulkan     | 1000 |         tg128 |         35.03  0.13 |
| starcoder2 7B Q4_0             |   3.88 GiB |     7.40 B | Vulkan     | 1000 |         tg128 |         55.92  0.37 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     | 1000 |         tg128 |         93.81  0.97 |

After:
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   10224 runs -   493.00 us/run - 117.44 MFLOP/run - 238.21 GFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   20448 runs -   251.57 us/run - 117.44 MFLOP/run - 466.83 GFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  65604 runs -    77.18 us/run - 117.44 MFLOP/run -   1.52 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  55380 runs -    91.59 us/run - 117.44 MFLOP/run -   1.28 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  36636 runs -   139.45 us/run - 117.44 MFLOP/run - 842.17 GFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  33228 runs -   154.00 us/run - 117.44 MFLOP/run - 762.60 GFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  35784 runs -   141.88 us/run - 117.44 MFLOP/run - 827.74 GFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  53676 runs -    94.49 us/run - 117.44 MFLOP/run -   1.24 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  39192 runs -   129.18 us/run - 117.44 MFLOP/run - 909.13 GFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  82644 runs -    60.54 us/run - 117.44 MFLOP/run -   1.94 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  39192 runs -   130.01 us/run - 117.44 MFLOP/run - 903.33 GFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  40896 runs -   123.02 us/run - 117.44 MFLOP/run - 954.66 GFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                11076 runs -   459.25 us/run - 117.44 MFLOP/run - 255.72 GFLOPS

| baichuan 13B Q4_0              |   7.44 GiB |    13.90 B | Vulkan     | 1000 |         tg128 |         40.93  0.30 |
| starcoder2 7B Q4_0             |   3.88 GiB |     7.40 B | Vulkan     | 1000 |         tg128 |         64.30  0.82 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     | 1000 |         tg128 |        105.04  0.55 |

Split out from #10206, but the code is pretty different.

Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses
the B loads across the rows and also reuses some addressing calculations.
This required manually partially unrolling the loop, since the compiler
is less willing to unroll outer loops.

Add bounds-checking on the last iteration of the loop. I think this was at
least partly broken before.

Optimize the Q4_K shader to vectorize most loads and reduce the number of
bit twiddling instructions.
@jeffbolznv jeffbolznv requested a review from 0cc4m November 14, 2024 18:18
@0cc4m
Copy link
Collaborator

0cc4m commented Nov 14, 2024

Thank you, this is quite impressive!

I tested these models:

  • Llama 3 8B Q4_K_S
  • Llama 3 8B q4_0
  • Llama 2 13B q4_0
  • Mistral Nemo q5_0

Nvidia RTX 3090:

Before:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 66.12 ± 0.99
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 57.64 ± 0.04
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 38.73 ± 0.12
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 20.79 ± 0.04

After:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 77.54 ± 0.13
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 73.46 ± 0.09
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 50.04 ± 0.05
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 38.89 ± 0.11

Nvidia Tesla P40:

Before:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 21.52 ± 0.03
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 22.80 ± 0.02
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 13.85 ± 0.00
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 8.49 ± 0.01

After:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 38.30 ± 0.12
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 22.19 ± 0.01
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 13.05 ± 0.00
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 8.92 ± 0.59

AMD Radeon Pro VII:

Before:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 25.14 ± 0.52
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 28.82 ± 0.54
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 16.30 ± 1.93
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 14.74 ± 0.06

After:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 29.90 ± 0.03
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 38.49 ± 3.24
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 26.55 ± 0.35
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 18.09 ± 0.21

AMD Radeon RX 6800 XT:

Before:

model size params backend ngl threads test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 8 tg128 83.11 ± 0.12
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 8 tg128 59.96 ± 0.35
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 8 tg128 39.01 ± 0.06
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 8 tg128 30.71 ± 0.01

After:

model size params backend ngl threads test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 8 tg128 84.17 ± 0.02
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 8 tg128 70.52 ± 0.38
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 8 tg128 45.76 ± 0.17
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 8 tg128 37.49 ± 0.08

@netrunnereve
Copy link
Collaborator

This is 50% faster on Q4_0 with a RX 570, very nice!

Master

model size params backend ngl threads test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 100 8 pp512 95.83 ± 0.19
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 100 8 tg128 7.74 ± 0.06
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 pp512 76.86 ± 0.09
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 tg128 8.09 ± 0.02

PR

model size params backend ngl threads test t/s
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 100 8 pp512 96.81 ± 0.09
llama 8B Q4_0 4.33 GiB 8.03 B Vulkan 100 8 tg128 11.61 ± 0.05
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 pp512 76.83 ± 0.43
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 tg128 9.18 ± 0.02

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 15, 2024

We also didn't break Intel yet:

Intel ARC A770:
Before:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 22.56 ± 0.05
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 28.77 ± 0.05
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 20.51 ± 0.03
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 4.87 ± 0.00
llama 13B Q4_0 6.60 GiB 12.25 B Vulkan 99 tg128 21.41 ± 0.08

After:

model size params backend ngl test t/s
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 tg128 28.89 ± 0.09
llama 8B Q4_0 5.61 GiB 8.03 B Vulkan 99 tg128 34.64 ± 0.04
llama 13B Q4_0 6.86 GiB 13.02 B Vulkan 99 tg128 21.22 ± 0.01
llama 13B Q5_0 7.93 GiB 12.25 B Vulkan 99 tg128 4.74 ± 0.00
llama 13B Q4_0 6.60 GiB 12.25 B Vulkan 99 tg128 22.46 ± 0.02

It got a little faster in most cases, but Intel is still quirky. Besides low prompt processing cause ANV doesn't like my mul_mm shader, it seems to have an issue with q5_0 Matrix Vector Multiplication too. I added q4_0 to make sure the issue is not the model. Maybe you have an idea what's going on. Alignment issue with the quant struct? I've had those in the past.

Here's most of the quants on a 1.1B model on Intel A770. Something's definitely wrong with Q5_0 on ARC.

model size params backend ngl test t/s
llama 1B F16 2.05 GiB 1.10 B Vulkan 99 tg128 86.78 ± 0.20
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 99 tg128 123.14 ± 0.10
llama 1B Q4_1 668.18 MiB 1.10 B Vulkan 99 tg128 113.38 ± 0.18
llama 1B Q5_0 729.84 MiB 1.10 B Vulkan 99 tg128 45.57 ± 0.03
llama 1B Q5_1 791.50 MiB 1.10 B Vulkan 99 tg128 108.84 ± 0.35
llama 1B Q2_K - Medium 411.41 MiB 1.10 B Vulkan 99 tg128 98.81 ± 0.28
llama 1B Q3_K - Medium 523.67 MiB 1.10 B Vulkan 99 tg128 92.34 ± 0.04
llama 1B Q4_K - Medium 636.18 MiB 1.10 B Vulkan 99 tg128 110.08 ± 0.21
llama 1B Q5_K - Medium 745.11 MiB 1.10 B Vulkan 99 tg128 110.75 ± 0.04
llama 1B Q6_K 860.86 MiB 1.10 B Vulkan 99 tg128 111.68 ± 0.27
llama 1B Q8_0 1.09 GiB 1.10 B Vulkan 99 tg128 96.21 ± 0.74

Edit: But this is unrelated to this PR.

@jeffbolznv
Copy link
Collaborator Author

Just to make sure I understand, the issue with Q5_0 on Intel is not a functional issue, and is a preexisting performance issue? I don't have any experience with those GPUs. I can't think of a reason Q5_0 would be so much worse than Q5_1.

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 15, 2024

Just to make sure I understand, the issue with Q5_0 on Intel is not a functional issue, and is a preexisting performance issue? I don't have any experience with those GPUs. I can't think of a reason Q5_0 would be so much worse than Q5_1.

Yeah, this is a preexisting performance issue. I've found many quirks with Intel GPUs over the time of trying to make them work, and I only got them to a usable state, not to the performance they should be capable of. I don't really know how I can figure out what the cause of this is, I suspect the driver is not mature/optimized enough.

@netrunnereve
Copy link
Collaborator

So I went ahead and tried out subgroup adds with these changes and this time it has a negligible effect compared to #10206 (at least that's the case on the RX 570, I don't have the W8100 with me currently). I'm only seeing a 1% improvement with the code below.

--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
@@ -6,6 +6,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #extension GL_EXT_null_initializer : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
 
 #include "mul_mat_vec_base.comp"
 
@@ -16,8 +17,6 @@ layout (constant_id = 1) const uint NUM_ROWS = 1;
 
 uint a_offset, b_offset, d_offset, y_offset;
 
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
 void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
 {
     const uint col = i*BLOCK_SIZE + 2*tid;
@@ -79,21 +78,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
     // sum up partial sums and write back result
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-        tmpsh[n][tid] = temp[n];
-    }
-    barrier();
-    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
-        if (tid < s) {
-            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                tmpsh[n][tid] += tmpsh[n][tid + s];
-            }
-        }
-        barrier();
+        temp[n] = subgroupAdd(temp[n]);
     }
-    if (tid == 0) {
-        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
-        }
+
+    if (tid < num_rows) {
+        data_d[d_offset + first_row + tid] = D_TYPE(temp[tid]);
     }
 }

@0cc4m
Copy link
Collaborator

0cc4m commented Nov 16, 2024

Out of curiosity I tried @netrunnereve's subgroupAdd change and it gives correct results on Nvidia and AMD, but not on Intel. So the subgroup operations issue still exists on ANV.

@0cc4m 0cc4m merged commit 772703c into ggerganov:master Nov 16, 2024
51 of 53 checks passed
@netrunnereve
Copy link
Collaborator

On the other hand I'm able to get noticeably faster results on the 570 by adjusting the number of rows per workgroup. I wonder if it's worth making these things tweakable at some point like how we have K_QUANTS_PER_ITERATION.

model size params backend ngl threads test t/s improvement from pre merge master
llama 8B Q4_0 (1 row) 4.33 GiB 8.03 B Vulkan 100 8 tg128 9.80 ± 0.05 27 %
llama 8B Q4_0 (2 row default) 4.33 GiB 8.03 B Vulkan 100 8 tg128 11.60 ± 0.07 50%
llama 8B Q4_0 (8 rows) 4.33 GiB 8.03 B Vulkan 100 8 tg128 12.64 ± 0.04 63%
llama 8B Q4_0 (16 rows) 4.33 GiB 8.03 B Vulkan 100 8 tg128 12.49 ± 0.05 61%

The numbers start going downhill once I go past 16 rows on Q4_0. Q8_0 also has some nice improvements as seen below.

model size params backend ngl threads test t/s
llama 8B Q8_0 (1 row default) 7.95 GiB 8.03 B Vulkan 100 8 tg128 6.60 ± 0.02
llama 8B Q8_0 (2 rows) 7.95 GiB 8.03 B Vulkan 100 8 tg128 7.98 ± 0.00
llama 8B Q8_0 (4 rows) 7.95 GiB 8.03 B Vulkan 100 8 tg128 9.38 ± 0.02
llama 8B Q8_0 (8 rows) 7.95 GiB 8.03 B Vulkan 100 8 tg128 9.27 ± 0.02

@jeffbolznv
Copy link
Collaborator Author

Since it's a spec constant, we could select the value in ggml_vk_load_shaders based on the GPU.

I'm working on another change that does 8 consecutive K values per iteration and is able to use larger loads as a result. It would be good to test the number of rows again with that change. Maybe it's getting some of the same gains, or maybe the gains will stack.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses
the B loads across the rows and also reuses some addressing calculations.
This required manually partially unrolling the loop, since the compiler
is less willing to unroll outer loops.

Add bounds-checking on the last iteration of the loop. I think this was at
least partly broken before.

Optimize the Q4_K shader to vectorize most loads and reduce the number of
bit twiddling instructions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants