diff --git a/README.md b/README.md
index 842c9725a7..d7a8c0c261 100644
--- a/README.md
+++ b/README.md
@@ -228,6 +228,7 @@ The following model architectures, tasks and device distributions have been vali
| OWLViT | |
Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| idefics2 | LoRA | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
| Segment Anything Model | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| VideoMAE | | Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
| TableTransformer | | Single card | [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 947e41b6fe..18ac4ada6f 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -74,6 +74,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| OWLViT | | Single card | [zero shot object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/zero-shot-object-detection) |
| ClipSeg | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| Llava / Llava-next | | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
+| idefics2 | LoRA | Single card | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
| SAM | | Single card | [object segmentation](https://github.com/huggingface/optimum-habana/tree/main/examples/object-segementation) |
| VideoMAE | | Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
| TableTransformer | | Single card | [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) |
diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md
index 2ac99dc829..2ad32174ef 100644
--- a/examples/image-to-text/README.md
+++ b/examples/image-to-text/README.md
@@ -30,6 +30,7 @@ Models that have been validated:
- [llava-hf/llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)
- [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
+ - [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
### Inference with BF16
@@ -92,6 +93,15 @@ python3 run_pipeline.py \
--bf16
```
+To run idefics2 inference, use the following command:
+
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path HuggingFaceM4/idefics2-8b \
+ --use_hpu_graphs \
+ --bf16
+```
+
### Inference with FP8
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`.
@@ -101,56 +111,56 @@ https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP
Here is an example to measure the tensor quantization statistics on Llava-1.5-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-1.5-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-1.5-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to quantize the model based on previous measurements for Llava-1.5-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-1.5-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-1.5-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to measure the tensor quantization statistics on Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to quantize the model based on previous measurements for Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to measure the tensor quantization statistics on Llava-v1.6-vicuna-13b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
Here is an example to quantize the model based on previous measurements for Llava-v1.6-vicuna-13b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16
+ --model_name_or_path llava-hf/llava-v1.6-vicuna-13b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16
```
### Inference with FusedSDPA
@@ -186,21 +196,92 @@ Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with Fused
Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16 \
---use_flash_attention \
---flash_attention_recompute
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16 \
+ --use_flash_attention \
+ --flash_attention_recompute
```
Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
---model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
---image_path "https://llava-vl.github.io/static/images/view.jpg" \
---use_hpu_graphs \
---bf16 \
---use_flash_attention \
---flash_attention_recompute
+ --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
+ --image_path "https://llava-vl.github.io/static/images/view.jpg" \
+ --use_hpu_graphs \
+ --bf16 \
+ --use_flash_attention \
+ --flash_attention_recompute
+```
+## LORA Finetune
+
+To run LoRA finetuning, you can use `run_image2text_lora_finetune.py`.
+Here are single-/multi-device command examples for HuggingFaceM4/idefics2-8b.
+
+```bash
+python3 run_image2text_lora_finetune.py \
+ --model_name_or_path HuggingFaceM4/idefics2-8b \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference \
+ --low_cpu_mem_usage True \
+ --lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'
+```
+
+```bash
+python3 ../gaudi_spawn.py \
+ --world_size 8 --use_mpi run_image2text_lora_finetune.py \
+ --model_name_or_path HuggingFaceM4/idefics2-8b \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 1 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference \
+ --low_cpu_mem_usage True \
+ --lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```
diff --git a/examples/image-to-text/requirements.txt b/examples/image-to-text/requirements.txt
new file mode 100644
index 0000000000..045bd69e64
--- /dev/null
+++ b/examples/image-to-text/requirements.txt
@@ -0,0 +1,2 @@
+peft == 0.12.0
+Levenshtein
diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py
new file mode 100644
index 0000000000..9ee5af3f77
--- /dev/null
+++ b/examples/image-to-text/run_image2text_lora_finetune.py
@@ -0,0 +1,548 @@
+# Apache v2 license
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+lora fine tuning script for image-to-text case
+Adapted from the following sources:
+https://colab.research.google.com/drive/1rm3AGquGEYXfeeizE40bbDtcWh5S4Nlq?usp=sharing
+"""
+
+import logging
+import os
+import random
+import sys
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import Levenshtein
+import torch
+import transformers
+from datasets import load_dataset
+from peft import LoraConfig, get_peft_model
+from transformers import (
+ AutoConfig,
+ AutoModelForVision2Seq,
+ AutoProcessor,
+ HfArgumentParser,
+)
+from transformers.trainer_utils import is_main_process
+
+from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
+
+
+try:
+ from optimum.habana.utils import check_optimum_habana_min_version
+except ImportError:
+
+ def check_optimum_habana_min_version(*a, **b):
+ return ()
+
+
+os.environ["WANDB_DISABLED"] = "true"
+
+logger = logging.getLogger(__name__)
+
+# Will error if the minimal version of Optimum Habana is not installed. Remove at your own risks.
+check_optimum_habana_min_version("1.10.0")
+
+
+def normalized_levenshtein(s1, s2):
+ len_s1, len_s2 = len(s1), len(s2)
+ distance = Levenshtein.distance(s1, s2)
+ return distance / max(len_s1, len_s2)
+
+
+def similarity_score(a_ij, o_q_i, tau=0.5):
+ nl = normalized_levenshtein(a_ij, o_q_i)
+ return 1 - nl if nl < tau else 0
+
+
+def average_normalized_levenshtein_similarity(ground_truth, predicted_answers):
+ assert len(ground_truth) == len(predicted_answers), "Length of ground_truth and predicted_answers must match."
+
+ N = len(ground_truth)
+ total_score = 0
+
+ for i in range(N):
+ a_i = ground_truth[i]
+ o_q_i = predicted_answers[i]
+ if o_q_i == "":
+ print("Warning: Skipped an empty prediction.")
+ max_score = 0
+ else:
+ max_score = max(similarity_score(a_ij, o_q_i) for a_ij in a_i)
+
+ total_score += max_score
+
+ return total_score / N
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/processor we are going to fine-tune, or train from scratch.
+ """
+
+ model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The model checkpoint for weights initialization."
+ "Don't set if you want to train a model from scratch."
+ },
+ )
+ config_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ token: Optional[str] = field(
+ default=None,
+ metadata={"help": "auth token for private models"},
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ trust_remote_code: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to trust the execution of code from datasets/models defined on the Hub."
+ " This option should only be set to `True` for repositories you trust and in which you have read the"
+ " code, as it will execute code present on the Hub on your local machine."
+ )
+ },
+ )
+ use_cache: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "Whether or not the model should return the last key/values attentions (not used by all models)."
+ "Only relevant if `config.is_decoder=True`."
+ )
+ },
+ )
+ low_cpu_mem_usage: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
+ "When set to True, it will benefit LLM loading time and RAM consumption."
+ )
+ },
+ )
+ load_meta_device: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "It is an option to load the model to the device instead of the host, so it can reduce the host RAM usage."
+ "https://huggingface.co/blog/accelerate-large-models"
+ )
+ },
+ )
+ do_image_splitting: bool = field(default=False, metadata={"help": "Whether to do image split during finetune."})
+
+
+@dataclass
+class DataArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ dataset_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
+ )
+ max_seq_length: Optional[int] = field(
+ default=512,
+ metadata={
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated."
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False,
+ metadata={"help": "Overwrite the cached preprocessed datasets or not."},
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ },
+ )
+ dataset_seed: int = field(
+ default=42,
+ metadata={
+ "help": "Seed to use in dataset processing, different seeds might yield different datasets. This seed and the seed in training arguments are not related"
+ },
+ )
+ save_last_ckpt: bool = field(
+ default=True, metadata={"help": "Whether to save checkpoint at the end of the training."}
+ )
+ input_column_names: List[str] = field(
+ default_factory=lambda: None,
+ metadata={
+ "help": "Name of the column in the dataset that optionally provides context or input for the task. By "
+ "default, 'image,query' columns are used"
+ },
+ )
+ output_column_names: List[str] = field(
+ default_factory=lambda: None,
+ metadata={
+ "help": "Name of the column in the dataset with the answer to the instruction. By default, the "
+ "'answers' column is used"
+ },
+ )
+
+
+@dataclass
+class FinetuneArguments:
+ """
+ Arguments of finetune we are going to apply on the model.
+ """
+
+ lora_rank: int = field(
+ default=8,
+ metadata={"help": "Rank parameter in the LoRA method."},
+ )
+ lora_alpha: int = field(
+ default=8,
+ metadata={"help": "Alpha parameter in the LoRA method."},
+ )
+ lora_dropout: float = field(
+ default=0.1,
+ metadata={"help": "Dropout parameter in the LoRA method."},
+ )
+ lora_target_modules: str = field(
+ default=None,
+ metadata={"help": "Target modules for the LoRA/AdaLoRA method."},
+ )
+
+
+class MyDataCollator:
+ def __init__(self, processor, max_seq_length):
+ self.processor = processor
+ self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
+ processor.tokenizer.additional_special_tokens.index("")
+ ]
+ self.max_seq_length = max_seq_length
+
+ def __call__(self, examples):
+ texts = []
+ images = []
+ keys = list(examples[0].keys())
+ if not all(key in ["image", "query", "answers"] for key in keys):
+ raise ValueError("Unsupported dataset format")
+ for example in examples:
+ image = example["image"]
+ question = example["query"]["en"]
+ answer = random.choice(example["answers"])
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Answer briefly."},
+ {"type": "image"},
+ {"type": "text", "text": question},
+ ],
+ },
+ {"role": "assistant", "content": [{"type": "text", "text": answer}]},
+ ]
+ text = self.processor.apply_chat_template(messages, add_generation_prompt=False)
+ texts.append(text.strip())
+ images.append([image])
+
+ batch = self.processor(
+ text=texts,
+ images=images,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=self.max_seq_length,
+ )
+
+ labels = batch["input_ids"].clone()
+ labels[labels == self.processor.tokenizer.pad_token_id] = -100
+ labels[labels == self.image_token_id] = -100
+ batch["labels"] = labels
+
+ return batch
+
+
+def eval(processor, model, dataset, batch_size, use_lazy_mode, use_hpu_graphs, max_seq_length):
+ from tqdm import tqdm
+
+ answers_unique = []
+ generated_texts_unique = []
+
+ for i in tqdm(range(0, len(dataset), batch_size)):
+ examples = dataset[i : i + batch_size]
+ answers_unique.extend(examples["answers"])
+ images = [[im] for im in examples["image"]]
+ texts = []
+ for q in examples["query"]:
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Answer briefly."},
+ {"type": "image"},
+ {"type": "text", "text": q["en"]},
+ ],
+ }
+ ]
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ texts.append(text.strip())
+ inputs = processor(
+ text=texts,
+ images=images,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_seq_length,
+ )
+ inputs = {k: v.to("hpu") for k, v in inputs.items()}
+ generated_ids = model.generate(
+ **inputs, max_new_tokens=64, ignore_eos=False, lazy_mode=use_lazy_mode, hpu_graphs=use_hpu_graphs
+ )
+ generated_texts = processor.batch_decode(
+ generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
+ )
+ generated_texts_unique.extend(generated_texts)
+ generated_texts_unique = [g.strip().strip(".") for g in generated_texts_unique]
+ anls = average_normalized_levenshtein_similarity(
+ ground_truth=answers_unique,
+ predicted_answers=generated_texts_unique,
+ )
+ return anls
+
+
+def main():
+ parser = HfArgumentParser((ModelArguments, DataArguments, GaudiTrainingArguments, FinetuneArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args, finetune_args = parser.parse_json_file(
+ json_file=os.path.abspath(sys.argv[1])
+ )
+ else:
+ (
+ model_args,
+ data_args,
+ training_args,
+ finetune_args,
+ ) = parser.parse_args_into_dataclasses()
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
+
+ if is_main_process(training_args.local_rank):
+ transformers.utils.logging.set_verbosity_info()
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ processor = AutoProcessor.from_pretrained(
+ model_args.model_name_or_path,
+ do_image_splitting=model_args.do_image_splitting,
+ padding_side="right",
+ )
+ setattr(processor.image_processor, "pad_to_longest_edge", True)
+ config_kwargs = {
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "trust_remote_code": True if model_args.trust_remote_code else None,
+ "use_cache": False if training_args.gradient_checkpointing else model_args.use_cache,
+ "token": model_args.token,
+ }
+ if model_args.config_name:
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
+ elif model_args.model_name_or_path:
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
+ else:
+ raise ValueError("Please provide value for model_name_or_path or config_name.")
+
+ # Load model
+ if model_args.model_name_or_path:
+ model_dtype = torch.bfloat16 if training_args.bf16 else None
+ model = AutoModelForVision2Seq.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ trust_remote_code=True if model_args.trust_remote_code else None,
+ torch_dtype=model_dtype,
+ low_cpu_mem_usage=model_args.low_cpu_mem_usage,
+ device_map=training_args.device.type if model_args.load_meta_device else None,
+ token=model_args.token,
+ )
+ else:
+ raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")
+
+ lora_config = LoraConfig(
+ r=finetune_args.lora_rank,
+ lora_alpha=finetune_args.lora_alpha,
+ lora_dropout=finetune_args.lora_dropout,
+ target_modules=finetune_args.lora_target_modules,
+ init_lora_weights="gaussian",
+ )
+ model = get_peft_model(model, lora_config)
+ model.print_trainable_parameters()
+
+ train_dataset = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ split="train",
+ )
+
+ train_dataset = train_dataset.remove_columns(
+ [
+ col
+ for col in train_dataset.column_names
+ if col not in (data_args.input_column_names + data_args.output_column_names)
+ ]
+ )
+
+ eval_dataset = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ token=model_args.token,
+ trust_remote_code=model_args.trust_remote_code,
+ split="test",
+ )
+
+ eval_dataset = eval_dataset.remove_columns(
+ [
+ col
+ for col in eval_dataset.column_names
+ if col not in (data_args.input_column_names + data_args.output_column_names)
+ ]
+ )
+
+ data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length)
+
+ gaudi_config = GaudiConfig()
+ gaudi_config.use_fused_adam = True
+ gaudi_config.use_fused_clip_norm = True
+
+ trainer = GaudiTrainer(
+ model=model,
+ args=training_args,
+ gaudi_config=gaudi_config,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+
+ if training_args.do_train:
+ train_result = trainer.train()
+ trainer.save_model()
+ metrics = train_result.metrics
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+
+ if is_main_process(training_args.local_rank):
+ processor.tokenizer.padding_side = "left"
+
+ example = eval_dataset[15]
+ model.eval()
+ model = model.merge_and_unload()
+ if model_dtype == torch.bfloat16:
+ model = model.to(torch.bfloat16)
+
+ image = example["image"]
+ query = example["query"]
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Answer briefly."},
+ {"type": "image"},
+ {"type": "text", "text": query["en"]},
+ ],
+ }
+ ]
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ inputs = processor(
+ text=[text.strip()],
+ images=[image],
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=data_args.max_seq_length,
+ )
+ inputs = {k: v.to("hpu") for k, v in inputs.items()}
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=64,
+ ignore_eos=False,
+ lazy_mode=training_args.use_lazy_mode,
+ hpu_graphs=training_args.use_hpu_graphs_for_inference,
+ )
+ generated_texts = processor.batch_decode(
+ generated_ids[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
+ )
+ logger.info(f"generated: {generated_texts}")
+ if training_args.do_eval:
+ if training_args.use_hpu_graphs_for_inference:
+ from habana_frameworks.torch.hpu import wrap_in_hpu_graph
+
+ model = wrap_in_hpu_graph(model)
+
+ anls = eval(
+ processor=processor,
+ model=model,
+ dataset=eval_dataset,
+ batch_size=training_args.per_device_eval_batch_size,
+ use_lazy_mode=training_args.use_lazy_mode,
+ use_hpu_graphs=training_args.use_hpu_graphs_for_inference,
+ max_seq_length=data_args.max_seq_length,
+ )
+ eval_metrics = {"eval_accuracy": anls}
+ trainer.log_metrics("eval", eval_metrics)
+ trainer.save_metrics("eval", eval_metrics)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py
index d80939b43f..00ade6aa5b 100644
--- a/examples/image-to-text/run_pipeline.py
+++ b/examples/image-to-text/run_pipeline.py
@@ -23,7 +23,7 @@
import PIL.Image
import requests
import torch
-from transformers import AutoConfig, LlavaNextProcessor, LlavaProcessor, pipeline
+from transformers import AutoConfig, AutoProcessor, pipeline
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@@ -155,17 +155,15 @@ def main():
adapt_transformers_to_gaudi()
model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
- if args.image_path is None and model_type == "llava":
+ if args.image_path is None and model_type in ["llava", "idefics2"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "llava_next":
args.image_path = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
- if args.prompt is None and model_type in ("llava", "llava_next"):
- if model_type == "llava":
- processor = LlavaProcessor.from_pretrained(args.model_name_or_path)
- elif model_type == "llava_next":
- processor = LlavaNextProcessor.from_pretrained(args.model_name_or_path)
+
+ if args.prompt is None and model_type in ["llava", "idefics2", "llava_next"]:
+ processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
"role": "user",
@@ -228,6 +226,17 @@ def main():
if args.quant_config:
generator.model = setup_quantization(generator.model, args)
+ # delete once pipeline integrate AutoProcessor as preprocess engine
+ if model_type in ["idefics2"]:
+ from transformers.image_utils import load_image
+
+ def preprocess(self, image, prompt=None, timeout=None):
+ image = load_image(image, timeout=timeout)
+ model_inputs = processor(images=image, text=prompt, return_tensors=self.framework)
+ return model_inputs
+
+ generator.__class__.preprocess = preprocess
+
# warm up
for i in range(args.warmup):
generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs)
diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py
index 69325ab7b3..844ffa50f2 100644
--- a/optimum/habana/transformers/generation/stopping_criteria.py
+++ b/optimum/habana/transformers/generation/stopping_criteria.py
@@ -38,8 +38,10 @@ def gaudi_MaxLengthCriteria_call(
) -> Union[torch.BoolTensor, bool]:
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
- assert not kwargs["needs_tensor_output"]
- return token_idx >= self.max_length
+ if not kwargs["needs_tensor_output"]:
+ return token_idx >= self.max_length
+ else:
+ return create_return_const_tensor(input_ids, token_idx >= self.max_length)
else:
cur_len = input_ids.shape[-1]
is_done = cur_len >= self.max_length
@@ -68,7 +70,6 @@ def gaudi_EosTokenCriteria_call(
self.eos_token_id = self.eos_token_id.to(input_ids.device)
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
- assert not kwargs["needs_tensor_output"]
is_done = torch.isin(input_ids[:, token_idx - 1], self.eos_token_id)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id)
@@ -78,19 +79,15 @@ def gaudi_EosTokenCriteria_call(
return torch.all(is_done).item()
-def needs_tensor_output(token_idx, ignore_eos, eos_token_id) -> bool:
- if token_idx is None:
- return not ignore_eos and eos_token_id is not None
- else:
- # token_idx is present, so we have static shapes, so using single boolean
- return False
+def needs_tensor_output(ignore_eos, eos_token_id) -> bool:
+ return not ignore_eos and eos_token_id is not None
def gaudi_StoppingCriteriaList_call(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> Union[torch.BoolTensor, bool]:
kwargs["needs_tensor_output"] = needs_tensor_output(
- kwargs.get("token_idx", None), kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None)
+ kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None)
)
is_done = (
torch.full((input_ids.shape[0],), 0, device=input_ids.device, dtype=torch.int8)
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index a76ea59e87..35e83663dc 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -108,6 +108,7 @@
"qwen2_moe",
"gemma",
"whisper",
+ "idefics2",
]
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 2fd24148be..010d3abeb3 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -30,6 +30,7 @@
GAUDI_WHISPER_ATTENTION_CLASSES,
DeciLMConfig,
DeciLMForCausalLM,
+ Gaudi2Idefics2ImageProcessor,
GaudiBloomForCausalLM,
GaudiBloomMLP,
GaudiCLIPAttention,
@@ -64,6 +65,9 @@
GaudiGPTNeoXAttention,
GaudiGPTNeoXForCausalLM,
GaudiGPTNeoXLayer,
+ GaudiIdefics2ForConditionalGeneration,
+ GaudiIdefics2Model,
+ GaudiIdefics2VisionEmbeddings,
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
@@ -410,6 +414,14 @@ def adapt_transformers_to_gaudi():
GaudiLlavaNextForConditionalGeneration
)
+ # Optimization for idefics2 on Gaudi
+ transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration = (
+ GaudiIdefics2ForConditionalGeneration
+ )
+ transformers.models.idefics2.modeling_idefics2.Idefics2Model = GaudiIdefics2Model
+ transformers.models.idefics2.image_processing_idefics2.Idefics2ImageProcessor = Gaudi2Idefics2ImageProcessor
+ transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings = GaudiIdefics2VisionEmbeddings
+
# Optimization for Clip on Gaudi
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings = GaudiCLIPVisionEmbeddings
transformers.models.clip.modeling_clip.CLIPAttention = GaudiCLIPAttention
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 8c9a045efa..7e781272c3 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -103,6 +103,12 @@
GaudiGPTJForCausalLM,
GaudiGPTJModel,
)
+from .idefics2 import (
+ Gaudi2Idefics2ImageProcessor,
+ GaudiIdefics2ForConditionalGeneration,
+ GaudiIdefics2Model,
+ GaudiIdefics2VisionEmbeddings,
+)
from .llama import (
GaudiLlamaAttention,
GaudiLlamaDecoderLayer,
diff --git a/optimum/habana/transformers/models/idefics2/__init__.py b/optimum/habana/transformers/models/idefics2/__init__.py
new file mode 100644
index 0000000000..5c52749e7a
--- /dev/null
+++ b/optimum/habana/transformers/models/idefics2/__init__.py
@@ -0,0 +1,2 @@
+from .image_processing_idefics2 import Gaudi2Idefics2ImageProcessor
+from .modeling_idefics2 import GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model, GaudiIdefics2VisionEmbeddings
diff --git a/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py b/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py
new file mode 100644
index 0000000000..3bf45a36de
--- /dev/null
+++ b/optimum/habana/transformers/models/idefics2/image_processing_idefics2.py
@@ -0,0 +1,84 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Iterable, List, Optional, Union
+
+import numpy as np
+from transformers.image_processing_utils import BatchFeature
+from transformers.image_utils import ChannelDimension, infer_channel_dimension_format
+from transformers.models.idefics2.image_processing_idefics2 import (
+ Idefics2ImageProcessor,
+ get_max_height_width,
+ make_pixel_mask,
+)
+from transformers.utils import TensorType
+
+
+class Gaudi2Idefics2ImageProcessor(Idefics2ImageProcessor):
+ def pad(
+ self,
+ images: List[np.ndarray],
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> BatchFeature:
+ """
+ Inherits from Idefics2ImageProcessor::pad https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/image_processing_idefics2.py#L314
+ The only differences are:
+ - pad size use longest_edge, so the image size will not change, aims to accelerate finetune speed
+ """
+
+ if getattr(self, "pad_to_longest_edge", False):
+ pad_size = (self.size["longest_edge"], self.size["longest_edge"])
+ else:
+ pad_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ batch_size = len(images)
+ max_num_images = max(len(images_) for images_ in images)
+ input_data_format = (
+ infer_channel_dimension_format(images[0][0]) if input_data_format is None else input_data_format
+ )
+ data_format = input_data_format if data_format is None else data_format
+
+ def empty_image(size, input_data_format):
+ if input_data_format == ChannelDimension.FIRST:
+ return np.zeros((3, *size), dtype=np.uint8)
+ elif input_data_format == ChannelDimension.LAST:
+ return np.zeros((*size, 3), dtype=np.uint8)
+ raise ValueError("Invalid channel dimension format.")
+
+ padded_images_list = [
+ [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
+ ]
+ padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
+
+ for batch_idx in range(batch_size):
+ for sample_idx, image in enumerate(images[batch_idx]):
+ padded_images_list[batch_idx][sample_idx] = self._pad_image(
+ image,
+ pad_size,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ padded_masks[batch_idx][sample_idx] = make_pixel_mask(
+ image, output_size=pad_size, input_data_format=input_data_format
+ )
+
+ padded_masks = padded_masks if return_pixel_mask else None
+ return padded_images_list, padded_masks
diff --git a/optimum/habana/transformers/models/idefics2/modeling_idefics2.py b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py
new file mode 100644
index 0000000000..7b92bca9c3
--- /dev/null
+++ b/optimum/habana/transformers/models/idefics2/modeling_idefics2.py
@@ -0,0 +1,510 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Idefics2 model."""
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch.nn import CrossEntropyLoss
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.models.idefics2.modeling_idefics2 import (
+ Idefics2BaseModelOutputWithPast,
+ Idefics2CausalLMOutputWithPast,
+ Idefics2ForConditionalGeneration,
+ Idefics2Model,
+ Idefics2VisionEmbeddings,
+)
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GaudiIdefics2VisionEmbeddings(Idefics2VisionEmbeddings):
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
+ """
+ Inherits from Idefics2VisionEmbeddings::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L159
+ The only differences are:
+ - add int() in nb_patches_h. nb_patches_w to avoid overflow in torch.arange. sometimes return shape is nb_patches_h/nb_patch_w + 1
+ - delete to("cpu") of p_attn_mask
+ """
+
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
+ position_ids = torch.full(
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w),
+ fill_value=0,
+ device=self.position_embedding.weight.device,
+ )
+
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
+ nb_patches_h = int(p_attn_mask[:, 0].sum())
+ nb_patches_w = int(p_attn_mask[0].sum())
+
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
+
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
+
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
+ position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids.to(self.position_embedding.weight.device)
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class GaudiIdefics2Model(Idefics2Model):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]:
+ """
+ Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303
+ The only differences are:
+ - ignoring new Cache path for HPU
+ - unfold is not supported in HPU, replace with conv2d
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_seen_tokens = 0
+ return_legacy_cache = True
+ use_new_cache = False # Ignoring new Cache path for HPU
+ if use_cache and use_new_cache:
+ if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+ else:
+ if past_key_values is not None:
+ past_seen_tokens = past_key_values[0][0].shape[2]
+ if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
+ raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+ elif pixel_values is not None:
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = pixel_attention_mask.view(
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size
+ ).squeeze(1)
+ patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size))
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ ).last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.connector(
+ image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
+ )
+
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+
+ if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if return_legacy_cache and use_cache:
+ if return_dict:
+ outputs.past_key_values = (
+ outputs.past_key_values.to_legacy_cache()
+ if isinstance(outputs.past_key_values, Cache)
+ else outputs.past_key_values
+ )
+ else:
+ outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1]
+
+ if not return_dict:
+ return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
+
+ return Idefics2BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+
+ def inputs_merger(
+ self,
+ input_ids: torch.LongTensor,
+ inputs_embeds: Optional[torch.Tensor],
+ image_hidden_states: Optional[torch.Tensor],
+ ):
+ """
+ Inherits from Idefics2Model::inputs_merger https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1268
+ The only differences are:
+ - replace `==` with torch.where to fix the issue in hpu graph
+ """
+ num_images, _, vision_hidden_size = image_hidden_states.shape
+ special_image_token_mask = torch.where(input_ids == self.image_token_id)
+ new_inputs_embeds = inputs_embeds.clone()
+ reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
+ new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
+ return new_inputs_embeds
+
+
+class GaudiIdefics2ForConditionalGeneration(Idefics2ForConditionalGeneration):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
+ image_hidden_states: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
+ """
+ Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1505
+ The only differences are:
+ - add new args token_idx
+ """
+ if token_idx is not None:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ if input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_seen_tokens = 0
+ return_legacy_cache = True
+ use_new_cache = False # Ignoring new Cache path for HPU
+ if use_cache and use_new_cache:
+ if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_seen_tokens = past_key_values.get_seq_length()
+ else:
+ if past_key_values is not None:
+ past_seen_tokens = past_key_values[0][0].shape[2]
+ if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
+ raise ValueError(
+ "When first calling the model, if input_embeds are passed, input_ids should not be None."
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.model.text_model.get_input_embeddings()(input_ids)
+
+ # START VISUAL INPUTS INTEGRATION
+ if pixel_values is not None and image_hidden_states is not None:
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
+ elif image_hidden_states is not None:
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
+
+ if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
+ # that simply don't exist
+ inputs_embeds = self.model.inputs_merger(
+ input_ids=input_ids,
+ inputs_embeds=inputs_embeds,
+ image_hidden_states=image_hidden_states,
+ )
+
+ outputs = self.model.text_model(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ token_idx=token_idx,
+ )
+
+ if return_legacy_cache and use_cache:
+ if return_dict:
+ outputs.past_key_values = (
+ outputs.past_key_values.to_legacy_cache()
+ if isinstance(outputs.past_key_values, Cache)
+ else outputs.past_key_values
+ )
+ else:
+ outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1]
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:].to(logits.device)
+ shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Idefics2CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_hidden_states,
+ )
+ else:
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ pixel_values=pixel_values,
+ pixel_attention_mask=pixel_attention_mask,
+ image_hidden_states=image_hidden_states,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ """
+ Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1622
+ The only differences are:
+ - add new args token_idx
+ - add None "Cache" past_key_values support
+ - move vision_model to prepare_input_for_generation
+ """
+ past_length = 0
+ token_idx = kwargs.get("token_idx", None)
+ # Omit tokens covered by past_key_values
+ if past_key_values is not None:
+ # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
+ if isinstance(past_key_values, Cache):
+ past_length = past_key_values.get_seq_length()
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and past_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_length == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ image_hidden_states = kwargs.get("image_hidden_states", None)
+ if image_hidden_states is not None:
+ pixel_values = None
+ pixel_attention_mask = None
+ else:
+ pixel_values = kwargs.get("pixel_values", None)
+ pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
+
+ if token_idx is not None and pixel_values is not None:
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
+
+ # Remove padding images - padding images are full 0.
+ nb_values_per_image = pixel_values.shape[1:].numel()
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
+ pixel_values = pixel_values[real_images_inds].contiguous()
+
+ # Handle the vision attention mask
+ if pixel_attention_mask is None:
+ pixel_attention_mask = torch.ones(
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
+ dtype=torch.bool,
+ device=pixel_values.device,
+ )
+ else:
+ # Remove padding images from the mask/pP p
+ pixel_attention_mask = pixel_attention_mask.view(
+ batch_size * num_images, *pixel_attention_mask.shape[2:]
+ )
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
+
+ patch_size = self.config.vision_config.patch_size
+ conv_kernel = torch.ones(
+ [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device
+ )
+ patches_subgrid = torch.nn.functional.conv2d(
+ pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size
+ ).squeeze(1)
+ patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size))
+
+ # Get sequence from the vision encoder
+ image_hidden_states = self.model.vision_model(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ ).last_hidden_state
+
+ # Modality projection & resampling
+ image_hidden_states = self.model.connector(
+ image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
+ )
+ pixel_values = None
+ pixel_attention_mask = None
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ "pixel_attention_mask": pixel_attention_mask,
+ "image_hidden_states": image_hidden_states,
+ "token_idx": token_idx,
+ }
+ )
+ return model_inputs
diff --git a/tests/baselines/idefics2_8b.json b/tests/baselines/idefics2_8b.json
new file mode 100644
index 0000000000..f40995c72d
--- /dev/null
+++ b/tests/baselines/idefics2_8b.json
@@ -0,0 +1,38 @@
+{
+ "gaudi2": {
+ "image2text_lora_finetune": {
+ "num_train_epochs": 2,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 5e-5,
+ "train_batch_size": 2,
+ "train_runtime": 286,
+ "train_samples_per_second": 11.8,
+ "eval_accuracy": 0.6,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 8",
+ "--eval_strategy no",
+ "--save_strategy no",
+ "--warmup_steps 50",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 8",
+ "--lora_dropout 0.1",
+ "--lora_target_modules '.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--input_column_name image query",
+ "--output_column_name answers",
+ "--remove_unused_columns False",
+ "--max_seq_length 512"
+ ]
+ }
+ }
+ }
+ }
+}
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 4fd61d9b7f..f24d250880 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -201,6 +201,11 @@ def is_valid_model_type(model_type: str) -> bool:
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
["t5"],
),
+ "run_image2text_lora_finetune": _get_supported_models_for_script(
+ MODELS_TO_TEST_MAPPING,
+ MODEL_MAPPING,
+ ["idefics2"],
+ ),
}
@@ -230,16 +235,21 @@ def to_test(
"meta-llama/LlamaGuard-7b",
]
+ case_only_in_gaudi2 = [
+ "sft",
+ "dpo",
+ "reward_modeling",
+ "ppo",
+ "prompt_tuning",
+ "peft_poly",
+ "run_sequence_classification",
+ "run_image2text_lora_finetune",
+ ]
+
if (fsdp or fp8) and not IS_GAUDI2:
return False
elif (
- "sft" in example_name
- or "dpo" in example_name
- or "reward_modeling" in example_name
- or "ppo" in example_name
- or "prompt_tuning" in example_name
- or "peft_poly" in example_name
- or example_name == "run_sequence_classification"
+ any(case in example_name for case in case_only_in_gaudi2)
or task_name in ("llama-adapter", "vera", "ia3", "adalora", "ln_tuning", "mamamiya405/finred")
) and not IS_GAUDI2:
return False
@@ -922,6 +932,16 @@ class MultiCardCausalLanguageModelingLoRAFP8ExampleTester(
DATASET_NAME = "tatsu-lab/alpaca"
+class MultiCardImageToTextModelingLoRAExampleTester(
+ ExampleTesterBase,
+ metaclass=ExampleTestMeta,
+ example_name="run_image2text_lora_finetune",
+ multi_card=True,
+):
+ TASK_NAME = "image2text_lora_finetune"
+ DATASET_NAME = "nielsr/docvqa_1200_examples"
+
+
class MultiCardCausalLanguageModelingVeraExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
):
diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py
index 81d6ec3d19..1cb8b95b33 100644
--- a/tests/test_image_to_text_example.py
+++ b/tests/test_image_to_text_example.py
@@ -19,6 +19,7 @@
("llava-hf/llava-v1.6-mistral-7b-hf", 1, 33.17984878151546),
("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742),
("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925),
+ ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077),
],
"fp8": [
("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062),
diff --git a/tests/utils.py b/tests/utils.py
index daee779fa9..18c00a564c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -61,6 +61,7 @@
"code_llama": [("codellama/CodeLlama-13b-Instruct-hf", "Habana/llama")],
"protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")],
"qwen2": [("Qwen/Qwen2-7B", "Habana/qwen"), ("Qwen/Qwen2-72B", "Habana/qwen")],
+ "idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")],
}
MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [