-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) #12676
[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) #12676
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
59ab887
to
4d3d413
Compare
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
4d3d413
to
bd75f96
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great find! Makes a lot of sense!
vllm/worker/cache_engine.py
Outdated
if current_platform.is_cuda() and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE: | ||
alloc_entry_size = align_to_256bytes(entry_size, self.dtype) | ||
else: | ||
alloc_entry_size = entry_size | ||
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size) | ||
|
||
for _ in range(self.num_attention_layers): | ||
# null block in CpuGpuBlockAllocator requires at least that | ||
# block to be zeroed-out. | ||
# We zero-out everything for simplicity. | ||
kv_cache.append( | ||
torch.zeros(kv_cache_shape, | ||
dtype=self.dtype, | ||
pin_memory=pin_memory, | ||
device=device)) | ||
layer_kv_cache = torch.zeros(alloc_shape, | ||
dtype=self.dtype, | ||
pin_memory=pin_memory, | ||
device=device) | ||
|
||
if alloc_entry_size != entry_size: | ||
layer_kv_cache = layer_kv_cache[..., :entry_size] | ||
|
||
kv_cache.append(layer_kv_cache.view(kv_cache_shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks good to me. A couple of comments noting what the padding and views are doing would be nice to make it a little easier to follow (as well as noting that this is kind of a special case for MLA).
vllm/envs.py
Outdated
@@ -539,6 +540,15 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: | |||
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": | |||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) | |||
), | |||
|
|||
# When on a Nvidia GPU aligns single entrys (within a page) so they are 256 | |||
# byte aligned for better performance, this increases the memory usage of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# byte aligned for better performance, this increases the memory usage of | |
# byte aligned for better performance, this increases the memory usage of |
vllm/worker/cache_engine.py
Outdated
@@ -75,15 +80,30 @@ def _allocate_kv_cache( | |||
num_blocks, self.block_size, self.num_kv_heads, self.head_size) | |||
pin_memory = is_pin_memory_available() if device == "cpu" else False | |||
kv_cache: List[torch.Tensor] = [] | |||
|
|||
entry_shape = kv_cache_shape[2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should assert/deal with if the num dimensions is what we expect and/or possibly reverse index to deal with different shapes
For instance:
Flash attention has 5 dims
return (2, num_blocks, block_size, num_kv_heads, head_size) |
Pallas attention has 4 dims
vllm/vllm/attention/backends/pallas.py
Line 40 in a1a2aaa
return (num_kv_heads, num_blocks, block_size, head_size) |
Triton MLA has 3 dims
return (num_blocks, block_size, head_size) |
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Sorry! hold on there may be accuracy issues Edit: Accuracy issues resolved |
"VLLM_CUDA_MEM_ALIGN_KV_CACHE": | ||
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to flag this? I think we can just default to this behavior without switching back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the concern about reducing the cache space by ~11%, although maybe we consider this change necessary for performance to remove the choice like you say.
This means for MLA with a head dim of 576 (like DeepSeek V2/V3) and a fp16/bf16 cache, we allocate 640 elements per cache entry in instead of 576 (1280 bytes instead of 1152). This increases the size of the cache by ~11% (wasted), but leads to a worthwhile performance gain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that it can increase the size of the KV-cache I wanted it on by default but with a flag to turn it off incase a user really wants to maximize KV-cache size
I think we can just turn this on by default?
it is on by default already (default value is "1")
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
…(MLA perf improvement)
def copy_blocks_mla(kv_caches: List[torch.Tensor], | ||
block_mapping: torch.Tensor) -> None: | ||
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why didn't we need this kernel before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we did ..., I think this may solve some bugs (TBH im not sure how copy_blocks is used by the wider system)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work -- LGTM!
…llm-project#12676) Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Felix Marty <felmarty@amd.com>
…llm-project#12676) Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
Hi, I’m using a 8xH200 setup and unable to reproduce the results from benchmarks/benchmark_throughput.py using the latest 0.7.2 version. My results:
With VLLM_CUDA_MEM_ALIGN_KV_CACHE=1:
The throughput seems nearly identical in both cases. Could you suggest potential causes for this discrepancy? |
@leepoly what model is this? this only affects MLA (i.e. DeepSeek V2/3) Edit: nvm I assume you are using R1 since the numbers look very comparable, ill try to re-run the numbers tmrw to see if there is something weird going on |
Yes I use deepseek v3 model. And I simply used the script you provided Even with VLLM_CUDA_MEM_ALIGN_KV_CACHE=0 the reported throughput (1.06rps) already roughly matches with your results with 256B aligned (1.10 rps). |
…llm-project#12676) Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
…llm-project#12676) Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
…llm-project#12676) Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
…llm-project#12676) Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
Generally Nvidia hardware likes 256 byte alignment (reasons is foggy due to the blackbox nature of Nvidia hardware), but memory allocated via the CUDA Runtime ensure 256 byte alignment (see https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#a-sequential-but-misaligned-access-pattern).
This PR aligns KV cache entries to start 256 byte boundaries, this mainly targets MLA since for "normal attention" with normal head dims (say 64 or 128) the entries are naturally 256 byte aligned.
This means for MLA with a head dim of 576 (like DeepSeek V2/V3) and a fp16/bf16 cache, we allocate 640 elements per cache entry in instead of 576 (1280 bytes instead of 1152). This increases the size of the cache by ~11% (wasted), but leads to a worthwhile performance gain.
Results DeepSeek-R1 on 8xH200
Accuracy: