Skip to content

Commit

Permalink
add falcon180b FP8 test (#104) (#123)
Browse files Browse the repository at this point in the history
Co-authored-by: Sun Choi <schoi@habana.ai>
  • Loading branch information
vivekgoe and schoi-habana authored Mar 20, 2024
1 parent 9e0975f commit acc65c1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
14 changes: 8 additions & 6 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ def __init__(self, tokenizer, model, args, options):
self.options = options
self._device = args.device
self.model_inputs = {"use_cache": self.options.use_cache}
if self.model.config.model_type == "llama":
if self.model.config.model_type == "llama" or "falcon":
self.model_inputs.update(
{
"reuse_cache": self.options.reuse_cache,
}
)
if self.model.config.model_type == "llama":
self.model_inputs.update(
{
"attn_softmax_bf16": self.options.attn_softmax_bf16,
}
)
Expand Down Expand Up @@ -131,11 +136,7 @@ def _model_call(self, inps):
if self.options.static_shapes:
bucket_length = self.find_bucket(seq_length)
if self.options.use_cache and self.options.reuse_cache:
self.model.allocate_kv_cache(
bs,
bucket_length + 1,
bucket_length
)
self.model.allocate_kv_cache(bs, bucket_length + 1, bucket_length)
padding_length = bucket_length - seq_length
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu()
Expand Down Expand Up @@ -177,6 +178,7 @@ def main():
habana_quantization_toolkit.finish_measurements(model)
if args.const_serialization_path and os.path.isdir(args.const_serialization_path):
import shutil

shutil.rmtree(args.const_serialization_path)


Expand Down
27 changes: 27 additions & 0 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
("mistralai/Mistral-7B-v0.1", 125.26115369093216),
("mistralai/Mixtral-8x7B-v0.1", 23.78652574031883),
],
"fp8": [
("tiiuae/falcon-180B", 47.67900945905787),
],
"deepspeed": [
("bigscience/bloomz", 36.34664210641816),
("meta-llama/Llama-2-70b-hf", 61.973950428647164),
Expand Down Expand Up @@ -69,6 +72,7 @@ def _test_text_generation(
deepspeed: bool = False,
world_size: int = 8,
torch_compile: bool = False,
fp8: bool = False,
):
command = ["python3"]
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
Expand Down Expand Up @@ -103,6 +107,13 @@ def _test_text_generation(
if not deepspeed:
command.append("--bf16")

if fp8:
command += [
"--fp8",
"--reuse_cache",
"--trim_logits",
]

with TemporaryDirectory() as tmp_dir:
command.append(f"--output_dir {tmp_dir}")
print(f"\n\nCommand to test: {' '.join(command)}\n")
Expand All @@ -112,6 +123,15 @@ def _test_text_generation(
pattern = re.compile(r"([\"\'].+?[\"\'])|\s")
command = [x for y in command for x in re.split(pattern, y) if x]

if fp8:
os.environ["QUANT_CONFIG"] = os.path.join(
path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json"
)
subprocess.run(command)
os.environ["QUANT_CONFIG"] = os.path.join(
path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json"
)

proc = subprocess.run(command)

# Ensure the run finished without any issue
Expand All @@ -135,6 +155,13 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str):
_test_text_generation(model_name, baseline, token)


@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["fp8"])
def test_text_generation_fp8(model_name: str, baseline: float, token: str):
deepspeed = True if "falcon-180B" in model_name else False
world_size = 8 if "falcon-180B" in model_name else None
_test_text_generation(model_name, baseline, token, deepspeed=deepspeed, world_size=world_size, fp8=True)


@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["deepspeed"])
def test_text_generation_deepspeed(model_name: str, baseline: float, token: str):
world_size = 2 if "opt-66b" in model_name else 8
Expand Down

0 comments on commit acc65c1

Please sign in to comment.