Skip to content

Commit

Permalink
[V1][Sampler] Don't apply temp for greedy-only (#13311)
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Hill <nhill@redhat.com>
  • Loading branch information
njhill authored Feb 15, 2025
1 parent e7eea5a commit 6a854c7
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def forward(
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)

Expand Down Expand Up @@ -82,9 +80,21 @@ def sample(
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(logits)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled

# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)

# Apply min_p.
if not sampling_metadata.no_min_p:
logits = self.apply_min_p(logits, sampling_metadata.min_p)

# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
Expand All @@ -94,13 +104,9 @@ def sample(
sampling_metadata.top_p,
)

if not sampling_metadata.no_min_p:
logits = self.apply_min_p(logits, sampling_metadata.min_p)

if sampling_metadata.all_random:
if greedy_sampled is None:
return random_sampled

greedy_sampled = self.greedy_sample(logits)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
Expand Down

0 comments on commit 6a854c7

Please sign in to comment.