Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Core] Fix memory issue with logits & sampling #13776

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Feb 24, 2025

Reopened from accidentally merged #13721

This PR is a followup to #13594 (comment) that describes the memory issue during online serving even after sampler profiling is added to profile_run. After some investigation, the root cause is memory fragmentation issue of logits and other related sampling tensors since we don't preallocate buffers for these beforehand.

This memory issue can be reproduced by modifying the temperature in the file below to non zero to trigger the logits sampling code path.

"temperature": 0.0,

Server command:
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B --disable-log-requests --no-enable-prefix-caching

Client command:

python3 benchmarks/benchmark_serving.py \        
        --model meta-llama/Llama-3.1-8B \
        --dataset-name sharegpt \
        --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
        --ignore-eos \
        --num-prompts 7200 \
        --request-rate 60

The graphs below tracks the memory usage of the server process on 1xH100. Timestamp starts when the server is ready to receive the traffic, and around T=25 is when it starts receiving traffic.

On main, since logits.shape[0] goes up incrementally because of the nature of online traffic pattern, the memory usage starts growing as a result of the memory fragmentation issue of the intermediate tensors from logits in Sampler. This issue will not crash the server since PyTorch itself will garbage collect these cached buffers to prevent OOM (as observed from the dips in the graph), but this should be fixed and handled by vLLM for a few obvious reasons (e.g, memory release requires synchronization).
gpu0_memory_usage

The root cause of this issue is that the sampler was not included in compile_or_warm_up_model but capture_model implicitly calls torch.cuda.empty_cache(), therefore even if the memory usage of sampler was captured in profile_run, the memory buffers were cleared from this method.

This PR addresses this issue by adding dummy_sampler_run and calls it after the model forward itself is warmed up and captured. We do not want to put them both in _dummy_run since this method is needed elsewhere for other purposes.
gpu0_memory_usage (1)
The memory usage is rather stable from this PR, and one can observe the initial server memory usage increases from ~68K MiB to ~73K MiB (this accounts for all sampling related buffers that were taken into account during profiling but cleared from warmup), but stayed stable during the actual inference. We indeed observe a very small bump when it started receiving traffic, but IMO it is small enough for us to leave it for later investigation.

In addition, with the default GMU=0.9 and thus one may expect initial server launch takes 81559 * 0.9 = 73403 + w/e cuda graphs require (although this is technically speaking not how it works), the memory usage (~73714) with the fix from this PR should be acceptable. Without needing PyTorch to do gc, we also observe a tiny perf improvement.

Main
============ Serving Benchmark Result ============
Successful requests:                     7200      
Benchmark duration (s):                  138.31    
Total input tokens:                      1582725   
Total generated tokens:                  1442778   
Request throughput (req/s):              52.06     
Output token throughput (tok/s):         10431.65  
Total Token throughput (tok/s):          21875.15  
---------------Time to First Token----------------
Mean TTFT (ms):                          132.21    
Median TTFT (ms):                        111.04    
P99 TTFT (ms):                           392.95    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.49     
Median TPOT (ms):                        33.55     
P99 TPOT (ms):                           58.06     
---------------Inter-token Latency----------------
Mean ITL (ms):                           34.90     
Median ITL (ms):                         33.16     
P99 ITL (ms):                            94.75     
==================================================

This PR
============ Serving Benchmark Result ============
Successful requests:                     7200      
Benchmark duration (s):                  137.63    
Total input tokens:                      1582725   
Total generated tokens:                  1442778   
Request throughput (req/s):              52.32     
Output token throughput (tok/s):         10483.35  
Total Token throughput (tok/s):          21983.57  
---------------Time to First Token----------------
Mean TTFT (ms):                          111.15    
Median TTFT (ms):                        101.92    
P99 TTFT (ms):                           283.02    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          35.12     
Median TPOT (ms):                        33.42     
P99 TPOT (ms):                           57.61     
---------------Inter-token Latency----------------
Mean ITL (ms):                           34.56     
Median ITL (ms):                         32.93     
P99 ITL (ms):                            87.28     
==================================================

An alternative fix is to have persistent buffer of logits but this may encounter some practical issues, and we will leave it for future investigation too.

ywang96 and others added 5 commits February 23, 2025 01:15
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 24, 2025
@youkaichao
Copy link
Member

is it compatible with sleep mode now?

@WoosukKwon
Copy link
Collaborator

Any updates?

@JaheimLee
Copy link

Is it related to this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants