From c5ff677e19ec10ed521f4aacbde99b58a822249b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 27 Jun 2024 11:31:11 -0700 Subject: [PATCH] [BugFix] Fix `min_tokens` behaviour for multiple eos tokens (#5849) --- vllm/engine/llm_engine.py | 7 ++----- vllm/sampling_params.py | 29 +++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0ad957ef9f958..4b427b1fb2f22 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -606,12 +606,9 @@ def _create_sequence_group_with_sampling( # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() - # Add the eos token id into the sampling_params to support min_tokens - # processing - if seq.eos_token_id is not None: - sampling_params.all_stop_token_ids.add(seq.eos_token_id) + sampling_params.update_from_generation_config( - self.generation_config_fields) + self.generation_config_fields, seq.eos_token_id) # Create the sequence group. seq_group = SequenceGroup( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9d8a361353e26..a2caae21a86e3 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -280,17 +280,30 @@ def _verify_greedy_sampling(self) -> None: f"Got {self.best_of}.") def update_from_generation_config( - self, generation_config: Dict[str, Any]) -> None: + self, + generation_config: Dict[str, Any], + model_eos_token_id: Optional[int] = None) -> None: """Update if there are non-default values from generation_config""" + + if model_eos_token_id is not None: + # Add the eos token id into the sampling_params to support + # min_tokens processing. + self.all_stop_token_ids.add(model_eos_token_id) + # Update eos_token_id for generation - if (not self.ignore_eos) and (eos_ids := - generation_config.get("eos_token_id")): + if (eos_ids := generation_config.get("eos_token_id")) is not None: # it can be either int or list of int - if isinstance(eos_ids, int): - eos_ids = [eos_ids] - original_stop_token_ids = set(self.stop_token_ids) - original_stop_token_ids.update(eos_ids) - self.stop_token_ids = list(original_stop_token_ids) + eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + if model_eos_token_id is not None: + # We don't need to include the primary eos_token_id in + # stop_token_ids since it's handled separately for stopping + # purposes. + eos_ids.discard(model_eos_token_id) + if eos_ids: + self.all_stop_token_ids.update(eos_ids) + if not self.ignore_eos: + eos_ids.update(self.stop_token_ids) + self.stop_token_ids = list(eos_ids) @cached_property def sampling_type(self) -> SamplingType: