From 03e2bc5853d04c3eda56030233d7629bfaca9707 Mon Sep 17 00:00:00 2001 From: Harish Subramony <81822986+hsubramony@users.noreply.github.com> Date: Thu, 6 Mar 2025 06:39:37 -0800 Subject: [PATCH] Fix for text-generation, AttributeError: 'GenerationConfig' object has no attribute 'use_fused_rope' (#1823) Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- examples/text-generation/run_lm_eval.py | 3 +++ optimum/habana/transformers/modeling_utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 7f0797489f..43eef61b12 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -27,6 +27,7 @@ import psutil import torch import torch.nn.functional as F +import transformers from lm_eval import evaluator, utils from lm_eval.models.huggingface import HFLM, TemplateLM @@ -36,6 +37,7 @@ from transformers.generation import GenerationConfig from utils import finalize_quantization, initialize_model, save_model +from optimum.habana.transformers.generation import GaudiGenerationConfig from optimum.habana.utils import get_hpu_memory_stats @@ -248,6 +250,7 @@ def get_model_dtype(model) -> str: def main() -> None: # Modified based on cli_evaluate function in https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/__main__.py/#L268 args = setup_lm_eval_parser() + transformers.GenerationConfig = GaudiGenerationConfig model, _, tokenizer, generation_config = initialize_model(args, logger) if args.trust_remote_code: # trust_remote_code fix was introduced in lm_eval 0.4.3 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 291059cbec..53ab91433b 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -376,7 +376,6 @@ def adapt_transformers_to_gaudi(): GaudiGenerationMixin._prepare_cache_for_generation ) transformers.generation.GenerationConfig = GaudiGenerationConfig - transformers.GenerationConfig = GaudiGenerationConfig transformers.generation.configuration_utils.GenerationConfig = GaudiGenerationConfig transformers.modeling_utils.GenerationConfig = GaudiGenerationConfig transformers.generation.MaxLengthCriteria.__call__ = gaudi_MaxLengthCriteria_call