Skip to content

Commit

Permalink
Enable Falcon FP8 inference (#94)
Browse files Browse the repository at this point in the history
squash of the following 5 commits

enable Falcon FP8 inference

added example command in readme, code cleanup

resolve issues in finetuning

enable non reuse cache flow for fp8

revert non reuse_cache flow for training due to perf drop
  • Loading branch information
Local Lab User authored and schoi-habana committed Mar 19, 2024
1 parent 93eeebd commit 9e0975f
Show file tree
Hide file tree
Showing 7 changed files with 576 additions and 219 deletions.
34 changes: 33 additions & 1 deletion examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ While `--bucket_size` works for any model without model file changes, an even mo

### Running with FP8

Llama2-70b, Llama2-7b and Mixtral-8x7B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
Llama2-70b, Llama2-7b, Mixtral-8x7B, Falcon-7B, Falcon-40B, and Falcon-180B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.

More information on enabling fp8 in SynapseAI is available here:
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html
Expand Down Expand Up @@ -320,6 +320,38 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati
--bf16 \
--fp8
```

Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards:
> Please note that Falcon-180B is a gated model, and users are required to request access to it. Please refer to the instructions provided in the StarCoder example above.
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure_include_outputs.json python ../gaudi_spawn.py \
--use_deepspeed --world_size 8 run_lm_eval.py \
-o acc_falcon180b_bs1_quant.txt \
--model_name_or_path tiiuae/falcon-180B \
--use_hpu_graphs \
--use_kv_cache \
--trim_logits \
--batch_size 1 \
--bf16 \
--reuse_cache
```

Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--use_deepspeed --world_size 8 run_generation.py \
--model_name_or_path tiiuae/falcon-180B \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--max_input_tokens 128 \
--max_new_tokens 2048 \
--batch_size 110 \
--bf16 \
--reuse_cache \
--trim_logits \
--fp8
```
`--fp8` is required to enable quantization in fp8.

### Using Habana Flash Attention
Expand Down
2 changes: 1 addition & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
if model.config.model_type == "llama":
if model.config.model_type == "llama" or "falcon":
patch_scoped_linear_all_reduce(model)

if args.quant_config:
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def generate(
)
model_kwargs["kv_cache_len"] = calculated_max_length

if self.config.model_type in ["llama"]:
if self.config.model_type in ["llama", "falcon"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)

Expand Down
10 changes: 6 additions & 4 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
GaudiBloomMLP,
GaudiCodeGenAttention,
GaudiCodeGenForCausalLM,
GaudiFalconAttention,
GaudiFalconDecoderLayer,
GaudiFalconForCausalLM,
GaudiFalconMLP,
GaudiFalconModel,
GaudiGPT2Attention,
GaudiGPT2LMHeadModel,
Expand Down Expand Up @@ -63,9 +66,7 @@
gaudi_conv1d_forward,
gaudi_esm_for_protein_folding_forward,
gaudi_esmfolding_trunk_forward,
gaudi_falcon_attention_forward,
gaudi_falcon_attention_split_heads,
gaudi_falcon_decoder_layer_forward,
gaudi_get_extended_attention_mask,
gaudi_gpt2_block_forward,
gaudi_gpt2_forward,
Expand Down Expand Up @@ -258,10 +259,11 @@ def adapt_transformers_to_gaudi():
transformers.models.llama.modeling_llama.LlamaRMSNorm.forward = gaudi_llama_rmsnorm_forward

# Optimization for falcon generation on Gaudi
transformers.models.falcon.modeling_falcon.FalconAttention = GaudiFalconAttention
transformers.models.falcon.modeling_falcon.FalconForCausalLM = GaudiFalconForCausalLM
transformers.models.falcon.modeling_falcon.FalconMLP = GaudiFalconMLP
transformers.models.falcon.modeling_falcon.FalconModel = GaudiFalconModel
transformers.models.falcon.modeling_falcon.FalconDecoderLayer.forward = gaudi_falcon_decoder_layer_forward
transformers.models.falcon.modeling_falcon.FalconAttention.forward = gaudi_falcon_attention_forward
transformers.models.falcon.modeling_falcon.FalconDecoderLayer = GaudiFalconDecoderLayer
transformers.models.falcon.modeling_falcon.FalconAttention._split_heads = gaudi_falcon_attention_split_heads

# Optimization for t5 on Gaudi
Expand Down
5 changes: 3 additions & 2 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
gaudi_rot_vec_mul,
)
from .falcon import (
GaudiFalconAttention,
GaudiFalconDecoderLayer,
GaudiFalconForCausalLM,
GaudiFalconMLP,
GaudiFalconModel,
gaudi_falcon_attention_forward,
gaudi_falcon_attention_split_heads,
gaudi_falcon_decoder_layer_forward,
)
from .gpt2 import GaudiGPT2Attention, GaudiGPT2LMHeadModel, gaudi_gpt2_block_forward, gaudi_gpt2_forward
from .gpt_bigcode import (
Expand Down
5 changes: 3 additions & 2 deletions optimum/habana/transformers/models/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .modeling_falcon import (
GaudiFalconAttention,
GaudiFalconDecoderLayer,
GaudiFalconForCausalLM,
GaudiFalconMLP,
GaudiFalconModel,
gaudi_falcon_attention_forward,
gaudi_falcon_attention_split_heads,
gaudi_falcon_decoder_layer_forward,
)
Loading

0 comments on commit 9e0975f

Please sign in to comment.