Skip to content

Commit

Permalink
[BugFix] Fix min_tokens behaviour for multiple eos tokens (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and prashantgupta24 committed Jul 1, 2024
1 parent ca516bd commit c5ff677
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
7 changes: 2 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,9 @@ def _create_sequence_group_with_sampling(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)

sampling_params.update_from_generation_config(
self.generation_config_fields)
self.generation_config_fields, seq.eos_token_id)

# Create the sequence group.
seq_group = SequenceGroup(
Expand Down
29 changes: 21 additions & 8 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,30 @@ def _verify_greedy_sampling(self) -> None:
f"Got {self.best_of}.")

def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
self,
generation_config: Dict[str, Any],
model_eos_token_id: Optional[int] = None) -> None:
"""Update if there are non-default values from generation_config"""

if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id)

# Update eos_token_id for generation
if (not self.ignore_eos) and (eos_ids :=
generation_config.get("eos_token_id")):
if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids.discard(model_eos_token_id)
if eos_ids:
self.all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)

@cached_property
def sampling_type(self) -> SamplingType:
Expand Down

0 comments on commit c5ff677

Please sign in to comment.