Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) #12676

Merged
merged 11 commits into from
Feb 5, 2025

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Feb 3, 2025

Generally Nvidia hardware likes 256 byte alignment (reasons is foggy due to the blackbox nature of Nvidia hardware), but memory allocated via the CUDA Runtime ensure 256 byte alignment (see https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#a-sequential-but-misaligned-access-pattern).

This PR aligns KV cache entries to start 256 byte boundaries, this mainly targets MLA since for "normal attention" with normal head dims (say 64 or 128) the entries are naturally 256 byte aligned.

This means for MLA with a head dim of 576 (like DeepSeek V2/V3) and a fp16/bf16 cache, we allocate 640 elements per cache entry in instead of 576 (1280 bytes instead of 1152). This increases the size of the cache by ~11% (wasted), but leads to a worthwhile performance gain.

Results DeepSeek-R1 on 8xH200

VLLM_CUDA_MEM_ALIGN_KV_CACHE=0  python3 benchmarks/benchmark_throughput.py --model /data/nm/models/DeepSeek-R1 --trust-remote-code --tensor-parallel-size 8 --max-model-len 8000 --enable-chunked-prefill False --input-len 2000 --output-len 1000  --num-prompts 100
...
Throughput: 0.76 requests/s, 2289.10 total tokens/s, 763.03 output tokens/s
VLLM_CUDA_MEM_ALIGN_KV_CACHE=1  python3 benchmarks/benchmark_throughput.py --model /data/nm/models/DeepSeek-R1 --trust-remote-code --tensor-parallel-size 8 --max-model-len 8000 --enable-chunked-prefill False --input-len 2000 --output-len 1000  --num-prompts 100
...
Throughput: 1.10 requests/s, 3287.09 total tokens/s, 1095.70 output tokens/s

Accuracy:

VLLM_MLA_DISABLE=1 lm_eval --model vllm --model_args pretrained=/data/nm/models/DeepSeek-R1,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enforce_eager=False --task gsm8k --num_fewshot=5 --limit 100
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.94|±  |0.0239|
|     |       |strict-match    |     5|exact_match|↑  | 0.94|±  |0.0239|


VLLM_MLA_DISABLE=0 VLLM_CUDA_MEM_ALIGN_KV_CACHE=0 lm_eval --model vllm --model_args pretrained=/data/nm/models/DeepSeek-R1,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enforce_eager=False --task gsm8k --num_fewshot=5 --limit 100
...
INFO 02-03 14:26:12 executor_base.py:110] # CUDA blocks: 30218, # CPU blocks: 3819
INFO 02-03 14:26:12 executor_base.py:115] Maximum concurrency for 16384 tokens per request: 29.51x
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.94|±  |0.0239|
|     |       |strict-match    |     5|exact_match|↑  | 0.94|±  |0.0239|


VLLM_MLA_DISABLE=0 VLLM_CUDA_MEM_ALIGN_KV_CACHE=1 lm_eval --model vllm --model_args pretrained=/data/nm/models/DeepSeek-R1,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enforce_eager=False --task gsm8k --num_fewshot=5 --limit 100
...
INFO 02-03 14:33:20 executor_base.py:110] # CUDA blocks: 27196, # CPU blocks: 3437
INFO 02-03 14:33:20 executor_base.py:115] Maximum concurrency for 16384 tokens per request: 26.56x
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.94|±  |0.0239|
|     |       |strict-match    |     5|exact_match|↑  | 0.94|±  |0.0239|

Copy link

github-actions bot commented Feb 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Feb 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 3, 2025
@robertgshaw2-redhat robertgshaw2-redhat force-pushed the lwilkinson/mem-align-kv-cache branch from 59ab887 to 4d3d413 Compare February 3, 2025 15:59
Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@robertgshaw2-redhat robertgshaw2-redhat force-pushed the lwilkinson/mem-align-kv-cache branch from 4d3d413 to bd75f96 Compare February 3, 2025 15:59
@mergify mergify bot removed the needs-rebase label Feb 3, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great find! Makes a lot of sense!

Comment on lines 88 to 106
if current_platform.is_cuda() and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE:
alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
else:
alloc_entry_size = entry_size
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)

for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
kv_cache.append(
torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device))
layer_kv_cache = torch.zeros(alloc_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)

if alloc_entry_size != entry_size:
layer_kv_cache = layer_kv_cache[..., :entry_size]

kv_cache.append(layer_kv_cache.view(kv_cache_shape))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good to me. A couple of comments noting what the padding and views are doing would be nice to make it a little easier to follow (as well as noting that this is kind of a special case for MLA).

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
vllm/envs.py Outdated
@@ -539,6 +540,15 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
),

