diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_modeling.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_modeling.py index b6f24d4ac..bbf1e019d 100644 --- a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_modeling.py +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_modeling.py @@ -25,7 +25,7 @@ from transformers import PretrainedConfig from transformers_neuronx import bucket from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB -from optimum.neuron.generation import TokenSelector +from djl_python.transformers_neuronx_scheduler.optimum_token_selector import OptimumTokenSelector from optimum.neuron.utils.version_utils import check_compiler_compatibility, get_neuronxcc_version from optimum.modeling_base import OptimizedModel from transformers.generation import StoppingCriteriaList @@ -238,11 +238,12 @@ def generate( self._validate_model_kwargs(model_kwargs) # Instantiate a TokenSelector for the specified configuration - selector = TokenSelector.create(input_ids, - generation_config, - self, - self.max_length, - stopping_criteria=stopping_criteria) + selector = OptimumTokenSelector.create( + input_ids, + generation_config, + self, + self.max_length, + stopping_criteria=stopping_criteria) # Verify that the inputs are compatible with the model static input dimensions batch_size, sequence_length = input_ids.shape @@ -280,7 +281,7 @@ def generate( def generate_tokens( self, input_ids: torch.LongTensor, - selector: TokenSelector, + selector: OptimumTokenSelector, batch_size: int, attention_mask: Optional[torch.Tensor] = None, **model_kwargs, @@ -291,7 +292,7 @@ def generate_tokens( Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. - selector (`TokenSelector`): + selector (`OptimumTokenSelector`): The object implementing the generation logic based on transformers processors and stopping criterias. batch_size (`int`): The actual input batch size. Used to avoid generating tokens for padded inputs. diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py index 61b2b6a2b..24f689196 100644 --- a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_neuron_scheduler.py @@ -24,7 +24,6 @@ from dataclasses import dataclass from djl_python.transformers_neuronx_scheduler.slot import Slot -from djl_python.rolling_batch.rolling_batch import filter_unused_generation_params from djl_python.request import Request from djl_python.transformers_neuronx_scheduler.token_selector import TokenSelector from djl_python.transformers_neuronx_scheduler.speculation import ( diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_token_selector.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_token_selector.py new file mode 100644 index 000000000..c8631b172 --- /dev/null +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/optimum_token_selector.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# +# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +# The below code is heavily inspired from Optimum Neuron under the following link: +# https://github.com/huggingface/optimum-neuron/blob/main/optimum/neuron/generation/token_selector.py + +import copy +import logging +from typing import TYPE_CHECKING, List, Optional + +import torch +from transformers.generation import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, +) +from transformers.generation.utils import GenerationMode + +from optimum.neuron.generation import FusedLogitsWarper + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + +logger = logging.getLogger(__name__) + + +# TODO: This is a temporary solution to avoid Optimum's dependency on transformers<4.42. +class OptimumTokenSelector: + """Implements the token selection logic corresponding to a generation configuration. + + This class combines and uses the logits processors and stopping criterias implemented in + the transformers library. + + The algorithm to select these objects is heavily inspired by the transformers `GenerationMixin.generate()` + method, but the actual token selection methods are specific. + + The reason why this class does not inherit from `GenerationMixin` is because it does not + include the code to produce the tokens logits. + Separating the production of the tokens logits from the tokens selection allows this class + to be used with different generation paradigms, either synchronously using a single `TokenSelector` in + `GenerationMixin.generate()` or asynchronously using multiple `TokenSelector` inside an inference endpoint. + + The constructor of this class should not be called directly: instances should be obtained by + calling `TokenSelector.create()`. + """ + + def __init__( + self, + mode: GenerationMode, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + eos_token_ids: List[int], + pad_token_id: int, + logits_warper: Optional[LogitsProcessorList] = None, + seed: Optional[int] = 0, + ): + self.mode = mode + self.logits_processor = logits_processor + self.stopping_criteria = stopping_criteria + self.eos_token_ids = eos_token_ids + self.pad_token_id = pad_token_id + self.logits_warper = logits_warper + self.generator = torch.Generator() + self.generator.manual_seed(seed) + + @classmethod + def create( + cls, + input_ids: torch.Tensor, + generation_config: GenerationConfig, + model: GenerationMixin, + max_seq_length: int, + stopping_criteria: Optional[StoppingCriteriaList] = None, + tokenizer: Optional["PreTrainedTokenizer"] = None, + seed: Optional[int] = 0, + ) -> "OptimumTokenSelector": + r"""Creates the `TokenSelector` for a specific generation configuration. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + generation_config (`~transformers.generation.GenerationConfig`, *optional*): + The generation configuration to parametrize the token selection. + model (`~transformers.generation.GenerationMixin`): + The model provides the internal helpers allowing to select the logits processors and stopping criterias. + max_seq_length (`int`): + The maximum number of input + generated tokens for this model. It depends on the model compilation parameters. + stopping_criteria (`Optional[transformers.generation.StoppingCriteriaList], defaults to `None`): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config + tokenizer (`Optional[transformers.PreTrainedTokenizer]`, default to `None`): + A tokenizer used when stop strings are passed to generate. + seed(`Optional[int]`): + The optional seed for sampling. Defaults to zero. + Return: + `torch.LongTensor`: A `torch.LongTensor` containing the selected tokens. + """ + generation_config.validate() + generation_config = copy.deepcopy(generation_config) + + unsupported_generation_flags = [ + "output_attentions", + "output_hidden_states", + "output_scores", + "return_dict_in_generate", + ] + for flag in unsupported_generation_flags: + if getattr(generation_config, flag, False): + raise ValueError("{flag} is not supported for generation.") + + if generation_config.max_new_tokens is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[ + -1] + + min_length = generation_config.min_length + if min_length > max_seq_length: + raise ValueError( + f"The minimum generation length ({min_length}) exceeds the model maximum sequence length ({max_seq_length})" + ) + max_length = generation_config.max_length + if max_length > max_seq_length: + logger.warning( + f"Adjusting the maximum generation length ({max_length}) to the model maximum sequence length ({max_seq_length})" + ) + generation_config.max_length = max_seq_length + + # This is not supposed to happen for any of the models we support + eos_token_id = generation_config.eos_token_id + assert eos_token_id is not None + # The generation requires special tokens + eos_token_ids = eos_token_id if isinstance(eos_token_id, + list) else [eos_token_id] + generation_config._eos_token_tensor = torch.tensor( + eos_token_ids, device=input_ids.device) + if generation_config.pad_token_id is None: + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-ended generation." + ) + generation_config.pad_token_id = eos_token_ids[0] + + # Instantiate transformers library processors and criterias + logits_processor = model._get_logits_processor( + generation_config, + input_ids_seq_length=input_ids.shape[-1], + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=LogitsProcessorList(), + ) + if stopping_criteria is None: + stopping_criteria = StoppingCriteriaList() + stopping_criteria = model._get_stopping_criteria( + generation_config, + stopping_criteria=stopping_criteria, + tokenizer=tokenizer) + + generation_mode = generation_config.get_generation_mode() + if generation_mode not in [ + GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE + ]: + raise ValueError("Unsupported generation mode") + + logits_warper = None + if generation_mode == GenerationMode.SAMPLE: + logits_warper = FusedLogitsWarper.from_config(generation_config) + + return cls( + mode=generation_mode, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + logits_warper=logits_warper, + eos_token_ids=eos_token_ids, + pad_token_id=generation_config.pad_token_id, + seed=seed, + ) + + def select(self, input_ids: torch.LongTensor, + logits: torch.Tensor) -> torch.LongTensor: + """Select the next tokens from the candidate logits. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation (not used in all generation modes). + logits (`torch.Tensor` of shape `(batch_size, sequence_length)`): + The logits corresponding to the generated tokens. + + Return: + `torch.LongTensor`: A `torch.LongTensor` containing the selected tokens. + """ + scores = self.logits_processor(input_ids, logits) + if self.mode == GenerationMode.SAMPLE: + return self._sample(scores) + else: + return torch.argmax(scores, dim=-1) + + def _sample(self, scores: torch.Tensor) -> torch.LongTensor: + # Get [batch_size, kept] scores and indices instead of [batch_size, vocab_size] scores + scores, next_token_indices = self.logits_warper(scores) + + # sample + probs = torch.nn.functional.softmax(scores, dim=-1) + next_tokens = torch.multinomial(probs, + num_samples=1, + generator=self.generator) + # Convert the filtered tokens to actual vocabulary tokens + next_tokens = torch.gather(next_token_indices, 1, next_tokens) + return next_tokens.squeeze(1) diff --git a/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py b/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py index c4d88ce1c..1a82bd5ff 100644 --- a/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py +++ b/engines/python/setup/djl_python/transformers_neuronx_scheduler/token_selector.py @@ -169,7 +169,8 @@ def create( logits_warper = None if generation_mode == GenerationMode.SAMPLE: - logits_warper = model._get_logits_warper(generation_config) + logits_warper = model._get_logits_warper(generation_config, + device=model.device) if len(logits_warper) == 0: generation_mode = GenerationMode.GREEDY_SEARCH