From 4fb058ad0283786227e8f013d18006c742e5214c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 18 Aug 2024 18:49:39 -0700 Subject: [PATCH 01/12] add idefics2 Signed-off-by: Wang, Yi A --- .../habana/transformers/generation/utils.py | 1 + optimum/habana/transformers/modeling_utils.py | 8 + .../habana/transformers/models/__init__.py | 1 + .../transformers/models/idefics2/__init__.py | 1 + .../models/idefics2/modeling_idefics2.py | 340 ++++++++++++++++++ 5 files changed, 351 insertions(+) create mode 100644 optimum/habana/transformers/models/idefics2/__init__.py create mode 100644 optimum/habana/transformers/models/idefics2/modeling_idefics2.py diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 284f646a48..03b0f00a1c 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -105,6 +105,7 @@ "stablelm", "mamba", "deci", + "idefics2", ] diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 621e391bfb..41007b62bd 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -56,6 +56,8 @@ GaudiGPTJModel, GaudiGPTNeoXForCausalLM, GaudiGPTNeoXLayer, + GaudiIdefics2ForConditionalGeneration, + GaudiIdefics2Model, GaudiLlamaAttention, GaudiLlamaDecoderLayer, GaudiLlamaDynamicNTKScalingRotaryEmbedding, @@ -380,6 +382,12 @@ def adapt_transformers_to_gaudi(): GaudiLlavaNextForConditionalGeneration ) + # Optimization for idefics2 on Gaudi + transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration = ( + GaudiIdefics2ForConditionalGeneration + ) + transformers.models.idefics2.modeling_idefics2.Idefics2Model = GaudiIdefics2Model + # Optimization for Clip on Gaudi transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings transformers.models.clip.modeling_clip.CLIPAttention = GaudiCLIPAttention diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 99ef65c4e4..ab61fc05af 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -96,6 +96,7 @@ GaudiGPTJForCausalLM, GaudiGPTJModel, ) +from .idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model from .llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, diff --git a/optimum/habana/transformers/models/idefics2/__init__.py b/optimum/habana/transformers/models/idefics2/__init__.py new file mode 100644 index 0000000000..a862776ad9 --- /dev/null +++ b/optimum/habana/transformers/models/idefics2/__init__.py @@ -0,0 +1 @@ +from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model diff --git a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py new file mode 100644 index 0000000000..aeb1825540 --- /dev/null +++ b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache +from transformers.models.idefics2.modeling_idefics2 import ( + Idefics2BaseModelOutputWithPast, + Idefics2CausalLMOutputWithPast, + Idefics2ForConditionalGeneration, + Idefics2Model, +) +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GaudiIdefics2Model(Idefics2Model): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]: + """ + Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303 + The only differences are: + - add new args token_idx + - ignoring new Cache path for HPU + - unfold is not supported in HPU, fallback to cpu + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training and self.text_model.gradient_checkpointing and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + return_legacy_cache = True + use_new_cache = False # Ignoring new Cache path for HPU + if use_cache and use_new_cache: + if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + if past_key_values is not None: + past_seen_tokens = past_key_values[0][0].shape[2] + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.cpu().unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.cpu().unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.connector( + image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) + ) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = self.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + if return_legacy_cache and use_cache: + outputs.past_key_values = ( + outputs.past_key_values.to_legacy_cache() + if isinstance(outputs.past_key_values, Cache) + else outputs.past_key_values + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return Idefics2BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + + +class GaudiIdefics2ForConditionalGeneration(Idefics2ForConditionalGeneration): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: + """ + Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1505 + The only differences are: + - add new args token_idx + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Idefics2CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + """ + Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1622 + The only differences are: + - add new args token_idx + - add None "Cache" past_key_values support + """ + past_length = 0 + token_idx = kwargs.get("token_idx", None) + # Omit tokens covered by past_key_values + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + if isinstance(past_key_values, Cache): + past_length = past_key_values.get_seq_length() + max_cache_length = past_key_values.get_max_length() + else: + past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and past_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + image_hidden_states = kwargs.get("image_hidden_states", None) + if image_hidden_states is not None: + pixel_values = None + pixel_attention_mask = None + else: + pixel_values = kwargs.get("pixel_values", None) + pixel_attention_mask = kwargs.get("pixel_attention_mask", None) + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_attention_mask": pixel_attention_mask, + "image_hidden_states": image_hidden_states, + "token_idx": token_idx, + } + ) + return model_inputs From b5beca41f642ebf4ae309303cfafb3c6d6d5cf0b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 18 Aug 2024 22:13:46 -0700 Subject: [PATCH 02/12] add example and test Signed-off-by: Wang, Yi A --- examples/image-to-text/README.md | 6 ++++++ examples/image-to-text/run_pipeline.py | 18 ++++++++++++++++-- tests/test_image_to_text_example.py | 1 + 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 0f1a2624d4..dbcd1d7647 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -72,6 +72,12 @@ python3 run_pipeline.py \ --bf16 ``` +To run idefics2 inference, use the following command: +python3 run_pipeline.py \ + --model_name_or_path HuggingFaceM4/idefics2-8b \ + --use_hpu_graphs \ + --bf16 + ### Inference with FP8 Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 239d6fa4e4..5f8cd8b29c 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -105,13 +105,13 @@ def main(): adapt_transformers_to_gaudi() model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type - if args.image_path is None and model_type == "llava": + if args.image_path is None and model_type in ["llava", "idefics2"]: args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] elif args.image_path is None and model_type == "llava_next": args.image_path = [ "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - if args.prompt is None and model_type == "llava": + if args.prompt is None and model_type in ["llava", "idefics2"]: args.prompt = "\nUSER: What's the content of the image?\nASSISTANT:" elif args.prompt is None and model_type == "llava_next": args.prompt = "[INST] \nWhat is shown in this image? [/INST]" @@ -169,6 +169,20 @@ def main(): htcore.hpu_initialize(generator.model) + # delete once pipeline integrate AutoProcessor as preprocess engine + if model_type in ["idefics2"]: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(args.model_name_or_path) + from transformers.image_utils import load_image + + def preprocess(self, image, prompt=None, timeout=None): + image = load_image(image, timeout=timeout) + model_inputs = processor(images=image, text=prompt, return_tensors=self.framework) + return model_inputs + + generator.__class__.preprocess = preprocess + # warm up for i in range(args.warmup): generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs) diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index e35324046f..bf71d34ecb 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -19,6 +19,7 @@ ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546), ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742), ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), + ("HuggingFaceM4/idefics2-8b", 1, 25.886229336641385), ], "fp8": [ ("llava-hf/llava-1.5-7b-hf", 1, 115.48515989461843), From a3890a1681e0f38b42ad6fbddacf4d20b2366d6b Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 18 Aug 2024 22:19:38 -0700 Subject: [PATCH 03/12] fix readme Signed-off-by: Wang, Yi A --- examples/image-to-text/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index dbcd1d7647..2ec40ee650 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -73,10 +73,12 @@ python3 run_pipeline.py \ ``` To run idefics2 inference, use the following command: +```bash python3 run_pipeline.py \ --model_name_or_path HuggingFaceM4/idefics2-8b \ --use_hpu_graphs \ --bf16 +``` ### Inference with FP8 From 13f5f779208efd7426312b6214c9e6b1d5029413 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 6 Sep 2024 00:54:38 -0700 Subject: [PATCH 04/12] image2text finetune Signed-off-by: Wang, Yi A --- examples/image-to-text/README.md | 67 +++ examples/image-to-text/requirements.txt | 2 + .../run_image2text_lora_finetune.py | 487 ++++++++++++++++++ 3 files changed, 556 insertions(+) create mode 100644 examples/image-to-text/requirements.txt create mode 100644 examples/image-to-text/run_image2text_lora_finetune.py diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 2ec40ee650..81f15b50b1 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -187,3 +187,70 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ --use_hpu_graphs \ --bf16 --use_flash_attention ``` +## LORA Finetune + +To run LoRA finetuning, you can use `run_image2text_lora_finetune.py`. +Here are single-/multi-device command examples for HuggingFaceM4/idefics2-8b. + +```bash +python3 run_image2text_lora_finetune.py \ + --model_name_or_path HuggingFaceM4/idefics2-8b \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llama \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy epoch \ + --save_strategy "no" \ + --learning_rate 1e-4 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --low_cpu_mem_usage True \ + --lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' +``` + +```bash +python3 ../gaudi_spawn.py \ + --world_size 8 --use_mpi run_image2text_lora_finetune.py \ + --model_name_or_path HuggingFaceM4/idefics2-8b \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llama \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy epoch \ + --save_strategy "no" \ + --learning_rate 1e-4 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --low_cpu_mem_usage True \ + --lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' +``` diff --git a/examples/image-to-text/requirements.txt b/examples/image-to-text/requirements.txt new file mode 100644 index 0000000000..045bd69e64 --- /dev/null +++ b/examples/image-to-text/requirements.txt @@ -0,0 +1,2 @@ +peft == 0.12.0 +Levenshtein diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py new file mode 100644 index 0000000000..4c83d25572 --- /dev/null +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -0,0 +1,487 @@ +# Apache v2 license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +poly tuning script for sequence-to-sequence modeling +Adapted from the following sources: +https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?usp=sharing +""" + +import json +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import List, Optional + +import Levenshtein +import torch +from datasets import load_dataset +from peft import LoraConfig, get_peft_model +from transformers import ( + AutoConfig, + AutoModelForVision2Seq, + AutoProcessor, + HfArgumentParser, +) +from transformers.trainer_utils import is_main_process + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +try: + from optimum.habana.utils import check_optimum_habana_min_version +except ImportError: + + def check_optimum_habana_min_version(*a, **b): + return () + +os.environ["WANDB_DISABLED"] = "true" + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks. +check_optimum_habana_min_version("1.10.0") + + +def normalized_levenshtein(s1, s2): + len_s1, len_s2 = len(s1), len(s2) + distance = Levenshtein.distance(s1, s2) + return distance / max(len_s1, len_s2) + +def similarity_score(a_ij, o_q_i, tau=0.5): + nl = normalized_levenshtein(a_ij, o_q_i) + return 1 - nl if nl < tau else 0 + +def average_normalized_levenshtein_similarity(ground_truth, predicted_answers): + assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match." + + N = len(ground_truth) + total_score = 0 + + for i in range(N): + a_i = ground_truth[i] + o_q_i = predicted_answers[i] + if o_q_i == "": + print("Warning: Skipped an empty prediction.") + max_score = 0 + else: + max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i) + + total_score += max_score + + return total_score / N + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/processor we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + token: Optional[str] = field( + default=None, + metadata={"help": "auth token for private models"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ) + }, + ) + use_cache: bool = field( + default=True, + metadata={ + "help": ( + "Whether or not the model should return the last key/values attentions (not used by all models)." + "Only relevant if `config.is_decoder=True`." + ) + }, + ) + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "When set to True, it will benefit LLM loading time and RAM consumption." + ) + }, + ) + load_meta_device: bool = field( + default=False, + metadata={ + "help": ( + "It is an option to load the model to the device instead of the host, so it can reduce the host RAM usage." + "https://huggingface.co/blog/accelerate-large-models" + ) + }, + ) + do_image_splitting: bool = field( + default=False, metadata={"help": "Whether to do image split during finetune."} + ) + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}, + ) + max_seq_length: Optional[int] = field( + default=512, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached preprocessed datasets or not."}, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + dataset_seed: int = field( + default=42, + metadata={ + "help": "Seed to use in dataset processing, different seeds might yield different datasets. This seed and the seed in training arguments are not related" + }, + ) + save_last_ckpt: bool = field( + default=True, metadata={"help": "Whether to save checkpoint at the end of the training."} + ) + input_column_names: List[str] = field( + default_factory=lambda: None, + metadata={ + "help": "Name of the column in the dataset that optionally provides context or input for the task. By " + "default, 'image,query' columns are used" + }, + ) + output_column_names: List[str] = field( + default_factory=lambda: None, + metadata={ + "help": "Name of the column in the dataset with the answer to the instruction. By default, the " + "'answers' column is used" + }, + ) + + +@dataclass +class FinetuneArguments: + """ + Arguments of finetune we are going to apply on the model. + """ + + lora_rank: int = field( + default=8, + metadata={"help": "Rank parameter in the LoRA method."}, + ) + lora_alpha: int = field( + default=8, + metadata={"help": "Alpha parameter in the LoRA method."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "Dropout parameter in the LoRA method."}, + ) + lora_target_modules: str = field( + default=None, + metadata={"help": "Target modules for the LoRA/AdaLoRA method."}, + ) + +class MyDataCollator: + def __init__(self, processor): + self.processor = processor + self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index("") + ] + + def __call__(self, examples): + texts = [] + images = [] + keys = list(examples[0].keys()) + if not all(key in ["image","query","answers"] for key in keys): + raise ValueError("Unsupported dataset format") + for example in examples: + image = example["image"] + question = example["query"]["en"] + answer = random.choice(example["answers"]) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": question} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": answer} + ] + } + ] + text = self.processor.apply_chat_template(messages, add_generation_prompt=False) + texts.append(text.strip()) + images.append([image]) + + batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True) + + labels = batch["input_ids"].clone() + labels[labels == self.processor.tokenizer.pad_token_id] = -100 + labels[labels == self.image_token_id] = -100 + batch["labels"] = labels + + return batch + + +def eval(processor, model, eval_dataset, eval_batch_size): + from tqdm import tqdm + answers_unique = [] + generated_texts_unique = [] + + for i in tqdm(range(0, len(eval_dataset), eval_batch_size)): + examples = eval_dataset[i: i + eval_batch_size] + answers_unique.extend(examples["answers"]) + images = [[im] for im in examples["image"]] + texts = [] + for q in examples["query"]: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": q["en"]} + ] + } + ] + text = processor.apply_chat_template(messages, add_generation_prompt=True) + texts.append(text.strip()) + inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) + inputs = {k: v.to("hpu") for k, v in inputs.items()} + generated_ids = model.generate(**inputs, max_new_tokens=64) + generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True) + generated_texts_unique.extend(generated_texts) + generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique] + anls = average_normalized_levenshtein_similarity( + ground_truth=answers_unique, predicted_answers=generated_texts_unique, + ) + return anls + + +def main(): + parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args, finetune_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + ( + model_args, + data_args, + training_args, + finetune_args, + ) = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + do_image_splitting=model_args.do_image_splitting, + padding_side="right", + ) + + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "trust_remote_code": True if model_args.trust_remote_code else None, + "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache, + "token": model_args.token, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + raise ValueError("Please provide value for model_name_or_path or config_name.") + + # Load model + if model_args.model_name_or_path: + model_dtype = torch.bfloat16 if training_args.bf16 else None + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + trust_remote_code=True if model_args.trust_remote_code else None, + torch_dtype=model_dtype, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + device_map=training_args.device.type if model_args.load_meta_device else None, + token=model_args.token, + ) + else: + raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") + + lora_config = LoraConfig( + r=finetune_args.lora_rank, + lora_alpha=finetune_args.lora_alpha, + lora_dropout=finetune_args.lora_dropout, + target_modules=finetune_args.lora_target_modules, + init_lora_weights="gaussian", + ) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + train_dataset = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + split="train") + + train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in (data_args.input_column_names + data_args.output_column_names)]) + + eval_dataset = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + split="test") + + eval_dataset = eval_dataset.remove_columns([col for col in eval_dataset.column_names if col not in (data_args.input_column_names + data_args.output_column_names)]) + + data_collator = MyDataCollator(processor) + + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + trainer = GaudiTrainer( + model=model, + args=training_args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + + if training_args.do_train: + trainer.train() + trainer.save_model() + + if is_main_process(training_args.local_rank): + example = eval_dataset[15] + model.eval() + model = model.merge_and_unload() + if model_dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + image = example["image"] + query = example["query"] + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + {"type": "image"}, + {"type": "text", "text": query["en"]} + ] + } + ] + processor.tokenizer.padding_side = 'left' + text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True) + inputs = {k: v.to("hpu") for k, v in inputs.items()} + generated_ids = model.generate(**inputs, max_new_tokens=64) + generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True) + logger.info(f"generated: {generated_texts}") + if training_args.do_eval: + anls = eval(processor=processor,model=model,eval_dataset=eval_dataset,eval_batch_size=training_args.per_device_eval_batch_size) + logger.info(f"anls = {anls}") + if training_args.output_dir is not None: + metrics = {"accuracy": anls} + with open(f"{training_args.output_dir}/accuracy_metrics.json", mode="w") as file: + json.dump(metrics, file) + + +if __name__ == "__main__": + main() From 21ace730307bdca06add8e212d7265989e880938 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 6 Sep 2024 07:35:14 -0700 Subject: [PATCH 05/12] format fix Signed-off-by: Wang, Yi A --- .../run_image2text_lora_finetune.py | 95 ++++++++++++------- 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index 4c83d25572..ec2ffbf89f 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -49,6 +49,7 @@ def check_optimum_habana_min_version(*a, **b): return () + os.environ["WANDB_DISABLED"] = "true" logger = logging.getLogger(__name__) @@ -62,10 +63,12 @@ def normalized_levenshtein(s1, s2): distance = Levenshtein.distance(s1, s2) return distance / max(len_s1, len_s2) + def similarity_score(a_ij, o_q_i, tau=0.5): nl = normalized_levenshtein(a_ij, o_q_i) return 1 - nl if nl < tau else 0 + def average_normalized_levenshtein_similarity(ground_truth, predicted_answers): assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match." @@ -156,9 +159,8 @@ class ModelArguments: ) }, ) - do_image_splitting: bool = field( - default=False, metadata={"help": "Whether to do image split during finetune."} - ) + do_image_splitting: bool = field(default=False, metadata={"help": "Whether to do image split during finetune."}) + @dataclass class DataArguments: @@ -254,6 +256,7 @@ class FinetuneArguments: metadata={"help": "Target modules for the LoRA/AdaLoRA method."}, ) + class MyDataCollator: def __init__(self, processor): self.processor = processor @@ -265,7 +268,7 @@ def __call__(self, examples): texts = [] images = [] keys = list(examples[0].keys()) - if not all(key in ["image","query","answers"] for key in keys): + if not all(key in ["image", "query", "answers"] for key in keys): raise ValueError("Unsupported dataset format") for example in examples: image = example["image"] @@ -277,15 +280,10 @@ def __call__(self, examples): "content": [ {"type": "text", "text": "Answer briefly."}, {"type": "image"}, - {"type": "text", "text": question} - ] + {"type": "text", "text": question}, + ], }, - { - "role": "assistant", - "content": [ - {"type": "text", "text": answer} - ] - } + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, ] text = self.processor.apply_chat_template(messages, add_generation_prompt=False) texts.append(text.strip()) @@ -303,11 +301,12 @@ def __call__(self, examples): def eval(processor, model, eval_dataset, eval_batch_size): from tqdm import tqdm + answers_unique = [] generated_texts_unique = [] for i in tqdm(range(0, len(eval_dataset), eval_batch_size)): - examples = eval_dataset[i: i + eval_batch_size] + examples = eval_dataset[i : i + eval_batch_size] answers_unique.extend(examples["answers"]) images = [[im] for im in examples["image"]] texts = [] @@ -318,8 +317,8 @@ def eval(processor, model, eval_dataset, eval_batch_size): "content": [ {"type": "text", "text": "Answer briefly."}, {"type": "image"}, - {"type": "text", "text": q["en"]} - ] + {"type": "text", "text": q["en"]}, + ], } ] text = processor.apply_chat_template(messages, add_generation_prompt=True) @@ -327,11 +326,14 @@ def eval(processor, model, eval_dataset, eval_batch_size): inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate(**inputs, max_new_tokens=64) - generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True) + generated_texts = processor.batch_decode( + generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True + ) generated_texts_unique.extend(generated_texts) generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique] anls = average_normalized_levenshtein_similarity( - ground_truth=answers_unique, predicted_answers=generated_texts_unique, + ground_truth=answers_unique, + predicted_answers=generated_texts_unique, ) return anls @@ -409,24 +411,38 @@ def main(): model.print_trainable_parameters() train_dataset = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - split="train") + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + split="train", + ) - train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in (data_args.input_column_names + data_args.output_column_names)]) + train_dataset = train_dataset.remove_columns( + [ + col + for col in train_dataset.column_names + if col not in (data_args.input_column_names + data_args.output_column_names) + ] + ) eval_dataset = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - token=model_args.token, - trust_remote_code=model_args.trust_remote_code, - split="test") + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + split="test", + ) - eval_dataset = eval_dataset.remove_columns([col for col in eval_dataset.column_names if col not in (data_args.input_column_names + data_args.output_column_names)]) + eval_dataset = eval_dataset.remove_columns( + [ + col + for col in eval_dataset.column_names + if col not in (data_args.input_column_names + data_args.output_column_names) + ] + ) data_collator = MyDataCollator(processor) @@ -463,19 +479,26 @@ def main(): "content": [ {"type": "text", "text": "Answer briefly."}, {"type": "image"}, - {"type": "text", "text": query["en"]} - ] + {"type": "text", "text": query["en"]}, + ], } ] - processor.tokenizer.padding_side = 'left' + processor.tokenizer.padding_side = "left" text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True) inputs = {k: v.to("hpu") for k, v in inputs.items()} generated_ids = model.generate(**inputs, max_new_tokens=64) - generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True) + generated_texts = processor.batch_decode( + generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True + ) logger.info(f"generated: {generated_texts}") if training_args.do_eval: - anls = eval(processor=processor,model=model,eval_dataset=eval_dataset,eval_batch_size=training_args.per_device_eval_batch_size) + anls = eval( + processor=processor, + model=model, + eval_dataset=eval_dataset, + eval_batch_size=training_args.per_device_eval_batch_size, + ) logger.info(f"anls = {anls}") if training_args.output_dir is not None: metrics = {"accuracy": anls} From ad61736abedda0ffb0c76ae62d7f74d804ce61b9 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 8 Sep 2024 22:52:48 -0700 Subject: [PATCH 06/12] fix generate problem, eos is incorrect handled if batch_size > 1 move vision model to prepare input(image size may change if the input is dataset, can not use hpu graph to optimize it) add ci test Signed-off-by: Wang, Yi A --- examples/image-to-text/README.md | 12 +- .../run_image2text_lora_finetune.py | 93 +++++-- .../generation/stopping_criteria.py | 19 +- .../models/idefics2/modeling_idefics2.py | 242 +++++++++++++----- tests/baselines/idefics2_8b.json | 38 +++ tests/test_examples.py | 36 ++- tests/test_image_to_text_example.py | 2 +- tests/utils.py | 1 + 8 files changed, 332 insertions(+), 111 deletions(-) create mode 100644 tests/baselines/idefics2_8b.json diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 81f15b50b1..203edcb77b 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -204,9 +204,9 @@ python3 run_image2text_lora_finetune.py \ --gradient_accumulation_steps 8 \ --weight_decay 0.01 \ --logging_steps 25 \ - --eval_strategy epoch \ + --eval_strategy "no" \ --save_strategy "no" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --warmup_steps 50 \ --lr_scheduler_type "constant" \ --input_column_names 'image' 'query' \ @@ -219,6 +219,8 @@ python3 run_image2text_lora_finetune.py \ --lora_rank=8 \ --lora_alpha=8 \ --lora_dropout=0.1 \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference \ --low_cpu_mem_usage True \ --lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' ``` @@ -236,9 +238,9 @@ python3 ../gaudi_spawn.py \ --gradient_accumulation_steps 8 \ --weight_decay 0.01 \ --logging_steps 25 \ - --eval_strategy epoch \ + --eval_strategy "no" \ --save_strategy "no" \ - --learning_rate 1e-4 \ + --learning_rate 5e-5 \ --warmup_steps 50 \ --lr_scheduler_type "constant" \ --input_column_names 'image' 'query' \ @@ -251,6 +253,8 @@ python3 ../gaudi_spawn.py \ --lora_rank=8 \ --lora_alpha=8 \ --lora_dropout=0.1 \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference \ --low_cpu_mem_usage True \ --lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' ``` diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index ec2ffbf89f..d5e23344c2 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -poly tuning script for sequence-to-sequence modeling +lora fine tuning script for image-to-text case Adapted from the following sources: https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?usp=sharing """ -import json import logging import os import random @@ -29,6 +28,7 @@ import Levenshtein import torch +import transformers from datasets import load_dataset from peft import LoraConfig, get_peft_model from transformers import ( @@ -187,13 +187,6 @@ class DataArguments: default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}, ) - pad_to_max_length: bool = field( - default=False, - metadata={ - "help": "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch." - }, - ) max_train_samples: Optional[int] = field( default=None, metadata={ @@ -258,11 +251,12 @@ class FinetuneArguments: class MyDataCollator: - def __init__(self, processor): + def __init__(self, processor, max_seq_length): self.processor = processor self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ processor.tokenizer.additional_special_tokens.index("") ] + self.max_seq_length = max_seq_length def __call__(self, examples): texts = [] @@ -289,7 +283,14 @@ def __call__(self, examples): texts.append(text.strip()) images.append([image]) - batch = self.processor(text=texts, images=images, return_tensors="pt", padding=True) + batch = self.processor( + text=texts, + images=images, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_seq_length, + ) labels = batch["input_ids"].clone() labels[labels == self.processor.tokenizer.pad_token_id] = -100 @@ -299,14 +300,14 @@ def __call__(self, examples): return batch -def eval(processor, model, eval_dataset, eval_batch_size): +def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length): from tqdm import tqdm answers_unique = [] generated_texts_unique = [] - for i in tqdm(range(0, len(eval_dataset), eval_batch_size)): - examples = eval_dataset[i : i + eval_batch_size] + for i in tqdm(range(0, len(dataset), batch_size)): + examples = dataset[i : i + batch_size] answers_unique.extend(examples["answers"]) images = [[im] for im in examples["image"]] texts = [] @@ -323,9 +324,18 @@ def eval(processor, model, eval_dataset, eval_batch_size): ] text = processor.apply_chat_template(messages, add_generation_prompt=True) texts.append(text.strip()) - inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_seq_length, + ) inputs = {k: v.to("hpu") for k, v in inputs.items()} - generated_ids = model.generate(**inputs, max_new_tokens=64) + generated_ids = model.generate( + **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs + ) generated_texts = processor.batch_decode( generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True ) @@ -362,6 +372,11 @@ def main(): ) logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + processor = AutoProcessor.from_pretrained( model_args.model_name_or_path, do_image_splitting=model_args.do_image_splitting, @@ -444,7 +459,7 @@ def main(): ] ) - data_collator = MyDataCollator(processor) + data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length) gaudi_config = GaudiConfig() gaudi_config.use_fused_adam = True @@ -460,10 +475,15 @@ def main(): ) if training_args.do_train: - trainer.train() + train_result = trainer.train() trainer.save_model() + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) if is_main_process(training_args.local_rank): + processor.tokenizer.padding_side = "left" + example = eval_dataset[15] model.eval() model = model.merge_and_unload() @@ -485,25 +505,44 @@ def main(): ] processor.tokenizer.padding_side = "left" text = processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True) + inputs = processor( + text=[text.strip()], + images=[image], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=data_args.max_seq_length, + ) inputs = {k: v.to("hpu") for k, v in inputs.items()} - generated_ids = model.generate(**inputs, max_new_tokens=64) + generated_ids = model.generate( + **inputs, + max_new_tokens=64, + ignore_eos=False, + lazy_mode=training_args.use_lazy_mode, + hpu_graphs=training_args.use_hpu_graphs_for_inference, + ) generated_texts = processor.batch_decode( generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True ) logger.info(f"generated: {generated_texts}") if training_args.do_eval: + if training_args.use_hpu_graphs_for_inference: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + model = wrap_in_hpu_graph(model) + anls = eval( processor=processor, model=model, - eval_dataset=eval_dataset, - eval_batch_size=training_args.per_device_eval_batch_size, + dataset=eval_dataset, + batch_size=training_args.per_device_eval_batch_size, + use_lazy_mode=training_args.use_lazy_mode, + use_hpu_graphs=training_args.use_hpu_graphs_for_inference, + max_seq_length=data_args.max_seq_length, ) - logger.info(f"anls = {anls}") - if training_args.output_dir is not None: - metrics = {"accuracy": anls} - with open(f"{training_args.output_dir}/accuracy_metrics.json", mode="w") as file: - json.dump(metrics, file) + eval_metrics = {"eval_accuracy": anls} + trainer.log_metrics("eval", eval_metrics) + trainer.save_metrics("eval", eval_metrics) if __name__ == "__main__": diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py index dac7aadd92..6efe2ea87f 100644 --- a/optimum/habana/transformers/generation/stopping_criteria.py +++ b/optimum/habana/transformers/generation/stopping_criteria.py @@ -38,8 +38,10 @@ def gaudi_MaxLengthCriteria_call( ) -> Union[torch.BoolTensor, bool]: token_idx = kwargs.get("token_idx", None) if token_idx is not None: - assert not kwargs["needs_tensor_output"] - return token_idx >= self.max_length + if not kwargs["needs_tensor_output"]: + return token_idx >= self.max_length + else: + return create_return_const_tensor(input_ids, token_idx >= self.max_length) else: cur_len = input_ids.shape[-1] is_done = cur_len >= self.max_length @@ -57,8 +59,10 @@ def gaudi_MaxNewTokensCriteria_call( ) -> Union[torch.BoolTensor, bool]: token_idx = kwargs.get("token_idx", None) if token_idx is not None: - assert not kwargs["needs_tensor_output"] - return token_idx >= self.max_length + if not kwargs["needs_tensor_output"]: + return token_idx >= self.max_length + else: + return create_return_const_tensor(input_ids, token_idx >= self.max_length) else: is_done = input_ids.shape[-1] >= self.max_length return create_return_const_tensor(input_ids, is_done) @@ -80,7 +84,6 @@ def gaudi_EosTokenCriteria_call( self.eos_token_id = self.eos_token_id.to(input_ids.device) token_idx = kwargs.get("token_idx", None) if token_idx is not None: - assert not kwargs["needs_tensor_output"] is_done = torch.isin(input_ids[:, token_idx - 1], self.eos_token_id) else: is_done = torch.isin(input_ids[:, -1], self.eos_token_id) @@ -91,11 +94,7 @@ def gaudi_EosTokenCriteria_call( def needs_tensor_output(token_idx, ignore_eos, eos_token_id) -> bool: - if token_idx is None: - return not ignore_eos and eos_token_id is not None - else: - # token_idx is present, so we have static shapes, so using single boolean - return False + return not ignore_eos and eos_token_id is not None def gaudi_StoppingCriteriaList_call( diff --git a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py index aeb1825540..324edd2812 100644 --- a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py +++ b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py @@ -47,12 +47,10 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]: """ Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303 The only differences are: - - add new args token_idx - ignoring new Cache path for HPU - unfold is not supported in HPU, fallback to cpu """ @@ -157,15 +155,17 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - token_idx=token_idx, ) if return_legacy_cache and use_cache: - outputs.past_key_values = ( - outputs.past_key_values.to_legacy_cache() - if isinstance(outputs.past_key_values, Cache) - else outputs.past_key_values - ) + if return_dict: + outputs.past_key_values = ( + outputs.past_key_values.to_legacy_cache() + if isinstance(outputs.past_key_values, Cache) + else outputs.past_key_values + ) + else: + outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1] if not return_dict: return tuple(v for v in [*outputs, image_hidden_states] if v is not None) @@ -178,6 +178,24 @@ def forward( image_hidden_states=image_hidden_states, ) + def inputs_merger( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.Tensor], + image_hidden_states: Optional[torch.Tensor], + ): + """ + Inherits from Idefics2Model::inputs_merger https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1268 + The only differences are: + - replace `==` with torch.where to fix the issue in hpu graph + """ + num_images, _, vision_hidden_size = image_hidden_states.shape + special_image_token_mask = torch.where(input_ids == self.image_token_id) + new_inputs_embeds = inputs_embeds.clone() + reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) + new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states + return new_inputs_embeds + class GaudiIdefics2ForConditionalGeneration(Idefics2ForConditionalGeneration): def forward( @@ -202,61 +220,122 @@ def forward( The only differences are: - add new args token_idx """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - pixel_values=pixel_values, - pixel_attention_mask=pixel_attention_mask, - image_hidden_states=image_hidden_states, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - token_idx=token_idx, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:].to(logits.device) - shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() - shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() + if token_idx is not None: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + return_legacy_cache = True + use_new_cache = False # Ignoring new Cache path for HPU + if use_cache and use_new_cache: + if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + if past_key_values is not None: + past_seen_tokens = past_key_values[0][0].shape[2] + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError( + "When first calling the model, if input_embeds are passed, input_ids should not be None." + ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + if inputs_embeds is None: + inputs_embeds = self.model.text_model.get_input_embeddings()(input_ids) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + inputs_embeds = self.model.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) - return Idefics2CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, - ) + outputs = self.model.text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + token_idx=token_idx, + ) + + if return_legacy_cache and use_cache: + if return_dict: + outputs.past_key_values = ( + outputs.past_key_values.to_legacy_cache() + if isinstance(outputs.past_key_values, Cache) + else outputs.past_key_values + ) + else: + outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1] + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Idefics2CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_hidden_states, + ) + else: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + image_hidden_states=image_hidden_states, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs @@ -325,6 +404,49 @@ def prepare_inputs_for_generation( else: pixel_values = kwargs.get("pixel_values", None) pixel_attention_mask = kwargs.get("pixel_attention_mask", None) + + if token_idx is not None and pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.cpu().unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.cpu().unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.model.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ).last_hidden_state + + # Modality projection & resampling + image_hidden_states = self.model.connector( + image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) + ) + pixel_values = None + pixel_attention_mask = None + model_inputs.update( { "position_ids": position_ids, diff --git a/tests/baselines/idefics2_8b.json b/tests/baselines/idefics2_8b.json new file mode 100644 index 0000000000..52ae5dd4d9 --- /dev/null +++ b/tests/baselines/idefics2_8b.json @@ -0,0 +1,38 @@ +{ + "gaudi2": { + "image2text_lora_finetune": { + "num_train_epochs": 2, + "eval_batch_size": 4, + "distribution": { + "multi_card": { + "learning_rate": 5e-5, + "train_batch_size": 2, + "train_runtime": 420, + "train_samples_per_second": 6.728, + "eval_accuracy": 0.54, + "extra_arguments": [ + "--bf16", + "--gradient_accumulation_steps 8", + "--eval_strategy no", + "--save_strategy no", + "--warmup_steps 50", + "--lr_scheduler_type constant", + "--max_grad_norm 0.3", + "--logging_steps 1", + "--use_hpu_graphs_for_inference", + "--lora_rank 8", + "--lora_alpha 8", + "--lora_dropout 0.1", + "--lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'", + "--low_cpu_mem_usage True", + "--adam_epsilon 1e-08", + "--input_column_name image query", + "--output_column_name answers", + "--remove_unused_columns False", + "--max_seq_length 512" + ] + } + } + } + } +} diff --git a/tests/test_examples.py b/tests/test_examples.py index 3670b32693..fdd4286c42 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -201,6 +201,11 @@ def is_valid_model_type(model_type: str) -> bool: MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, ["t5"], ), + "run_image2text_lora_finetune": _get_supported_models_for_script( + MODELS_TO_TEST_MAPPING, + MODEL_MAPPING, + ["idefics2"], + ), } @@ -230,17 +235,20 @@ def to_test( "meta-llama/LlamaGuard-7b", ] + case_only_in_gaudi2 = [ + "sft", + "dpo", + "reward_modeling", + "ppo", + "prompt_tuning", + "peft_poly", + "run_sequence_classification", + "run_image2text_lora_finetune", + ] + if (fsdp or fp8) and not IS_GAUDI2: return False - elif ( - "sft" in example_name - or "dpo" in example_name - or "reward_modeling" in example_name - or "ppo" in example_name - or "prompt_tuning" in example_name - or "peft_poly" in example_name - or example_name == "run_sequence_classification" - ) and not IS_GAUDI2: + elif any(case in example_name for case in case_only_in_gaudi2) and not IS_GAUDI2: return False elif "llama" in model_name and "trl-sft-chat" in task_name: return False @@ -899,3 +907,13 @@ class MultiCardCausalLanguageModelingLoRAFP8ExampleTester( ): TASK_NAME = "tatsu-lab/alpaca_fp8" DATASET_NAME = "tatsu-lab/alpaca" + + +class MultiCardImageToTextModelingLoRAExampleTester( + ExampleTesterBase, + metaclass=ExampleTestMeta, + example_name="run_image2text_lora_finetune", + multi_card=True, +): + TASK_NAME = "image2text_lora_finetune" + DATASET_NAME = "nielsr/docvqa_1200_examples" diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index bf71d34ecb..ddfc6092a2 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -19,7 +19,7 @@ ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546), ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742), ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), - ("HuggingFaceM4/idefics2-8b", 1, 25.886229336641385), + ("HuggingFaceM4/idefics2-8b", 1, 24.07768894366222), ], "fp8": [ ("llava-hf/llava-1.5-7b-hf", 1, 115.48515989461843), diff --git a/tests/utils.py b/tests/utils.py index 3e9114d14a..f79cb73284 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -61,6 +61,7 @@ "code_llama": [("codellama/CodeLlama-13b-Instruct-hf", "Habana/llama")], "protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")], "qwen2": [("Qwen/Qwen2-7B", "Habana/qwen")], + "idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")], } MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [ From b6b2baeb42293cbc96942d2b12745ac5589eec99 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 10 Sep 2024 07:06:26 -0700 Subject: [PATCH 07/12] update doc Signed-off-by: Wang, Yi A --- README.md | 1 + docs/source/index.mdx | 1 + examples/image-to-text/README.md | 1 + 3 files changed, 3 insertions(+) diff --git a/README.md b/README.md index 2edc24d2f8..91587ebe9e 100644 --- a/README.md +++ b/README.md @@ -214,6 +214,7 @@ The following model architectures, tasks and device distributions have been vali | OWLViT | |
  • Single card
  • |
  • [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)
  • | | ClipSeg | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | Llava / Llava-next | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| idefics2 | :heavy_check_mark: |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | Segment Anything Model | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 9b6de456c5..306d0c786c 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -72,6 +72,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | OWLViT | |
  • Single card
  • |
  • [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)
  • | | ClipSeg | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | Llava / Llava-next | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| idefics2 | ✅ |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | SAM | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 203edcb77b..7654809ae7 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -28,6 +28,7 @@ Models that have been validated: - [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) - [llava-hf/llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) - [llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) + - [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) ### Inference with BF16 From 998574e3c278611604840918585a25ac31755812 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 10 Sep 2024 07:12:24 -0700 Subject: [PATCH 08/12] fix doc Signed-off-by: Wang, Yi A --- README.md | 2 +- docs/source/index.mdx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 91587ebe9e..f09704cf15 100644 --- a/README.md +++ b/README.md @@ -214,7 +214,7 @@ The following model architectures, tasks and device distributions have been vali | OWLViT | |
  • Single card
  • |
  • [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)
  • | | ClipSeg | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | Llava / Llava-next | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | -| idefics2 | :heavy_check_mark: |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| idefics2 |
  • LoRA
  • |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | Segment Anything Model | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 306d0c786c..6aaade1d8a 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -72,7 +72,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | OWLViT | |
  • Single card
  • |
  • [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection)
  • | | ClipSeg | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | Llava / Llava-next | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | -| idefics2 | ✅ |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| idefics2 |
  • LoRA
  • |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | SAM | |
  • Single card
  • |
  • [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation)
  • | | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | From 960b730ce37ed93495ac1b4276b1d9c4a156ac59 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sat, 14 Sep 2024 00:43:30 -0700 Subject: [PATCH 09/12] add static shape support in image process, replace unfold with conv2d to speedup finetune Signed-off-by: Wang, Yi A --- .../run_image2text_lora_finetune.py | 3 +- optimum/habana/transformers/modeling_utils.py | 4 + .../habana/transformers/models/__init__.py | 7 +- .../transformers/models/idefics2/__init__.py | 3 +- .../idefics2/image_processing_idefics2.py | 84 +++++++++++++++++++ .../models/idefics2/modeling_idefics2.py | 62 ++++++++++++-- tests/baselines/idefics2_8b.json | 6 +- 7 files changed, 155 insertions(+), 14 deletions(-) create mode 100644 optimum/habana/transformers/models/idefics2/image_processing_idefics2.py diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index d5e23344c2..9ee5af3f77 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -382,7 +382,7 @@ def main(): do_image_splitting=model_args.do_image_splitting, padding_side="right", ) - + setattr(processor.image_processor, "pad_to_longest_edge", True) config_kwargs = { "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, @@ -503,7 +503,6 @@ def main(): ], } ] - processor.tokenizer.padding_side = "left" text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor( text=[text.strip()], diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 41007b62bd..a9efd272ee 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -28,6 +28,7 @@ from .models import ( DeciLMConfig, DeciLMForCausalLM, + Gaudi2Idefics2ImageProcessor, GaudiBloomForCausalLM, GaudiBloomMLP, GaudiCLIPAttention, @@ -58,6 +59,7 @@ GaudiGPTNeoXLayer, GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model, + GaudiIdefics2VisionEmbeddings, GaudiLlamaAttention, GaudiLlamaDecoderLayer, GaudiLlamaDynamicNTKScalingRotaryEmbedding, @@ -387,6 +389,8 @@ def adapt_transformers_to_gaudi(): GaudiIdefics2ForConditionalGeneration ) transformers.models.idefics2.modeling_idefics2.Idefics2Model = GaudiIdefics2Model + transformers.models.idefics2.image_processing_idefics2.Idefics2ImageProcessor = Gaudi2Idefics2ImageProcessor + transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings = GaudiIdefics2VisionEmbeddings # Optimization for Clip on Gaudi transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index ab61fc05af..8b6e305883 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -96,7 +96,12 @@ GaudiGPTJForCausalLM, GaudiGPTJModel, ) -from .idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model +from .idefics2 import ( + Gaudi2Idefics2ImageProcessor, + GaudiIdefics2ForConditionalGeneration, + GaudiIdefics2Model, + GaudiIdefics2VisionEmbeddings, +) from .llama import ( GaudiLlamaAttention, GaudiLlamaDecoderLayer, diff --git a/optimum/habana/transformers/models/idefics2/__init__.py b/optimum/habana/transformers/models/idefics2/__init__.py index a862776ad9..5c52749e7a 100644 --- a/optimum/habana/transformers/models/idefics2/__init__.py +++ b/optimum/habana/transformers/models/idefics2/__init__.py @@ -1 +1,2 @@ -from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model +from .image_processing_idefics2 import Gaudi2Idefics2ImageProcessor +from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model, GaudiIdefics2VisionEmbeddings diff --git a/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py b/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py new file mode 100644 index 0000000000..3bf45a36de --- /dev/null +++ b/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Iterable, List, Optional, Union + +import numpy as np +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import ChannelDimension, infer_channel_dimension_format +from transformers.models.idefics2.image_processing_idefics2 import ( + Idefics2ImageProcessor, + get_max_height_width, + make_pixel_mask, +) +from transformers.utils import TensorType + + +class Gaudi2Idefics2ImageProcessor(Idefics2ImageProcessor): + def pad( + self, + images: List[np.ndarray], + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Inherits from Idefics2ImageProcessor::pad https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/image_processing_idefics2.py#L314 + The only differences are: + - pad size use longest_edge, so the image size will not change, aims to accelerate finetune speed + """ + + if getattr(self, "pad_to_longest_edge", False): + pad_size = (self.size["longest_edge"], self.size["longest_edge"]) + else: + pad_size = get_max_height_width(images, input_data_format=input_data_format) + + batch_size = len(images) + max_num_images = max(len(images_) for images_ in images) + input_data_format = ( + infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format + ) + data_format = input_data_format if data_format is None else data_format + + def empty_image(size, input_data_format): + if input_data_format == ChannelDimension.FIRST: + return np.zeros((3, *size), dtype=np.uint8) + elif input_data_format == ChannelDimension.LAST: + return np.zeros((*size, 3), dtype=np.uint8) + raise ValueError("Invalid channel dimension format.") + + padded_images_list = [ + [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size) + ] + padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)] + + for batch_idx in range(batch_size): + for sample_idx, image in enumerate(images[batch_idx]): + padded_images_list[batch_idx][sample_idx] = self._pad_image( + image, + pad_size, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + padded_masks[batch_idx][sample_idx] = make_pixel_mask( + image, output_size=pad_size, input_data_format=input_data_format + ) + + padded_masks = padded_masks if return_pixel_mask else None + return padded_images_list, padded_masks diff --git a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py index 324edd2812..7b92bca9c3 100644 --- a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py +++ b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py @@ -25,6 +25,7 @@ Idefics2CausalLMOutputWithPast, Idefics2ForConditionalGeneration, Idefics2Model, + Idefics2VisionEmbeddings, ) from transformers.utils import logging @@ -32,6 +33,44 @@ logger = logging.get_logger(__name__) +class GaudiIdefics2VisionEmbeddings(Idefics2VisionEmbeddings): + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + """ + Inherits from Idefics2VisionEmbeddings::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L159 + The only differences are: + - add int() in nb_patches_h. nb_patches_w to avoid overflow in torch.arange. sometimes return shape is nb_patches_h/nb_patch_w + 1 + - delete to("cpu") of p_attn_mask + """ + + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), + fill_value=0, + device=self.position_embedding.weight.device, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = int(p_attn_mask[:, 0].sum()) + nb_patches_w = int(p_attn_mask[0].sum()) + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + class GaudiIdefics2Model(Idefics2Model): def forward( self, @@ -52,7 +91,7 @@ def forward( Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303 The only differences are: - ignoring new Cache path for HPU - - unfold is not supported in HPU, fallback to cpu + - unfold is not supported in HPU, replace with conv2d """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -120,9 +159,13 @@ def forward( pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.cpu().unfold(dimension=1, size=patch_size, step=patch_size) - patches_subgrid = patches_subgrid.cpu().unfold(dimension=2, size=patch_size, step=patch_size) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size + ).squeeze(1) + patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size)) # Get sequence from the vision encoder image_hidden_states = self.vision_model( @@ -345,6 +388,7 @@ def prepare_inputs_for_generation( The only differences are: - add new args token_idx - add None "Cache" past_key_values support + - move vision_model to prepare_input_for_generation """ past_length = 0 token_idx = kwargs.get("token_idx", None) @@ -430,9 +474,13 @@ def prepare_inputs_for_generation( pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.cpu().unfold(dimension=1, size=patch_size, step=patch_size) - patches_subgrid = patches_subgrid.cpu().unfold(dimension=2, size=patch_size, step=patch_size) - patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + conv_kernel = torch.ones( + [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device + ) + patches_subgrid = torch.nn.functional.conv2d( + pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size + ).squeeze(1) + patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size)) # Get sequence from the vision encoder image_hidden_states = self.model.vision_model( diff --git a/tests/baselines/idefics2_8b.json b/tests/baselines/idefics2_8b.json index 52ae5dd4d9..c70c8efef8 100644 --- a/tests/baselines/idefics2_8b.json +++ b/tests/baselines/idefics2_8b.json @@ -7,9 +7,9 @@ "multi_card": { "learning_rate": 5e-5, "train_batch_size": 2, - "train_runtime": 420, - "train_samples_per_second": 6.728, - "eval_accuracy": 0.54, + "train_runtime": 240, + "train_samples_per_second": 14, + "eval_accuracy": 0.6, "extra_arguments": [ "--bf16", "--gradient_accumulation_steps 8", From baca7e90e674d1f679f76b81231403a516f77f76 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 17 Sep 2024 06:26:18 -0700 Subject: [PATCH 10/12] blip does not has chat template Signed-off-by: Wang, Yi A --- examples/image-to-text/run_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 66db3b9a4a..49d59615bc 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -162,7 +162,7 @@ def main(): "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - if args.prompt is None: + if args.prompt is None and model_type in ["llava", "idefics2", "llava_next"]: processor = AutoProcessor.from_pretrained(args.model_name_or_path) conversation = [ { From 09a42ad1570f23c59d4b5eb5f302d6a0fccf0c8c Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 17 Sep 2024 22:34:43 -0700 Subject: [PATCH 11/12] update README and test since default "use_kv_cache" is False if it's not indicated explicitily Signed-off-by: Wang, Yi A --- examples/image-to-text/README.md | 90 +++++++++++++++++------------ tests/test_image_to_text_example.py | 10 ++-- 2 files changed, 60 insertions(+), 40 deletions(-) diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 509ccbcd02..6d66697db1 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -40,6 +40,7 @@ python3 run_pipeline.py \ --model_name_or_path Salesforce/blip-image-captioning-large \ --image_path "https://ankur3107.github.io/assets/images/image-captioning-example.png" \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -48,6 +49,7 @@ To run Llava-1.5-7b inference, use the following command: python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-1.5-7b-hf \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -56,6 +58,7 @@ To run Llava-1.5-13b inference, use the following command: python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-1.5-13b-hf \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -64,6 +67,7 @@ To run Llava-v1.6-mistral-7b inference, use the following command: python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -72,6 +76,7 @@ To run Llava-v1.6-vicuna-13b inference, use the following command: python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -81,6 +86,7 @@ To run Llava-hf/llava-v1.6-34b-hf inference, use the following command: python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-v1.6-34b-hf \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -90,6 +96,7 @@ To run Llava-hf/llama3-llava-next-8b-hf inference, use the following command: python3 run_pipeline.py \ --model_name_or_path llava-hf/llama3-llava-next-8b-hf \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -99,6 +106,7 @@ To run idefics2 inference, use the following command: python3 run_pipeline.py \ --model_name_or_path HuggingFaceM4/idefics2-8b \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 ``` @@ -111,56 +119,62 @@ https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP Here is an example to measure the tensor quantization statistics on Llava-1.5-7b: ```bash QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-1.5-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 + --model_name_or_path llava-hf/llava-1.5-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 ``` Here is an example to quantize the model based on previous measurements for Llava-1.5-7b: ```bash QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-1.5-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 + --model_name_or_path llava-hf/llava-1.5-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 ``` Here is an example to measure the tensor quantization statistics on Llava-v1.6-mistral-7b: ```bash QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 ``` Here is an example to quantize the model based on previous measurements for Llava-v1.6-mistral-7b: ```bash QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 ``` Here is an example to measure the tensor quantization statistics on Llava-v1.6-vicuna-13b: ```bash QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 + --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 ``` Here is an example to quantize the model based on previous measurements for Llava-v1.6-vicuna-13b: ```bash QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 + --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 ``` ### Inference with FusedSDPA @@ -173,6 +187,7 @@ python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-1.5-7b-hf \ --image_path "https://llava-vl.github.io/static/images/view.jpg" \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 \ --use_flash_attention \ --flash_attention_recompute @@ -185,6 +200,7 @@ python3 run_pipeline.py \ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ --image_path "https://llava-vl.github.io/static/images/view.jpg" \ --use_hpu_graphs \ + --use_kv_cache \ --bf16 \ --use_flash_attention \ --flash_attention_recompute @@ -196,23 +212,25 @@ Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with Fused Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b: ```bash QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 \ ---use_flash_attention \ ---flash_attention_recompute + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 \ + --use_flash_attention \ + --flash_attention_recompute ``` Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b: ```bash QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \ ---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ ---image_path "https://llava-vl.github.io/static/images/view.jpg" \ ---use_hpu_graphs \ ---bf16 \ ---use_flash_attention \ ---flash_attention_recompute + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \ + --image_path "https://llava-vl.github.io/static/images/view.jpg" \ + --use_hpu_graphs \ + --use_kv_cache \ + --bf16 \ + --use_flash_attention \ + --flash_attention_recompute ``` ## LORA Finetune diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index ddfc6092a2..33e88e1fe2 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -14,16 +14,16 @@ # Gaudi2 CI baselines MODELS_TO_TEST = { "bf16": [ - ("llava-hf/llava-1.5-7b-hf", 1, 87.2901500056982), + ("llava-hf/llava-1.5-7b-hf", 1, 82.3422128290106), ("llava-hf/llava-1.5-13b-hf", 1, 51.04717105443364), ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546), ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742), ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), - ("HuggingFaceM4/idefics2-8b", 1, 24.07768894366222), + ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077), ], "fp8": [ - ("llava-hf/llava-1.5-7b-hf", 1, 115.48515989461843), - ("llava-hf/llava-1.5-13b-hf", 1, 78.2635142547838), + ("llava-hf/llava-1.5-7b-hf", 1, 105.25707848037551), + ("llava-hf/llava-1.5-13b-hf", 1, 66.40730104076319), ("llava-hf/llava-v1.6-mistral-7b-hf", 1, 45.011551008367084), ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 45.18544502949674), ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 30.9535718774675), @@ -58,6 +58,8 @@ def _test_image_to_text( f"--model_name_or_path {model_name}", f"--batch_size {batch_size}", "--max_new_tokens 20", + "--ignore_eos", + "--use_kv_cache", ] command += [ From a282dbeb50cea16a4ae846e37f3a4095c13c46ae Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 10 Oct 2024 19:42:55 -0700 Subject: [PATCH 12/12] remove token_idx in needs_tensor_out Signed-off-by: Wang, Yi A --- optimum/habana/transformers/generation/stopping_criteria.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py index be14dc73d7..844ffa50f2 100644 --- a/optimum/habana/transformers/generation/stopping_criteria.py +++ b/optimum/habana/transformers/generation/stopping_criteria.py @@ -79,7 +79,7 @@ def gaudi_EosTokenCriteria_call( return torch.all(is_done).item() -def needs_tensor_output(token_idx, ignore_eos, eos_token_id) -> bool: +def needs_tensor_output(ignore_eos, eos_token_id) -> bool: return not ignore_eos and eos_token_id is not None @@ -87,7 +87,7 @@ def gaudi_StoppingCriteriaList_call( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> Union[torch.BoolTensor, bool]: kwargs["needs_tensor_output"] = needs_tensor_output( - kwargs.get("token_idx", None), kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None) + kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None) ) is_done = ( torch.full((input_ids.shape[0],), 0, device=input_ids.device, dtype=torch.int8)