diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py index d75a0600b..b53f9ce07 100644 --- a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py @@ -190,6 +190,15 @@ def increment_cache_id(self): def trim_cache_id(self): self._cache_id = self._cache_id.max() + def is_slot_eos_token(self, token) -> bool: + if hasattr(self._generation_config, "eos_token_id"): + if isinstance(self._generation_config.eos_token_id, int): + return token == self._generation_config.eos_token_id + else: + return token in self._generation_config.eos_token_id + else: + return False + @property def stopped(self) -> bool: return self._selector.stopping_criteria(self._tokens, None) @@ -381,7 +390,9 @@ def _generate_token( slot.append(next_token, next_token_text) generated_text = None finish_reason = None - if next_token == self.tokenizer.eos_token_id: + if slot.is_slot_eos_token(next_token): + finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN + elif next_token == self.tokenizer.eos_token_id: finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN elif slot.stopped: finish_reason = FinishReason.FINISH_REASON_LENGTH @@ -403,7 +414,8 @@ def _generate_token( token_id=next_token, token_logprob=next_log_prob, token_text=next_token_text, - token_is_special=(next_token in [self.special_tokens]), + token_is_special=(next_token in [self.special_tokens]) + or (finish_reason == FinishReason.FINISH_REASON_EOS_TOKEN), generated_text=generated_text, )) return generations diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py index b94149b74..612866b08 100644 --- a/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py @@ -13,7 +13,7 @@ # The below code is heavily inspired from Optimum Neuron under the following link: # https://github.com/huggingface/optimum-neuron/blob/974f34336bb36b1b64890c191c558a1575372be7/optimum/neuron/generation/token_selector.py import logging -from typing import Optional +from typing import Optional, Union, List import torch from transformers.generation import ( GenerationConfig, @@ -63,7 +63,7 @@ def __init__( mode: GenerationMode, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, - eos_token_id: int, + eos_token_id: Union[List[int], int], pad_token_id: int, logits_warper: Optional[LogitsProcessorList] = None, ): @@ -139,9 +139,9 @@ def create(cls, input_ids: torch.Tensor, # The generation requires special tokens eos_token_id = generation_config.eos_token_id # This is not supposed to happen for any of the models we support - assert eos_token_id is not None and not isinstance(eos_token_id, list) if generation_config.pad_token_id is None: - generation_config.pad_token_id = eos_token_id + generation_config.pad_token_id = eos_token_id if isinstance( + eos_token_id, int) else eos_token_id[0] generation_mode = model._get_generation_mode(generation_config, None) if generation_mode not in [