Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable tiiuae/falcon-11B-vlm in image_to_text example. fix the incorr… #1490

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Models that have been validated:
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
- [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
- [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
- [meta-llama/Llama-3.2-90B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision-Instruct)
- [tiiuae/falcon-11B-vlm](https://huggingface.co/tiiuae/falcon-11B-vlm)
- [google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224)

### Inference with BF16
Expand Down Expand Up @@ -381,35 +383,49 @@ To enable multi-card inference, you must set the environment variable `PT_HPU_EN
Use the following commands to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA on 8 HPUs:
```bash
PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

Use the following commands to run Llama-3.2-90B-Vision-Instruct BF16 inference with FusedSDPA on 8 HPUs:
```bash
PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path meta-llama/Llama-3.2-90B-Vision-Instruct \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```


### FP8 Inference with FusedSDPA on 8 HPUs

Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA on 8 HPUs.
Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b on 8 HPUs:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py \
--use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b on 8 HPUs:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py \
--use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```
55 changes: 40 additions & 15 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ def main():
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)
parser.add_argument(
"--limit_hpu_graphs",
action="store_true",
help="Whether to Skip HPU Graph usage for first token to save memory",
)
parser.add_argument(
"--use_kv_cache",
action="store_true",
Expand All @@ -184,7 +189,8 @@ def main():
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
adapt_transformers_to_gaudi()

model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
config = AutoConfig.from_pretrained(args.model_name_or_path)
model_type = config.model_type
if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "paligemma":
Expand All @@ -196,21 +202,36 @@ def main():
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]

if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma"]:
if model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma"]:
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
}
]
if model_type == "paligemma":
args.prompt = "caption es"
else:
args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
if args.prompt is None:
if processor.chat_template is not None:
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
}
]
args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
else:
image_token_id = None
if hasattr(config, "image_token_id"):
# idefics
image_token_id = config.image_token_id
elif hasattr(config, "image_token_index"):
# mllama/falcon_vlm
image_token_id = config.image_token_index
if image_token_id is None:
image_str = "<image>"
else:
image_str = str(processor.tokenizer.added_tokens_decoder[image_token_id])
if model_type == "paligemma":
args.prompt = "caption es"
else:
args.prompt = f"User:{image_str}\nWhat is shown in this image?\nAssistant:"

image_paths = args.image_path
image_paths_len = len(image_paths)
Expand Down Expand Up @@ -268,13 +289,17 @@ def main():

generator.model = wrap_in_hpu_graph(generator.model)

if "falcon-11B-vlm" in args.model_name_or_path:
# WA falcon vlm issue that image_token_id == embed size.
generator.model.resize_token_embeddings(generator.tokenizer.vocab_size + 1)
generate_kwargs = {
"lazy_mode": True,
"hpu_graphs": args.use_hpu_graphs,
"max_new_tokens": args.max_new_tokens,
"ignore_eos": args.ignore_eos,
"use_flash_attention": args.use_flash_attention,
"flash_attention_recompute": args.flash_attention_recompute,
"limit_hpu_graphs": args.limit_hpu_graphs,
}
if args.use_kv_cache:
generate_kwargs["use_cache"] = args.use_kv_cache
Expand Down
12 changes: 11 additions & 1 deletion optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def forward(
output_attentions: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Copied from CLIPAttention.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
The only differences are:
- add new args use_flash_attention to enable FusedSDPA
- add new args flash_attention_recompute
- add new args flash_attention_fast_softmax
"""
bsz, tgt_len, _ = hidden_states.size()
attn_weights_reshaped = None
Expand All @@ -102,9 +104,17 @@ def forward(
if FusedSDPA and use_flash_attention:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, self.dropout, False, 1, "fast"
query_states,
key_states,
value_states,
attention_mask,
self.dropout,
False,
1,
softmax_mode,
)
else:
attn_weights = self.bmm1(query_states, key_states.transpose(1, 2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer):
"""

def __init__(self, config: FalconConfig, layer_idx=None):
super().__init__(config)
super().__init__(config, layer_idx=layer_idx)
self.self_attention = GaudiFalconAttention(config, layer_idx)

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def forward(
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
token_idx: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
"""
Inherits from PaliGemmaForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/paligemma/modeling_paligemma.py#L402
Expand Down
1 change: 1 addition & 0 deletions tests/test_image_to_text_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
("google/paligemma-3b-mix-224", 1, 132.8949150246155),
("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077),
("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303),
("tiiuae/falcon-11B-vlm", 1, 27.0566558689559327),
],
"fp8": [
("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062),
Expand Down