From 0b7c32f7c6e879d98d717a3896e963b581cf819b Mon Sep 17 00:00:00 2001 From: Urszula Golowicz Date: Tue, 3 Dec 2024 17:32:32 +0100 Subject: [PATCH] Restore performance in generate (#1546) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Urszula Golowicz Co-authored-by: Marcin Łapiński Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> --- .../habana/transformers/generation/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 66a94ff65c..9468bea956 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -117,6 +117,8 @@ "deepseek_v2", ] +# Initial generated token index is set to 1 to accomodate SOS (start of string) token. +INITIAL_TOKEN_IDX = 1 logger = logging.get_logger(__name__) @@ -1149,7 +1151,7 @@ def generate( else: assert generation_config.bucket_size <= 0, "Untested path for bucket>0" if model_kwargs.get("decoder_input_ids", None) is None: - token_idx = 1 + token_idx = INITIAL_TOKEN_IDX else: token_idx = model_kwargs["decoder_input_ids"].shape[-1] model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device) @@ -2614,14 +2616,12 @@ def _sample( if batch_size > 1 and has_eos_stopping_criteria: eos_token_id = generation_config.eos_token_id - idx_bs = generation_config.max_length - for i in range(batch_size): - for idx in range(len(input_ids[i])): - if input_ids[i][idx] == eos_token_id: - idx_bs = idx - if idx > idx_bs: - input_ids[i][idx] = pad_token_id - idx_bs = generation_config.max_length + # Find the positions of the first eos_token_id in each sequence + eos_positions = (input_ids[:, INITIAL_TOKEN_IDX:] == eos_token_id).int().argmax(dim=1) + INITIAL_TOKEN_IDX + # Create a mask for positions greater than the first eos_token_id + mask = torch.arange(max_length).expand(batch_size, max_length) > eos_positions.unsqueeze(1) + # Apply the mask to set positions greater than the first eos_token_id to pad_token_id + input_ids[mask] = pad_token_id if return_dict_in_generate: if self.config.is_encoder_decoder: