Skip to content

Commit

Permalink
Revert PR #1473 (#1582)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Dec 9, 2024
1 parent cfc0c48 commit 04901ae
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
3 changes: 2 additions & 1 deletion examples/text-generation/requirements_lm_eval.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
https://github.com/EleutherAI/lm-evaluation-harness/archive/c1d8795da7610d507cb191c2769c5e7bf1060a35.zip
https://github.com/EleutherAI/lm-evaluation-harness/archive/0bf683b4e6a9df359b3156ba9ba8d62bdd47e0c0.zip
datasets==2.21.0
19 changes: 14 additions & 5 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import psutil
import torch
import torch.nn.functional as F
from lm_eval.models.huggingface import HFLM

# Local imports
from run_generation import setup_parser
Expand Down Expand Up @@ -92,15 +91,17 @@ def setup_lm_eval_parser():
return args


class HabanaModelAdapter(HFLM):
class HabanaModelAdapter(lm_eval.base.BaseLM):
def __init__(self, tokenizer, model, args, options):
super().__init__(pretrained=model, tokenizer=tokenizer, batch_size=args.batch_size)
super().__init__()
self.tokenizer = tokenizer
self.model = model
self._batch_size = args.batch_size
self.buckets = sorted(args.buckets)
self.options = options
self._device = args.device
self.model_inputs = {"use_cache": self.options.use_cache}
if self._model.config.model_type in [
if self.model.config.model_type in [
"llama",
"mistral",
"falcon",
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(self, tokenizer, model, args, options):

def warm_up(self):
for bucket_size in reversed(self.buckets):
inps = torch.ones((self.batch_size, bucket_size), dtype=torch.int64)
inps = torch.ones((self._batch_size, bucket_size), dtype=torch.int64)
self._model_call(inps)
pass

Expand All @@ -148,6 +149,14 @@ def eot_token_id(self):
def max_length(self):
return self.buckets[-1]

@property
def max_gen_toks(self):
raise NotImplementedError()

@property
def batch_size(self):
return self._batch_size

@property
def device(self):
# We need to do padding ourselves, otherwise we'll end up with recompilations
Expand Down

0 comments on commit 04901ae

Please sign in to comment.