[V1][Core] Fix memory issue with logits & sampling #13776
Draft
+49
−29
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reopened from accidentally merged #13721
This PR is a followup to #13594 (comment) that describes the memory issue during online serving even after sampler profiling is added to
profile_run
. After some investigation, the root cause is memory fragmentation issue of logits and other related sampling tensors since we don't preallocate buffers for these beforehand.This memory issue can be reproduced by modifying the temperature in the file below to non zero to trigger the logits sampling code path.
vllm/benchmarks/backend_request_func.py
Line 249 in eb24dc4
Server command:
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B --disable-log-requests --no-enable-prefix-caching
Client command:
The graphs below tracks the memory usage of the server process on 1xH100. Timestamp starts when the server is ready to receive the traffic, and around T=25 is when it starts receiving traffic.
On main, since

logits.shape[0]
goes up incrementally because of the nature of online traffic pattern, the memory usage starts growing as a result of the memory fragmentation issue of the intermediate tensors fromlogits
inSampler
. This issue will not crash the server since PyTorch itself will garbage collect these cached buffers to prevent OOM (as observed from the dips in the graph), but this should be fixed and handled by vLLM for a few obvious reasons (e.g, memory release requires synchronization).The root cause of this issue is that the sampler was not included in
compile_or_warm_up_model
butcapture_model
implicitly callstorch.cuda.empty_cache()
, therefore even if the memory usage of sampler was captured inprofile_run
, the memory buffers were cleared from this method.This PR addresses this issue by adding

dummy_sampler_run
and calls it after the model forward itself is warmed up and captured. We do not want to put them both in_dummy_run
since this method is needed elsewhere for other purposes.The memory usage is rather stable from this PR, and one can observe the initial server memory usage increases from ~68K MiB to ~73K MiB (this accounts for all sampling related buffers that were taken into account during profiling but cleared from warmup), but stayed stable during the actual inference. We indeed observe a very small bump when it started receiving traffic, but IMO it is small enough for us to leave it for later investigation.
In addition, with the default GMU=0.9 and thus one may expect initial server launch takes 81559 * 0.9 = 73403 + w/e cuda graphs require (although this is technically speaking not how it works), the memory usage (~73714) with the fix from this PR should be acceptable. Without needing PyTorch to do gc, we also observe a tiny perf improvement.
An alternative fix is to have persistent buffer of
logits
but this may encounter some practical issues, and we will leave it for future investigation too.