From 185d41f08916ff79292c664c387df03b2ba8a7e4 Mon Sep 17 00:00:00 2001 From: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Date: Thu, 20 Feb 2025 07:13:06 +0000 Subject: [PATCH 1/6] sampler memory Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 31fe095a91bc0..073b94b5d3dbb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,7 +36,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin - +from vllm.v1.sample.metadata import SamplingMetadata if TYPE_CHECKING: from vllm.v1.core.scheduler_output import SchedulerOutput @@ -1303,11 +1303,34 @@ def profile_run(self) -> None: if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) - # TODO(woosuk): Consider the memory usage of the sampler. + penalties = torch.full((logits.size(0),), 0.0, device=self.device) + dummy_metadata = SamplingMetadata( + temperature=torch.full((logits.size(0),), 0.5, device=self.device), + all_greedy=False, + all_random=False, + spec_token_ids=None, + top_p=torch.full((logits.size(0),), 0.99, device=self.device), + top_k=torch.full((logits.size(0),), logits.size(1) - 1, device=self.device), + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=penalties, + presence_penalties=penalties, + repetition_penalties=penalties, + output_token_ids=[[] for _ in range(logits.size(0))], + min_tokens={}, + logit_bias=[None for _ in range(logits.size(0))] + ) + sampler_output = self.model.sample(logits=logits, sampling_metadata=dummy_metadata) else: logits = None + sampler_output = None + penalties = None + dummy_metadata = None torch.cuda.synchronize() - del hidden_states, logits + del hidden_states, logits, sampler_output, penalties, dummy_metadata self.encoder_cache.clear() gc.collect() From 1d86c3c19f9da5ad0887f4b3aa3950e514ad0b6f Mon Sep 17 00:00:00 2001 From: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Date: Thu, 20 Feb 2025 08:27:57 +0000 Subject: [PATCH 2/6] address comments Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 073b94b5d3dbb..0b9d40f3d05d2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,12 +31,13 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.sample.metadata import SamplingMetadata + if TYPE_CHECKING: from vllm.v1.core.scheduler_output import SchedulerOutput @@ -1303,14 +1304,18 @@ def profile_run(self) -> None: if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) - penalties = torch.full((logits.size(0),), 0.0, device=self.device) + penalties = torch.full((num_reqs, ), 0.0, device=self.device) dummy_metadata = SamplingMetadata( - temperature=torch.full((logits.size(0),), 0.5, device=self.device), + temperature=torch.full((num_reqs, ), + 0.5, + device=self.device), all_greedy=False, all_random=False, spec_token_ids=None, - top_p=torch.full((logits.size(0),), 0.99, device=self.device), - top_k=torch.full((logits.size(0),), logits.size(1) - 1, device=self.device), + top_p=torch.full((num_reqs, ), 0.99, device=self.device), + top_k=torch.full((num_reqs, ), + logits.size(1) - 1, + device=self.device), min_p=None, generators={}, max_num_logprobs=None, @@ -1319,11 +1324,11 @@ def profile_run(self) -> None: frequency_penalties=penalties, presence_penalties=penalties, repetition_penalties=penalties, - output_token_ids=[[] for _ in range(logits.size(0))], + output_token_ids=[[] for _ in range(num_reqs)], min_tokens={}, - logit_bias=[None for _ in range(logits.size(0))] - ) - sampler_output = self.model.sample(logits=logits, sampling_metadata=dummy_metadata) + logit_bias=[None for _ in range(num_reqs)]) + sampler_output = self.model.sample( + logits=logits, sampling_metadata=dummy_metadata) else: logits = None sampler_output = None From 7b3106ee7431cc380ab3d2d7032dd3b5458ef300 Mon Sep 17 00:00:00 2001 From: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Date: Fri, 21 Feb 2025 03:38:40 +0000 Subject: [PATCH 3/6] address comments Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0b9d40f3d05d2..fe132993295cb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1304,7 +1304,8 @@ def profile_run(self) -> None: if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) - penalties = torch.full((num_reqs, ), 0.0, device=self.device) + penalties = lambda: torch.full( + (num_reqs, ), 0.0, device=self.device) dummy_metadata = SamplingMetadata( temperature=torch.full((num_reqs, ), 0.5, @@ -1321,9 +1322,9 @@ def profile_run(self) -> None: max_num_logprobs=None, no_penalties=True, prompt_token_ids=None, - frequency_penalties=penalties, - presence_penalties=penalties, - repetition_penalties=penalties, + frequency_penalties=penalties(), + presence_penalties=penalties(), + repetition_penalties=penalties(), output_token_ids=[[] for _ in range(num_reqs)], min_tokens={}, logit_bias=[None for _ in range(num_reqs)]) From b7a4fb67caab20174a569fe3e3db28f9b5b6530b Mon Sep 17 00:00:00 2001 From: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Date: Fri, 21 Feb 2025 06:59:19 +0000 Subject: [PATCH 4/6] update penalties tensors Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fe132993295cb..f652fa048b0a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1304,8 +1304,8 @@ def profile_run(self) -> None: if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) - penalties = lambda: torch.full( - (num_reqs, ), 0.0, device=self.device) + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=torch.full((num_reqs, ), 0.5, @@ -1313,18 +1313,16 @@ def profile_run(self) -> None: all_greedy=False, all_random=False, spec_token_ids=None, - top_p=torch.full((num_reqs, ), 0.99, device=self.device), - top_k=torch.full((num_reqs, ), - logits.size(1) - 1, - device=self.device), + top_p=dummy_tensors(0.9), + top_k=dummy_tensors(logits.size(1) - 1), min_p=None, generators={}, max_num_logprobs=None, - no_penalties=True, - prompt_token_ids=None, - frequency_penalties=penalties(), - presence_penalties=penalties(), - repetition_penalties=penalties(), + no_penalties=False, + prompt_token_ids=torch.ones_like(logits, dtype=torch.long), + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), output_token_ids=[[] for _ in range(num_reqs)], min_tokens={}, logit_bias=[None for _ in range(num_reqs)]) @@ -1333,10 +1331,9 @@ def profile_run(self) -> None: else: logits = None sampler_output = None - penalties = None dummy_metadata = None torch.cuda.synchronize() - del hidden_states, logits, sampler_output, penalties, dummy_metadata + del hidden_states, logits, sampler_output, dummy_metadata self.encoder_cache.clear() gc.collect() From 98c2f90ff052095f6fb8ec3a24e0564677396d00 Mon Sep 17 00:00:00 2001 From: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Date: Fri, 21 Feb 2025 07:03:47 +0000 Subject: [PATCH 5/6] fix Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f652fa048b0a4..7078fcd4d1480 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1307,9 +1307,7 @@ def profile_run(self) -> None: dummy_tensors = lambda v: torch.full( (num_reqs, ), v, device=self.device) dummy_metadata = SamplingMetadata( - temperature=torch.full((num_reqs, ), - 0.5, - device=self.device), + temperature=dummy_tensors(0.5), all_greedy=False, all_random=False, spec_token_ids=None, From cecc5fdbae9c2f97b8ea2b16de07605247eae2da Mon Sep 17 00:00:00 2001 From: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Date: Sat, 22 Feb 2025 04:11:15 +0000 Subject: [PATCH 6/6] set no_penalties=True Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7078fcd4d1480..543ccf171d1c1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1316,8 +1316,8 @@ def profile_run(self) -> None: min_p=None, generators={}, max_num_logprobs=None, - no_penalties=False, - prompt_token_ids=torch.ones_like(logits, dtype=torch.long), + no_penalties=True, + prompt_token_ids=torch.ones_like(logits, dtype=torch.int64), frequency_penalties=dummy_tensors(0.1), presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1),