Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Eval() api to support LM-Eval or EvalPlus benchmark harnesses #750

Merged
merged 13 commits into from
Dec 5, 2024
71 changes: 70 additions & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from huggingface_hub import list_repo_files
from transformers import AutoConfig

from ..utils import BACKEND
from ..utils import BACKEND, EVAL, EVALPLUS_TASK, LM_EVAL_TASK
from ..utils.logger import setup_logger
from ..utils.model import check_and_get_model_type
from .base import BaseGPTQModel, QuantizeConfig
Expand Down Expand Up @@ -217,3 +217,72 @@ def from_quantized(
verify_hash=verify_hash,
**kwargs,
)

@classmethod
def eval(
cls,
model_id_or_path: str,
framework: EVAL,
tasks: Union[List[LM_EVAL_TASK], List[EVALPLUS_TASK]],
batch: int = 1,
trust_remote_code: bool = False,
):
if framework is None:
raise ValueError("eval parameter: `framework` cannot be set to None")

if not isinstance(tasks, list):
raise ValueError("eval parameter: `tasks` must be of List type")

if framework == EVAL.LM_EVAL:
for task in tasks:
if task not in LM_EVAL_TASK.get_task_enums():
raise ValueError(f"lm_eval support tasks: {LM_EVAL_TASK.get_all_tasks_string()}")

from pathlib import Path

from gptqmodel.utils.eval import lm_eval
from lm_eval.utils import make_table
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)

result_path = Path("lm_eval_results")
result_path.mkdir(parents=True, exist_ok=True)

results = lm_eval(
model_id_or_path,
model_name="hf",
model_args=f"pretrained={model_id_or_path},gptqmodel=True",
tasks=[task.value for task in tasks],
trust_remote_code=trust_remote_code,
batch_size=batch,
apply_chat_template=True if tokenizer.chat_template is not None else False,
output_path=str(result_path)
)
print('--------lm_eval Eval Result---------')
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
print('--------lm_eval Result End---------')
return results
elif framework == EVAL.EVALPLUS:
for task in tasks:
if task not in EVALPLUS_TASK.get_task_enums():
raise ValueError(f"evalplus support tasks: {EVALPLUS_TASK.get_all_tasks_string()}")
from gptqmodel.utils.eval import evalplus, evalplus_make_table

results = {}
for task in tasks:
base_formatted, plus_formatted, result_path = evalplus(
model=model_id_or_path,
dataset=task.value,
batch=batch,
trust_remote_code=trust_remote_code,
)
results[task.value] = {"base tests": base_formatted, "base + extra tests": plus_formatted, "results_path": result_path}
print('--------evalplus Eval Result---------')
evalplus_make_table(results)
print('--------evalplus Result End---------')
return results
else:
raise ValueError(f"Eval backend support: {EVAL.get_all_eval_backend_string()}")
1 change: 1 addition & 0 deletions gptqmodel/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .backend import BACKEND, get_backend
from .eval import EVAL, EVALPLUS_TASK, LM_EVAL_TASK
from .perplexity import Perplexity
from .vram import get_vram
116 changes: 114 additions & 2 deletions gptqmodel/utils/lm_eval.py → gptqmodel/utils/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,115 @@
import json
import os
from enum import Enum
from typing import List, Optional, Union


class EVAL(Enum):
LM_EVAL = 0
EVALPLUS = 1

@classmethod
def get_task_enums(cls):
return list(cls)

@classmethod
def get_full_name(cls, member):
return f"{cls.__name__}.{member.name}"

@classmethod
def get_all_eval_backend_string(cls):
full_names = [cls.get_full_name(member) for member in cls]
return ', '.join(full_names)


class LM_EVAL_TASK(Enum):
ARC_CHALLENGE = "arc_challenge"
MMLU = "mmlu"
HELLASWAG = "hellaswag"
GSM8K_COT = "gsm8k_cot"

@classmethod
def get_task_enums(cls):
return list(cls)

@classmethod
def get_full_name(cls, member):
return f"{cls.__name__}.{member.name}"

@classmethod
def get_all_tasks_string(cls):
full_names = [cls.get_full_name(member) for member in cls]
return ', '.join(full_names)


class EVALPLUS_TASK(Enum):
HUMAN = "humaneval"
MBPP = "mbpp"

@classmethod
def get_task_enums(cls):
return list(cls)

@classmethod
def get_full_name(cls, member):
return f"{cls.__name__}.{member.name}"

@classmethod
def get_all_tasks_string(cls):
full_names = [cls.get_full_name(member) for member in cls]
return ', '.join(full_names)


def evalplus(
model: str,
dataset: str,
batch: int = 1,
trust_remote_code: bool = False,
):
try:
from evalplus.evaluate import evaluate
except BaseException:
raise ValueError("evalplus is not installed. Please install via `pip install gptqmodel[evalplus]`.")

assert dataset in ["humaneval", "mbpp"], f"Invalid dataset {dataset}"

evaluate(dataset=dataset, model=model, backend="gptqmodel", bs=batch, trust_remote_code=trust_remote_code,
greedy=True)

result_path = model.strip("./").replace("/", "--") + "_gptqmodel_temp_0.0_eval_results.json"
result_path = os.path.join("evalplus_results", dataset, result_path)

if not os.path.exists(result_path):
raise FileNotFoundError(f"No such file: {result_path}")

