Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Sampler] Don't apply temp for greedy-only #13311

Merged
merged 3 commits into from
Feb 15, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ def forward(
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)

Expand Down Expand Up @@ -82,9 +80,17 @@ def sample(
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(logits)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
return greedy_sampled

# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)

# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
Expand All @@ -93,10 +99,9 @@ def sample(
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
if sampling_metadata.all_random:
if greedy_sampled is None:
return random_sampled

greedy_sampled = self.greedy_sample(logits)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
Expand Down