diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 31fe095a91bc0..543ccf171d1c1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,6 +31,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache @@ -1303,11 +1304,34 @@ def profile_run(self) -> None: if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) - # TODO(woosuk): Consider the memory usage of the sampler. + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.5), + all_greedy=False, + all_random=False, + spec_token_ids=None, + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=torch.ones_like(logits, dtype=torch.int64), + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)]) + sampler_output = self.model.sample( + logits=logits, sampling_metadata=dummy_metadata) else: logits = None + sampler_output = None + dummy_metadata = None torch.cuda.synchronize() - del hidden_states, logits + del hidden_states, logits, sampler_output, dummy_metadata self.encoder_cache.clear() gc.collect()