Skip to content

Commit c495f47

Browse files
tthakkalregisss
andauthored
Enable fp8 inference for Llava-Next and add Fused_SDPA (huggingface#1120)
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
1 parent 609e450 commit c495f47

File tree

9 files changed

+483
-17
lines changed

9 files changed

+483
-17
lines changed

examples/image-to-text/README.md

+111-9
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,68 @@ limitations under the License.
1515
-->
1616

1717
# Image to Text Examples
18-
19-
This directory contains a script that showcases how to use the Transformers pipeline API to run image to text task on HPUs.
18+
This directory contains a script that showcases how to perform image to text generation on Intel® Gaudi® AI Accelerators.
2019

2120
## Single-HPU inference
2221

22+
Models that have been validated:
23+
- [nlpconnect/vit-gpt2-image-captioning](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning)
24+
- [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large)
25+
- [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base)
26+
- [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)
27+
- [llava-hf/llava-1.5-13b-hf](https://huggingface.co/llava-hf/llava-1.5-13b-hf)
28+
- [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)
29+
- [llava-hf/llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf)
30+
- [llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
31+
32+
### Inference with BF16
33+
34+
To run Salesforce/blip-image-captioning-large inference, use the following command:
2335
```bash
2436
python3 run_pipeline.py \
2537
--model_name_or_path Salesforce/blip-image-captioning-large \
2638
--image_path "https://ankur3107.github.io/assets/images/image-captioning-example.png" \
2739
--use_hpu_graphs \
2840
--bf16
2941
```
30-
Models that have been validated:
31-
- [nlpconnect/vit-gpt2-image-captioning](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning)
32-
- [Salesforce/blip-image-captioning-large](https://huggingface.co/Salesforce/blip-image-captioning-large)
33-
- [Salesforce/blip-image-captioning-base](https://huggingface.co/Salesforce/blip-image-captioning-base)
3442

35-
### Running with FP8
43+
To run Llava-1.5-7b inference, use the following command:
44+
```bash
45+
python3 run_pipeline.py \
46+
--model_name_or_path llava-hf/llava-1.5-7b-hf \
47+
--use_hpu_graphs \
48+
--bf16
49+
```
50+
51+
To run Llava-1.5-13b inference, use the following command:
52+
```bash
53+
python3 run_pipeline.py \
54+
--model_name_or_path llava-hf/llava-1.5-13b-hf \
55+
--use_hpu_graphs \
56+
--bf16
57+
```
58+
59+
To run Llava-v1.6-mistral-7b inference, use the following command:
60+
```bash
61+
python3 run_pipeline.py \
62+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
63+
--use_hpu_graphs \
64+
--bf16
65+
```
3666

37-
Llava-1.5-7b and Llava-1.5-13b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
67+
To run Llava-v1.6-vicuna-13b inference, use the following command:
68+
```bash
69+
python3 run_pipeline.py \
70+
--model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
71+
--use_hpu_graphs \
72+
--bf16
73+
```
74+
75+
### Inference with FP8
3876

39-
More information on enabling fp8 in SynapseAI is available here:
77+
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
78+
79+
More information on enabling FP8 in SynapseAI is available here:
4080
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html
4181

4282
Here is an example to measure the tensor quantization statistics on Llava-1.5-7b:
@@ -56,3 +96,65 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
5696
--use_hpu_graphs \
5797
--bf16
5898
```
99+
100+
101+
Here is an example to measure the tensor quantization statistics on Llava-v1.6-mistral-7b:
102+
```bash
103+
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
104+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
105+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
106+
--use_hpu_graphs \
107+
--bf16
108+
```
109+
110+
Here is an example to quantize the model based on previous measurements for Llava-v1.6-mistral-7b:
111+
```bash
112+
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
113+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
114+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
115+
--use_hpu_graphs \
116+
--bf16
117+
```
118+
119+
Here is an example to measure the tensor quantization statistics on Llava-v1.6-vicuna-13b:
120+
```bash
121+
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
122+
--model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
123+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
124+
--use_hpu_graphs \
125+
--bf16
126+
```
127+
128+
Here is an example to quantize the model based on previous measurements for Llava-v1.6-vicuna-13b:
129+
```bash
130+
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
131+
--model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
132+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
133+
--use_hpu_graphs \
134+
--bf16
135+
```
136+
137+
### Inference with FusedSDPA
138+
139+
Habana FusedSDPA is a fused and optimized implementation of torch.nn.functional.scaled_dot_product_attention() for Gaudi. For more details, refer to [Gaudi online documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html?highlight=fusedsdpa#using-fused-scaled-dot-product-attention-fusedsdpa). Currently FusedSDPA works with BF16 precision for Llava models.
140+
141+
Use the following commands to run Llava-1.5-7b inference with FusedSDPA
142+
```bash
143+
python3 run_pipeline.py \
144+
--model_name_or_path llava-hf/llava-1.5-7b-hf \
145+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
146+
--use_hpu_graphs \
147+
--bf16 \
148+
--use_flash_attention
149+
```
150+
151+
152+
Use the following commands to run Llava-v1.6-mistral-7b inference with FusedSDPA
153+
```bash
154+
python3 run_pipeline.py \
155+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
156+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
157+
--use_hpu_graphs \
158+
--bf16 \
159+
--use_flash_attention
160+
```

examples/image-to-text/run_pipeline.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def main():
9191
action="store_true",
9292
help="Whether to ignore eos, set False to disable it.",
9393
)
94+
parser.add_argument(
95+
"--use_flash_attention",
96+
action="store_true",
97+
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
98+
)
99+
94100
args = parser.parse_args()
95101

96102
# set args.quant_config with env variable if it is set
@@ -109,7 +115,7 @@ def main():
109115
args.prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
110116
elif args.prompt is None and model_type == "llava_next":
111117
args.prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
112-
if args.model_name_or_path == "llava-hf/llava-v1.6-vicuna-13b-hf":
118+
if args.model_name_or_path in ["llava-hf/llava-v1.6-vicuna-13b-hf", "llava-hf/llava-v1.6-vicuna-7b-hf"]:
113119
args.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
114120

115121
image_paths = args.image_path
@@ -149,6 +155,7 @@ def main():
149155
"hpu_graphs": args.use_hpu_graphs,
150156
"max_new_tokens": args.max_new_tokens,
151157
"ignore_eos": args.ignore_eos,
158+
"use_flash_attention": args.use_flash_attention,
152159
}
153160
if args.use_hpu_graphs:
154161
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
@@ -165,7 +172,6 @@ def main():
165172
# warm up
166173
for i in range(args.warmup):
167174
generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs)
168-
169175
torch.hpu.synchronize()
170176
if args.quant_config:
171177
habana_quantization_toolkit.finish_measurements(generator.model)

optimum/habana/transformers/modeling_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from .models import (
2929
GaudiBloomForCausalLM,
3030
GaudiBloomMLP,
31+
GaudiCLIPAttention,
32+
GaudiCLIPEncoder,
33+
GaudiCLIPEncoderLayer,
3134
GaudiCLIPVisionEmbeddings,
35+
GaudiCLIPVisionModel,
36+
GaudiCLIPVisionTransformer,
3237
GaudiCodeGenAttention,
3338
GaudiCodeGenForCausalLM,
3439
GaudiFalconAttention,
@@ -376,6 +381,11 @@ def adapt_transformers_to_gaudi():
376381

377382
# Optimization for Clip on Gaudi
378383
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings
384+
transformers.models.clip.modeling_clip.CLIPAttention = GaudiCLIPAttention
385+
transformers.models.clip.modeling_clip.CLIPEncoderLayer = GaudiCLIPEncoderLayer
386+
transformers.models.clip.modeling_clip.CLIPEncoder = GaudiCLIPEncoder
387+
transformers.models.clip.modeling_clip.CLIPVisionTransformer = GaudiCLIPVisionTransformer
388+
transformers.models.clip.modeling_clip.CLIPVisionModel = GaudiCLIPVisionModel
379389

380390
# Optimization for falcon generation on Gaudi
381391
transformers.models.falcon.modeling_falcon.FalconAttention = GaudiFalconAttention

optimum/habana/transformers/models/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@
3131
gaudi_bloom_convert_to_standard_cache,
3232
gaudi_bloom_model_forward,
3333
)
34-
from .clip import GaudiCLIPVisionEmbeddings
34+
from .clip import (
35+
GaudiCLIPAttention,
36+
GaudiCLIPEncoder,
37+
GaudiCLIPEncoderLayer,
38+
GaudiCLIPVisionEmbeddings,
39+
GaudiCLIPVisionModel,
40+
GaudiCLIPVisionTransformer,
41+
)
3542
from .codegen import (
3643
GaudiCodeGenAttention,
3744
GaudiCodeGenForCausalLM,
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1-
from .modeling_clip import GaudiCLIPVisionEmbeddings
1+
from .modeling_clip import (
2+
GaudiCLIPAttention,
3+
GaudiCLIPEncoder,
4+
GaudiCLIPEncoderLayer,
5+
GaudiCLIPVisionEmbeddings,
6+
GaudiCLIPVisionModel,
7+
GaudiCLIPVisionTransformer,
8+
)

0 commit comments

Comments
 (0)