diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4f90306354..8d61118890 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,10 +75,15 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type == "llama": + if self.model.config.model_type == "llama" or "falcon": self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, + } + ) + if self.model.config.model_type == "llama": + self.model_inputs.update( + { "attn_softmax_bf16": self.options.attn_softmax_bf16, } ) @@ -131,11 +136,7 @@ def _model_call(self, inps): if self.options.static_shapes: bucket_length = self.find_bucket(seq_length) if self.options.use_cache and self.options.reuse_cache: - self.model.allocate_kv_cache( - bs, - bucket_length + 1, - bucket_length - ) + self.model.allocate_kv_cache(bs, bucket_length + 1, bucket_length) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu() @@ -177,6 +178,7 @@ def main(): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil + shutil.rmtree(args.const_serialization_path) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index ff6f94d002..00602dbd0e 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -26,6 +26,9 @@ ("mistralai/Mistral-7B-v0.1", 125.26115369093216), ("mistralai/Mixtral-8x7B-v0.1", 23.78652574031883), ], + "fp8": [ + ("tiiuae/falcon-180B", 47.67900945905787), + ], "deepspeed": [ ("bigscience/bloomz", 36.34664210641816), ("meta-llama/Llama-2-70b-hf", 61.973950428647164), @@ -69,6 +72,7 @@ def _test_text_generation( deepspeed: bool = False, world_size: int = 8, torch_compile: bool = False, + fp8: bool = False, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -103,6 +107,13 @@ def _test_text_generation( if not deepspeed: command.append("--bf16") + if fp8: + command += [ + "--fp8", + "--reuse_cache", + "--trim_logits", + ] + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") print(f"\n\nCommand to test: {' '.join(command)}\n") @@ -112,6 +123,15 @@ def _test_text_generation( pattern = re.compile(r"([\"\'].+?[\"\'])|\s") command = [x for y in command for x in re.split(pattern, y) if x] + if fp8: + os.environ["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json" + ) + subprocess.run(command) + os.environ["QUANT_CONFIG"] = os.path.join( + path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json" + ) + proc = subprocess.run(command) # Ensure the run finished without any issue @@ -135,6 +155,13 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str): _test_text_generation(model_name, baseline, token) +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["fp8"]) +def test_text_generation_fp8(model_name: str, baseline: float, token: str): + deepspeed = True if "falcon-180B" in model_name else False + world_size = 8 if "falcon-180B" in model_name else None + _test_text_generation(model_name, baseline, token, deepspeed=deepspeed, world_size=world_size, fp8=True) + + @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["deepspeed"]) def test_text_generation_deepspeed(model_name: str, baseline: float, token: str): world_size = 2 if "opt-66b" in model_name else 8