From df6b9198bf33b345dc13fe53e98f692262826e70 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:18:39 +0200 Subject: [PATCH] Fix Llama 3.1 generation (#1444) --- examples/language-modeling/run_lora_clm.py | 4 ---- examples/text-generation/utils.py | 22 ++++++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 0b16be0725..1c6e29da25 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -700,10 +700,6 @@ def main(): raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") if model.config.model_type == "llama": - # unwind broken decapoda-research config - model.generation_config.pad_token_id = 0 - model.generation_config.bos_token_id = 1 - model.generation_config.eos_token_id = 2 if model_args.attn_softmax_bf16: model.generation_config.attn_softmax_bf16 = True if model_args.use_flash_attention: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 2b0fe0d328..cb734071b0 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -525,16 +525,22 @@ def setup_tokenizer(args, model, assistant_model): tokenizer.padding_side = "left" if model.config.model_type == "llama": - # unwind broken decapoda-research config - model.generation_config.pad_token_id = 0 - model.generation_config.bos_token_id = 1 - model.generation_config.eos_token_id = 2 + if model.generation_config.pad_token_id is None: + if isinstance(model.generation_config.eos_token_id, int): + model.generation_config.pad_token_id = model.generation_config.eos_token_id + elif isinstance(model.generation_config.eos_token_id, list): + model.generation_config.pad_token_id = model.generation_config.eos_token_id[0] if assistant_model is not None: - assistant_model.generation_config.pad_token_id = 0 - assistant_model.generation_config.bos_token_id = 1 - assistant_model.generation_config.eos_token_id = 2 + if assistant_model.generation_config.pad_token_id is None: + if isinstance(assistant_model.generation_config.eos_token_id, int): + assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id + elif isinstance(assistant_model.generation_config.eos_token_id, list): + assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id[0] tokenizer.bos_token_id = model.generation_config.bos_token_id - tokenizer.eos_token_id = model.generation_config.eos_token_id + if isinstance(model.generation_config.eos_token_id, int): + tokenizer.eos_token_id = model.generation_config.eos_token_id + elif isinstance(model.generation_config.eos_token_id, list): + tokenizer.eos_token_id = model.generation_config.eos_token_id[0] tokenizer.pad_token_id = model.generation_config.pad_token_id tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id) tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id)