# When on a Nvidia GPU aligns single entrys (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# byte aligned for better performance, this increases the memory usage of
# byte aligned for better performance, this increases the memory usage of

@@ -75,15 +80,30 @@ def _allocate_kv_cache(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []

entry_shape = kv_cache_shape[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should assert/deal with if the num dimensions is what we expect and/or possibly reverse index to deal with different shapes

For instance:
Flash attention has 5 dims

return (2, num_blocks, block_size, num_kv_heads, head_size)

Pallas attention has 4 dims

return (num_kv_heads, num_blocks, block_size, head_size)

Triton MLA has 3 dims

return (num_blocks, block_size, head_size)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Feb 3, 2025

Sorry! hold on there may be accuracy issues

Edit: Accuracy issues resolved

Comment on lines +550 to +551
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
Copy link
Collaborator

@simon-mo simon-mo Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to flag this? I think we can just default to this behavior without switching back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the concern about reducing the cache space by ~11%, although maybe we consider this change necessary for performance to remove the choice like you say.

This means for MLA with a head dim of 576 (like DeepSeek V2/V3) and a fp16/bf16 cache, we allocate 640 elements per cache entry in instead of 576 (1280 bytes instead of 1152). This increases the size of the cache by ~11% (wasted), but leads to a worthwhile performance gain.

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that it can increase the size of the KV-cache I wanted it on by default but with a flag to turn it off incase a user really wants to maximize KV-cache size

I think we can just turn this on by default?
it is on by default already (default value is "1")

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@simon-mo simon-mo mentioned this pull request Feb 3, 2025
6 tasks
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
yessenzhar pushed a commit to deepinfra/vllm that referenced this pull request Feb 3, 2025
Signed-off-by: simon-mo <xmo@berkeley.edu>
Comment on lines +1040 to +1042
def copy_blocks_mla(kv_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't we need this kernel before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we did ..., I think this may solve some bugs (TBH im not sure how copy_blocks is used by the wider system)

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work -- LGTM!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 4, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) February 4, 2025 14:04
@simon-mo simon-mo disabled auto-merge February 5, 2025 02:22
@simon-mo simon-mo merged commit 75e9430 into vllm-project:main Feb 5, 2025
71 of 73 checks passed
fxmarty-amd pushed a commit to fxmarty-amd/vllm that referenced this pull request Feb 7, 2025
…llm-project#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Felix Marty <felmarty@amd.com>
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
…llm-project#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
@leepoly
Copy link

leepoly commented Feb 12, 2025

Hi, I’m using a 8xH200 setup and unable to reproduce the results from benchmarks/benchmark_throughput.py using the latest 0.7.2 version.

My results:
With VLLM_CUDA_MEM_ALIGN_KV_CACHE=0:

  • CUDA blocks: 34,497
  • Throughput: 1.06 requests/s, 3,170.77 total tokens/s, 1,056.92 output tokens/s

With VLLM_CUDA_MEM_ALIGN_KV_CACHE=1:

  • CUDA blocks: 31,047
  • Throughput: 1.06 requests/s, 3,185.99 total tokens/s, 1,062.00 output tokens/s

The throughput seems nearly identical in both cases. Could you suggest potential causes for this discrepancy?

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Feb 12, 2025

@leepoly what model is this? this only affects MLA (i.e. DeepSeek V2/3)

Edit: nvm I assume you are using R1 since the numbers look very comparable, ill try to re-run the numbers tmrw to see if there is something weird going on

@leepoly
Copy link

leepoly commented Feb 12, 2025

@leepoly what model is this? this only affects MLA (i.e. DeepSeek V2/3)

Edit: nvm I assume you are using R1 since the numbers look very comparable, ill try to re-run the numbers tmrw to see if there is something weird going on

Yes I use deepseek v3 model. And I simply used the script you provided VLLM_CUDA_MEM_ALIGN_KV_CACHE=0 python3 benchmarks/benchmark_throughput.py --model /data/nm/models/DeepSeek-R1 --trust-remote-code --tensor-parallel-size 8 --max-model-len 8000 --enable-chunked-prefill False --input-len 2000 --output-len 1000 --num-prompts 100

Even with VLLM_CUDA_MEM_ALIGN_KV_CACHE=0 the reported throughput (1.06rps) already roughly matches with your results with 256B aligned (1.10 rps).

pathorn pushed a commit to deepinfra/vllm that referenced this pull request Feb 14, 2025
…llm-project#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
…llm-project#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
pathorn pushed a commit to deepinfra/vllm that referenced this pull request Feb 19, 2025
…llm-project#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
…llm-project#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants