Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some minimal optimizations for CDNA #10498

Merged
merged 2 commits into from
Nov 27, 2024

Conversation

IMbackK
Copy link
Contributor

@IMbackK IMbackK commented Nov 25, 2024

This pr adds some minimal optimizations for CDNA.

Mainly, the mmq kernels perform extreamly poorly on CDNA. One reason is that the compiler runs out of Architectural VGPRs and spills into Acc VGPRs, which is terrible for performance. MMQ also dosent make use of MFMA while rocblas can.

This therefore this pr makes ggml_cuda_should_use_mmq return false more often for CDNA (and VEGA20 as rocblas is faster here too).

To allow rocblas to use MFMA this pr also sets the compute type to 32bit. CDNA can not do 16Bit accumulation with MFMA and rocblas dose NOT give you higher precision than you asked for, even if this would result in better performance, thus we need to set
ROCBLAS_COMPUTE_32F on CNDA so that MFMA gets used.

This pr improves prompt processing by about 2x almost uniformly across a variety of batch and model sizes. rocprof2 sill has mmq's kernels taking >95% of the wall time and there would still be a huge amount to go for actually decent performance on CDNA/GCN.

  • Self-reported review complexity:
    • Low
    • Medium
    • High

@github-actions github-actions bot added the Nvidia GPU Issues specific to Nvidia GPUs label Nov 25, 2024
@JohannesGaessler JohannesGaessler self-requested a review November 25, 2024 18:20
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. See my comments about directly writing back 32 bit floats (which I assume would be possible). This could be done in a separate PR though so if you want we can just merge this as-is.

(I assume you've already tried alternative values for mmq y size and max x size.)

ggml/src/ggml-cuda/common.cuh Outdated Show resolved Hide resolved
Comment on lines 1110 to 1112
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if(ggml_cuda_info().devices[ctx.device].cc == CC_CDNA)
cu_compute_type = CUBLAS_COMPUTE_32F;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the computation is done as 32 bit floats anyways you should be able to get a bit more performance by writing back the results as 32 bit directly instead of writing back as 16 bit and then converting to 32 bit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if(ggml_cuda_info().devices[ctx.device].cc == CC_CDNA)
cu_compute_type = CUBLAS_COMPUTE_32F;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}

Copy link
Contributor Author

@IMbackK IMbackK Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the computation is done as 32 bit floats anyways you should be able to get a bit more performance by writing back the results as 32 bit directly instead of writing back as 16 bit and then converting to 32 bit.

I hacked this in and found no measurable difference, this is no surprise as the time spent in the the kernels of ggml_cuda_op_mul_mat_cublas and ggml_cuda_mul_mat_batched_cublas is minuscule compared to the time spend in other places. So its not worth the effort for now.

ggml/src/ggml-cuda/ggml-cuda.cu Outdated Show resolved Hide resolved
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Nov 26, 2024
@IMbackK
Copy link
Contributor Author

IMbackK commented Nov 26, 2024

I fixed the nits and also completed the defines for various rocm arches so that we have values for all major generations supported by rocblas if someone wants to do some optimization work in the future.

@IMbackK
Copy link
Contributor Author

IMbackK commented Nov 26, 2024

(I assume you've already tried alternative values for mmq y size and max x size.)

Yes, i was not able gain any major performance by doing that or come anywhere close to rocblas.
I have not looked at the kernels in detail but besides the register pressure limiting occupancy, another thing that is immidatly obvious is that quite a few kernels dispatch groups of work items that are n*WARP_SIZE wide but WARP_SIZE is simply defined as 32.

ofc gcn and cdnas warp size is 64. im kinda surprised it works at all given that things like layer norm can easly be senstive to warp size.

@IMbackK
Copy link
Contributor Author

IMbackK commented Nov 26, 2024

Some quick numbers:

This PR

  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl | n_batch |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |          pp64 |        434.23 ± 0.43 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp128 |        433.79 ± 0.39 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp512 |        430.21 ± 0.28 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |          pp64 |        433.94 ± 0.20 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp128 |        785.42 ± 0.49 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp512 |       1108.48 ± 0.63 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |          pp64 |        433.87 ± 0.41 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp128 |        783.96 ± 1.65 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp512 |       1106.24 ± 0.35 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |          pp64 |        432.29 ± 0.85 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp128 |        783.59 ± 0.92 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp512 |       1104.26 ± 0.70 |

Master:

  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl | n_batch |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |          pp64 |        209.77 ± 0.08 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp128 |        209.63 ± 0.08 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp512 |        208.54 ± 0.12 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |          pp64 |        209.35 ± 0.11 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp128 |        357.08 ± 0.39 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp512 |        466.96 ± 0.64 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |          pp64 |        208.90 ± 0.13 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp128 |        355.16 ± 0.25 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp512 |        463.69 ± 0.37 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |          pp64 |        208.57 ± 0.13 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp128 |        353.69 ± 0.23 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp512 |        461.89 ± 0.35 |

@8XXD8
Copy link

8XXD8 commented Nov 27, 2024

If I add gfx900 and gfx906 to the CDNA define instead of GCN in hip.h, they get a big increase in prompt processing too.

Master:

  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q8_0                  |   7.95 GiB |     8.03 B | ROCm       |  99 |         pp512 |        217.53 ± 0.73 |

  Device 0: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q8_0                  |   7.95 GiB |     8.03 B | ROCm       |  99 |         pp512 |        153.69 ± 0.38 |

  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 1: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 2: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 3: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
  Device 4: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| llama ?B Q4_K - Medium         |  68.19 GiB |   122.61 B | ROCm       |  99 |   row |         pp512 |          8.08 ± 0.01 |

  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 1: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 2: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 3: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
  Device 4: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| llama 70B Q6_K                 |  53.91 GiB |    70.55 B | ROCm       |  99 |   row |         pp512 |         14.24 ± 0.03 |

build: c9b00a70 (4191)

PR:

  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q8_0                  |   7.95 GiB |     8.03 B | ROCm       |  99 |         pp512 |        510.84 ± 4.09 |

  Device 0: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B Q8_0                  |   7.95 GiB |     8.03 B | ROCm       |  99 |         pp512 |        242.81 ± 0.65 |

  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 1: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 2: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 3: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
  Device 4: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| llama ?B Q4_K - Medium         |  68.19 GiB |   122.61 B | ROCm       |  99 |   row |         pp512 |         24.41 ± 0.04 |

  Device 0: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 1: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 2: AMD Radeon Pro VII, compute capability 9.0, VMM: no
  Device 3: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
  Device 4: AMD Radeon Instinct MI25, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| llama 70B Q6_K                 |  53.91 GiB |    70.55 B | ROCm       |  99 |   row |         pp512 |         35.49 ± 0.11 |

build: 003b9f7b (4166)

@JohannesGaessler
Copy link
Collaborator

I will merge this soon unless you also want to address the comment by @8XXD8 .

@IMbackK
Copy link
Contributor Author

IMbackK commented Nov 27, 2024

added the equivalent of what @8XXD8 did to also speed up gcn

@JohannesGaessler JohannesGaessler merged commit 3ad5451 into ggerganov:master Nov 27, 2024
49 of 50 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Dec 20, 2024
* Add some minimal optimizations for CDNA

* ggml_cuda: set launch bounds also for GCN as it helps there too
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants