From 25306d7a82b83a4c33ce8f633779a3e1712655b7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Sat, 7 Dec 2024 05:24:34 +0800 Subject: [PATCH] Fix mllama test (#1569) Signed-off-by: Wang, Yi A --- optimum/habana/transformers/generation/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 9468bea956..6b5a2534c3 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2407,9 +2407,11 @@ def _sample( assert "position_ids" not in model_kwargs, "Untested path" token_idx = model_kwargs.get("token_idx", None) + start_token_idx = cur_len if token_idx is not None: # Update cur_len in case of static shapes cur_len = (token_idx + model_kwargs.get("inputs_embeds_offset", 0)).item() + start_token_idx = token_idx time_to_first_token_done = False model_kwargs["pad_done"] = False @@ -2617,7 +2619,10 @@ def _sample( if batch_size > 1 and has_eos_stopping_criteria: eos_token_id = generation_config.eos_token_id # Find the positions of the first eos_token_id in each sequence - eos_positions = (input_ids[:, INITIAL_TOKEN_IDX:] == eos_token_id).int().argmax(dim=1) + INITIAL_TOKEN_IDX + eos_positions = ( + torch.isin(input_ids[:, start_token_idx:], torch.tensor(eos_token_id)).int().argmax(dim=1) + + start_token_idx + ) # Create a mask for positions greater than the first eos_token_id mask = torch.arange(max_length).expand(batch_size, max_length) > eos_positions.unsqueeze(1) # Apply the mask to set positions greater than the first eos_token_id to pad_token_id