From 134eabdc2123cb0fbedacb0e9e5b04acdac93e29 Mon Sep 17 00:00:00 2001 From: Yun Liu Date: Mon, 2 Dec 2024 09:46:27 +0800 Subject: [PATCH] For the issue: when ignore_eos=False and the model's pad_token==eos_token (like Llama3), the generated results in same batch size erased. --- .../habana/transformers/generation/utils.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 453e1f22d1..a390794c94 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2612,15 +2612,24 @@ def _sample( streamer.end() if batch_size > 1 and has_eos_stopping_criteria: + # cover the over-generated tokens after eos_token with pad_token eos_token_id = generation_config.eos_token_id - idx_bs = generation_config.max_length + def find_first_eos_token_idx_in_input_ids(batch_id) -> int : + idx = 0 + max_length = len(input_ids[batch_id]) + while idx < max_length and input_ids[batch_id][idx] == pad_token_id : + idx = idx+1 + if isinstance(eos_token_id, list) : + while idx < max_length and input_ids[batch_id][idx] not in eos_token_id: + idx = idx+1 + elif isinstance(eos_token_id, int) : + while idx < max_length and input_ids[batch_id][idx] != eos_token_id: + idx = idx+1 + return idx for i in range(batch_size): - for idx in range(len(input_ids[i])): - if input_ids[i][idx] == eos_token_id: - idx_bs = idx - if idx > idx_bs: - input_ids[i][idx] = pad_token_id - idx_bs = generation_config.max_length + eos_idx = find_first_eos_token_idx_in_input_ids(i) + for j in range(eos_idx+1, len(input_ids[i])) : + input_ids[i][j] = pad_token_id if return_dict_in_generate: if self.config.is_encoder_decoder: