From cd3743869964684475368632670838c0c4570e00 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 14 Nov 2024 18:19:29 -0800 Subject: [PATCH 1/3] enable tiiuae/falcon-11B-vlm in image_to_text example. fix the incorrect output Signed-off-by: Wang, Yi A --- examples/image-to-text/README.md | 56 ++++++++++++------- examples/image-to-text/run_pipeline.py | 5 +- .../transformers/models/clip/modeling_clip.py | 12 +++- .../models/falcon/modeling_falcon.py | 2 +- .../models/llava_next/modeling_llava_next.py | 5 +- tests/test_image_to_text_example.py | 1 + 6 files changed, 57 insertions(+), 24 deletions(-) diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 5916de4a29..d967e954ad 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -32,6 +32,8 @@ Models that have been validated: - [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf) - [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) - [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) + - [meta-llama/Llama-3.2-90B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision-Instruct) + - [tiiuae/falcon-11B-vlm](https://huggingface.co/tiiuae/falcon-11B-vlm) ### Inference with BF16 @@ -374,35 +376,49 @@ To enable multi-card inference, you must set the environment variable `PT_HPU_EN Use the following commands to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA on 8 HPUs: ```bash PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 \ ---use_flash_attention \ ---flash_attention_recompute + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --bf16 \ + --use_flash_attention \ + --flash_attention_recompute +``` + +Use the following commands to run Llama-3.2-90B-Vision-Instruct BF16 inference with FusedSDPA on 8 HPUs: +```bash +PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ + --model_name_or_path meta-llama/Llama-3.2-90B-Vision-Instruct \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --bf16 \ + --use_flash_attention \ + --flash_attention_recompute ``` + ### FP8 Inference with FusedSDPA on 8 HPUs Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA on 8 HPUs. Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b on 8 HPUs: ```bash -QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 \ ---use_flash_attention \ ---flash_attention_recompute +QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py \ + --use_deepspeed --world_size 8 run_pipeline.py \ + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --bf16 \ + --use_flash_attention \ + --flash_attention_recompute ``` Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b on 8 HPUs: ```bash -QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 \ ---use_flash_attention \ ---flash_attention_recompute +QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py \ + --use_deepspeed --world_size 8 run_pipeline.py \ + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --bf16 \ + --use_flash_attention \ + --flash_attention_recompute ``` \ No newline at end of file diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 75b391ea2e..57539b31d2 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -202,7 +202,10 @@ def main(): ], } ] - args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + if processor.chat_template is not None: + args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + else: + args.prompt = "User:\nWhat is shown in this image?\nAssistant:" image_paths = args.image_path image_paths_len = len(image_paths) diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py index ef5d604ec9..9ced56b120 100644 --- a/optimum/habana/transformers/models/clip/modeling_clip.py +++ b/optimum/habana/transformers/models/clip/modeling_clip.py @@ -79,12 +79,14 @@ def forward( output_attentions: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Copied from CLIPAttention.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py The only differences are: - add new args use_flash_attention to enable FusedSDPA - add new args flash_attention_recompute + - add new args flash_attention_fast_softmax """ bsz, tgt_len, _ = hidden_states.size() attn_weights_reshaped = None @@ -102,9 +104,17 @@ def forward( if FusedSDPA and use_flash_attention: import habana_frameworks.torch.hpu as ht + softmax_mode = "fast" if flash_attention_fast_softmax else "None" with ht.sdp_kernel(enable_recompute=flash_attention_recompute): attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, self.dropout, False, 1, "fast" + query_states, + key_states, + value_states, + attention_mask, + self.dropout, + False, + 1, + softmax_mode, ) else: attn_weights = self.bmm1(query_states, key_states.transpose(1, 2)) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index f4c9d454ab..c54cc0100e 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -615,7 +615,7 @@ class GaudiFalconDecoderLayer(FalconDecoderLayer): """ def __init__(self, config: FalconConfig, layer_idx=None): - super().__init__(config) + super().__init__(config, layer_idx=layer_idx) self.self_attention = GaudiFalconAttention(config, layer_idx) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): diff --git a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py index 6cf728d014..69738ea09e 100644 --- a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py +++ b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py @@ -278,7 +278,10 @@ def prepare_inputs_for_generation( ) # 1. Extract the input embeddings - inputs_embeds = self.get_input_embeddings()(input_ids) + input_ids_mask = input_ids.clone() + input_ids_mask[input_ids == self.config.image_token_index] = 0 + inputs_embeds = self.get_input_embeddings()(input_ids_mask) + # 2. Merge text and images batch_size, num_patches, num_channels, height, width = pixel_values.shape reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index 60049bf46e..66fdefadc7 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -21,6 +21,7 @@ ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077), ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303), + ("tiiuae/falcon-11B-vlm", 1, 27.0566558689559327), ], "fp8": [ ("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062), From 306a46ae77d77cef67e18399485209b0f2f6c37f Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 18 Nov 2024 21:56:17 -0800 Subject: [PATCH 2/3] fix issue if there's no chat template Signed-off-by: Wang, Yi A --- examples/image-to-text/run_pipeline.py | 46 +++++++++++++------ .../models/llava_next/modeling_llava_next.py | 5 +- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 57539b31d2..9d04e0582a 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -184,28 +184,41 @@ def main(): os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") adapt_transformers_to_gaudi() - model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type + config = AutoConfig.from_pretrained(args.model_name_or_path) + model_type = config.model_type if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]: args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] elif args.image_path is None and model_type == "llava_next": args.image_path = [ "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama"]: + if model_type in ["llava", "idefics2", "llava_next", "mllama"]: processor = AutoProcessor.from_pretrained(args.model_name_or_path) - conversation = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is shown in this image?"}, - {"type": "image"}, - ], - } - ] - if processor.chat_template is not None: - args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) - else: - args.prompt = "User:\nWhat is shown in this image?\nAssistant:" + if args.prompt is None: + if processor.chat_template is not None: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is shown in this image?"}, + {"type": "image"}, + ], + } + ] + args.prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + else: + image_token_id = None + if hasattr(config, "image_token_id"): + # idefics + image_token_id = config.image_token_id + elif hasattr(config, "image_token_index"): + # mllama/falcon_vlm + image_token_id = config.image_token_index + if image_token_id is None: + image_str = "" + else: + image_str = str(processor.tokenizer.added_tokens_decoder[image_token_id]) + args.prompt = f"User:{image_str}\nWhat is shown in this image?\nAssistant:" image_paths = args.image_path image_paths_len = len(image_paths) @@ -263,6 +276,9 @@ def main(): generator.model = wrap_in_hpu_graph(generator.model) + if "falcon-11B-vlm" in args.model_name_or_path: + # WA falcon vlm issue that image_token_id == embed size. + generator.model.resize_token_embeddings(generator.tokenizer.vocab_size + 1) generate_kwargs = { "lazy_mode": True, "hpu_graphs": args.use_hpu_graphs, diff --git a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py index 69738ea09e..6cf728d014 100644 --- a/optimum/habana/transformers/models/llava_next/modeling_llava_next.py +++ b/optimum/habana/transformers/models/llava_next/modeling_llava_next.py @@ -278,10 +278,7 @@ def prepare_inputs_for_generation( ) # 1. Extract the input embeddings - input_ids_mask = input_ids.clone() - input_ids_mask[input_ids == self.config.image_token_index] = 0 - inputs_embeds = self.get_input_embeddings()(input_ids_mask) - + inputs_embeds = self.get_input_embeddings()(input_ids) # 2. Merge text and images batch_size, num_patches, num_channels, height, width = pixel_values.shape reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width) From 36aa0c5ebff84842fcb65c469b733b57b31bf386 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 26 Nov 2024 19:41:42 -0800 Subject: [PATCH 3/3] add limit_hpu_graphs in image-to-text Signed-off-by: Wang, Yi A --- examples/image-to-text/run_pipeline.py | 6 ++++++ .../transformers/models/paligemma/modeling_paligemma.py | 1 + 2 files changed, 7 insertions(+) diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 54aec2b88b..ee584363ca 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -166,6 +166,11 @@ def main(): action="store_true", help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", ) + parser.add_argument( + "--limit_hpu_graphs", + action="store_true", + help="Whether to Skip HPU Graph usage for first token to save memory", + ) parser.add_argument( "--use_kv_cache", action="store_true", @@ -294,6 +299,7 @@ def main(): "ignore_eos": args.ignore_eos, "use_flash_attention": args.use_flash_attention, "flash_attention_recompute": args.flash_attention_recompute, + "limit_hpu_graphs": args.limit_hpu_graphs, } if args.use_kv_cache: generate_kwargs["use_cache"] = args.use_kv_cache diff --git a/optimum/habana/transformers/models/paligemma/modeling_paligemma.py b/optimum/habana/transformers/models/paligemma/modeling_paligemma.py index 596d99ea82..84d5014135 100644 --- a/optimum/habana/transformers/models/paligemma/modeling_paligemma.py +++ b/optimum/habana/transformers/models/paligemma/modeling_paligemma.py @@ -48,6 +48,7 @@ def forward( return_dict: Optional[bool] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: """ Inherits from PaliGemmaForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/paligemma/modeling_paligemma.py#L402