try:
with open(result_path, 'r') as file:
data = json.load(file)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON: {result_path}")

try:
pass_at_k = data["pass_at_k"]
base = float(pass_at_k["base"]["pass@1"])
plus = float(pass_at_k["plus"]["pass@1"])

base_formatted = format(base, ".3f")
plus_formatted = format(plus, ".3f")
except KeyError as e:
raise ValueError(f"Required key not found in JSON: {str(e)}")
except ValueError as e:
raise ValueError(f"Data format error: {str(e)}")

return base_formatted, plus_formatted, result_path


def evalplus_make_table(results):
print("| Tasks | base tests | base + extra tests |")
print("|-------------|------------|--------------------|")
for task, metrics in results.items():
print(f"| {task} | {metrics['base tests']} | {metrics['base + extra tests']} |")


try:
from lm_eval import simple_evaluate
from lm_eval.loggers import EvaluationTracker, WandbLogger
Expand All @@ -10,7 +119,6 @@
except BaseException:
raise ValueError("lm_eval is not installed. Please install via `pip install gptqmodel[eval]`.")


def lm_eval(
model,
model_args: str = "",
Expand Down Expand Up @@ -45,6 +153,7 @@ def lm_eval(
wandb_name: Optional[str] = None,
show_config: bool = False,
trust_remote_code: bool = False,
device: Optional[str] = None,
):
if model_name == "hf":
model_name = HFLM(
Expand All @@ -60,7 +169,7 @@ def lm_eval(
model=model_name,
model_args=model_args,
tasks=tasks,
device=str(model.device),
device=device,
num_fewshot=num_fewshot,
batch_size=batch_size,
max_batch_size=max_batch_size,
Expand Down Expand Up @@ -123,3 +232,6 @@ def lm_eval(
return results
else:
raise ValueError('lm_eval run fail, check your code!!!')



2 changes: 1 addition & 1 deletion gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def hf_select_quant_linear(
group_size: int,
desc_act: bool,
sym: bool,
backend: Optional[BACKEND] = None,
checkpoint_format: str,
backend: Optional[BACKEND] = None,
meta: Optional[Dict[str, any]] = None,
device_map: Optional[Union[str, dict]] = None,
) -> Type[BaseQuantLinear]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def run(self):
'ipex': ["intel_extension_for_pytorch>=2.5.0"],
'auto_round': ["auto_round>=0.3"],
'logger': ["clearml", "random_word", "plotly"],
'eval': ["lm_eval>=0.4.6"],
'eval': ["lm_eval>=0.4.6", "evalplus>=0.3.1"],
'triton': ["triton>=2.0.0"]
},
include_dirs=include_dirs,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.quantization import FORMAT # noqa: E402
from gptqmodel.quantization.config import QuantizeConfig # noqa: E402
from gptqmodel.utils.lm_eval import lm_eval # noqa: E402
from gptqmodel.utils.eval import lm_eval # noqa: E402
from lm_eval.utils import make_table # noqa: E402
from transformers import AutoTokenizer # noqa: E402

Expand Down
40 changes: 40 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import unittest
from typing import Union

from gptqmodel import GPTQModel
from gptqmodel.utils import EVAL, EVALPLUS_TASK, LM_EVAL_TASK
from parameterized import parameterized

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


class TestEval(unittest.TestCase):
@classmethod
def setUpClass(self):
self.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct"

@parameterized.expand(
[
(EVAL.LM_EVAL, LM_EVAL_TASK.ARC_CHALLENGE),
(EVAL.EVALPLUS, EVALPLUS_TASK.HUMAN)
]
)
def test_eval(self, eval_backend: EVAL, task: Union[LM_EVAL_TASK, EVALPLUS_TASK]):
results = GPTQModel.eval(self.MODEL_ID, framework=eval_backend, tasks=[task], batch=32)
if eval_backend == EVAL.LM_EVAL:
acc_score = results['results'].get(task.value, {}).get('acc,none')
acc_norm_score = results['results'].get(task.value, {}).get('acc_norm,none')

self.assertGreaterEqual(acc_score, 0.31, "acc score does not match expected result")
self.assertGreaterEqual(acc_norm_score, 0.35, "acc_norm score does not match expected result")
elif eval_backend == EVAL.EVALPLUS:
result = results.get(task.value)
base_formatted, plus_formatted, _ = float(result.get("base tests")), float(result.get("base + extra tests")), result.get("results_path")
self.assertGreaterEqual(base_formatted, 0.31, "Base score does not match expected result")
self.assertGreaterEqual(plus_formatted, 0.29, "Plus score does not match expected result")





19 changes: 19 additions & 0 deletions tests/test_evalplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import unittest

from gptqmodel.utils.eval import evalplus

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


class TestEvalplus(unittest.TestCase):
@classmethod
def setUpClass(self):
self.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct"

def test_evalplus(self):
base_formatted, plus_formatted, _ = evalplus(model=self.MODEL_ID, dataset='humaneval')
self.assertGreaterEqual(float(base_formatted), 0.31, "Base score does not match expected result")
self.assertGreaterEqual(float(plus_formatted), 0.29, "Plus score does not match expected result")


2 changes: 1 addition & 1 deletion tests/test_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest # noqa: E402

from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.utils.lm_eval import lm_eval # noqa: E402
from gptqmodel.utils.eval import lm_eval # noqa: E402
from lm_eval.utils import make_table # noqa: E402


Expand Down