Skip to content

Commit

Permalink
Enable tiiuae/falcon-11B-vlm in image_to_text example (#1490)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Nov 28, 2024
1 parent 22c6adb commit b36fb2b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 37 deletions.
56 changes: 36 additions & 20 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Models that have been validated:
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
- [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
- [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- [meta-llama/Llama-3.2-90B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision-Instruct)
- [tiiuae/falcon-11B-vlm](https://huggingface.co/tiiuae/falcon-11B-vlm)
- [google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224)

### Inference with BF16
Expand Down Expand Up @@ -381,35 +383,49 @@ To enable multi-card inference, you must set the environment variable `PT_HPU_EN
Use the following commands to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA on 8 HPUs:
```bash
PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

Use the following commands to run Llama-3.2-90B-Vision-Instruct BF16 inference with FusedSDPA on 8 HPUs:
```bash
PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path meta-llama/Llama-3.2-90B-Vision-Instruct \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```


### FP8 Inference with FusedSDPA on 8 HPUs

Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA on 8 HPUs.
Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b on 8 HPUs:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py \
--use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b on 8 HPUs:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py \
--use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```
55 changes: 40 additions & 15 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def main():
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)
parser.add_argument(
"--limit_hpu_graphs",
action="store_true",
help="Whether to Skip HPU Graph usage for first token to save memory",
)
parser.add_argument(
"--use_kv_cache",
action="store_true",
Expand All @@ -184,7 +189,8 @@ def main():
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
adapt_transformers_to_gaudi()

model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
config = AutoConfig.from_pretrained(args.model_name_or_path)
model_type = config.model_type
if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "paligemma":
Expand All @@ -196,21 +202,36 @@ def main():
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]

if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma"]:
if model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma"]:
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
}
]
if model_type == "paligemma":
args.prompt = "caption es"
else:
args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
if args.prompt is None:
if processor.chat_template is not None:
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
}
]
args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
else:
image_token_id = None
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama/falcon_vlm
image_token_id = config.image_token_index
if image_token_id is None:
image_str = "<image>"
else:
image_str = str(processor.tokenizer.added_tokens_decoder[image_token_id])
if model_type == "paligemma":
args.prompt = "caption es"
else:
args.prompt = f"User:{image_str}\nWhat is shown in this image?\nAssistant:"

image_paths = args.image_path
image_paths_len = len(image_paths)
Expand Down Expand Up @@ -268,13 +289,17 @@ def main():

generator.model = wrap_in_hpu_graph(generator.model)

if "falcon-11B-vlm" in args.model_name_or_path:
# WA falcon vlm issue that image_token_id == embed size.
generator.model.resize_token_embeddings(generator.tokenizer.vocab_size + 1)
generate_kwargs = {
"lazy_mode": True,
"hpu_graphs": args.use_hpu_graphs,
"max_new_tokens": args.max_new_tokens,
"ignore_eos": args.ignore_eos,
"use_flash_attention": args.use_flash_attention,
"flash_attention_recompute": args.flash_attention_recompute,
"limit_hpu_graphs": args.limit_hpu_graphs,
}
if args.use_kv_cache:
generate_kwargs["use_cache"] = args.use_kv_cache
Expand Down
12 changes: 11 additions & 1 deletion optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def forward(
output_attentions: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Copied from CLIPAttention.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
The only differences are:
- add new args use_flash_attention to enable FusedSDPA
- add new args flash_attention_recompute
- add new args flash_attention_fast_softmax
"""
bsz, tgt_len, _ = hidden_states.size()
attn_weights_reshaped = None
Expand All @@ -102,9 +104,17 @@ def forward(
if FusedSDPA and use_flash_attention:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, self.dropout, False, 1, "fast"
query_states,
key_states,
value_states,
attention_mask,
self.dropout,
False,
1,
softmax_mode,
)
else:
attn_weights = self.bmm1(query_states, key_states.transpose(1, 2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer):
"""

def __init__(self, config: FalconConfig, layer_idx=None):
super().__init__(config)
super().__init__(config, layer_idx=layer_idx)
self.self_attention = GaudiFalconAttention(config, layer_idx)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
"""
Inherits from PaliGemmaForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/paligemma/modeling_paligemma.py#L402
Expand Down
1 change: 1 addition & 0 deletions tests/test_image_to_text_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
("google/paligemma-3b-mix-224", 1, 132.8949150246155),
("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077),
("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303),
("tiiuae/falcon-11B-vlm", 1, 27.0566558689559327),
],
"fp8": [
("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062),
Expand Down

0 comments on commit b36fb2b

Please sign in to comment.