Skip to content

Commit

Permalink
Fix for text-generation, AttributeError: 'GenerationConfig' object ha…
Browse files Browse the repository at this point in the history
…s no attribute 'use_fused_rope' (#1823)

Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
  • Loading branch information
hsubramony and regisss authored Mar 6, 2025
1 parent cebe3ab commit 03e2bc5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 3 additions & 0 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import psutil
import torch
import torch.nn.functional as F
import transformers
from lm_eval import evaluator, utils
from lm_eval.models.huggingface import HFLM, TemplateLM

Expand All @@ -36,6 +37,7 @@
from transformers.generation import GenerationConfig
from utils import finalize_quantization, initialize_model, save_model

from optimum.habana.transformers.generation import GaudiGenerationConfig
from optimum.habana.utils import get_hpu_memory_stats


Expand Down Expand Up @@ -248,6 +250,7 @@ def get_model_dtype(model) -> str:
def main() -> None:
# Modified based on cli_evaluate function in https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/__main__.py/#L268
args = setup_lm_eval_parser()
transformers.GenerationConfig = GaudiGenerationConfig
model, _, tokenizer, generation_config = initialize_model(args, logger)
if args.trust_remote_code:
# trust_remote_code fix was introduced in lm_eval 0.4.3
Expand Down
1 change: 0 additions & 1 deletion optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def adapt_transformers_to_gaudi():
GaudiGenerationMixin._prepare_cache_for_generation
)
transformers.generation.GenerationConfig = GaudiGenerationConfig
transformers.GenerationConfig = GaudiGenerationConfig
transformers.generation.configuration_utils.GenerationConfig = GaudiGenerationConfig
transformers.modeling_utils.GenerationConfig = GaudiGenerationConfig
transformers.generation.MaxLengthCriteria.__call__ = gaudi_MaxLengthCriteria_call
Expand Down

0 comments on commit 03e2bc5

Please sign in to comment.