From 04901ae0746548856463e733a04eb89f40812903 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:10:56 -0600 Subject: [PATCH] Revert PR #1473 (#1582) --- .../text-generation/requirements_lm_eval.txt | 3 ++- examples/text-generation/run_lm_eval.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/text-generation/requirements_lm_eval.txt b/examples/text-generation/requirements_lm_eval.txt index 272b9365db..e632dc1236 100644 --- a/examples/text-generation/requirements_lm_eval.txt +++ b/examples/text-generation/requirements_lm_eval.txt @@ -1 +1,2 @@ -https://github.com/EleutherAI/lm-evaluation-harness/archive/c1d8795da7610d507cb191c2769c5e7bf1060a35.zip +https://github.com/EleutherAI/lm-evaluation-harness/archive/0bf683b4e6a9df359b3156ba9ba8d62bdd47e0c0.zip +datasets==2.21.0 diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index bec3ec6f6e..689860fc7c 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -29,7 +29,6 @@ import psutil import torch import torch.nn.functional as F -from lm_eval.models.huggingface import HFLM # Local imports from run_generation import setup_parser @@ -92,15 +91,17 @@ def setup_lm_eval_parser(): return args -class HabanaModelAdapter(HFLM): +class HabanaModelAdapter(lm_eval.base.BaseLM): def __init__(self, tokenizer, model, args, options): - super().__init__(pretrained=model, tokenizer=tokenizer, batch_size=args.batch_size) + super().__init__() self.tokenizer = tokenizer + self.model = model + self._batch_size = args.batch_size self.buckets = sorted(args.buckets) self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self._model.config.model_type in [ + if self.model.config.model_type in [ "llama", "mistral", "falcon", @@ -136,7 +137,7 @@ def __init__(self, tokenizer, model, args, options): def warm_up(self): for bucket_size in reversed(self.buckets): - inps = torch.ones((self.batch_size, bucket_size), dtype=torch.int64) + inps = torch.ones((self._batch_size, bucket_size), dtype=torch.int64) self._model_call(inps) pass @@ -148,6 +149,14 @@ def eot_token_id(self): def max_length(self): return self.buckets[-1] + @property + def max_gen_toks(self): + raise NotImplementedError() + + @property + def batch_size(self): + return self._batch_size + @property def device(self): # We need to do padding ourselves, otherwise we'll end up with recompilations